diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index fa2c3d8e3300..3bf6ce23cf8d 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -163,7 +163,7 @@ tvm::runtime::Module load_module(const std::string& file_name) { return tvm::runtime::Module(); } -std::ostream& operator<<(std::ostream& os, const tvm::Array& strings) { +std::ostream& operator<<(std::ostream& os, const tvm::ffi::Array& strings) { os << '['; for (int i = 0, e = strings.size(); i != e; ++i) { if (i != 0) os << ','; @@ -191,7 +191,7 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json, tvm::runtime::Module create_aot_executor(tvm::runtime::Module factory_module, tvm::Device device) { tvm::ffi::Function list_modules = get_module_func(factory_module, "list_module_names"); - tvm::Array module_names = list_modules(); + tvm::ffi::Array module_names = list_modules(); if (module_names.size() != 1) { LOG(WARNING) << __func__ << ": expecting single module, got: " << module_names << ", using " << module_names[0]; diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 09ee55390959..47e82a7f96be 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -116,7 +116,7 @@ void Init(const std::string& name) { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("ffi.Module.load_from_file.dylib_custom", [](ffi::PackedArgs args, ffi::Any* rv) { - auto n = make_object(); + auto n = ffi::make_object(); n->Init(args[0]); *rv = tvm::ffi::CreateLibraryModule(n); }); diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 30e28d20db28..e1afb97b9a34 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -93,9 +93,9 @@ needs to be executed when running under a user-provided optimization level. The .. code:: c++ class PassInfoNode : public Object { - String name; + ffi::String name; int opt_level; - Array required; + ffi::Array required; }; PassContext @@ -125,11 +125,11 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: int opt_level{2}; - tvm::Array required_pass; - tvm::Array disabled_pass; - mutable Optional diag_ctx; - Map config; - Array instruments; + tvm::ffi::Array required_pass; + tvm::ffi::Array disabled_pass; + mutable ffi::Optional diag_ctx; + ffi::Map config; + ffi::Array instruments; }; class PassContext : public NodeRef { @@ -262,7 +262,7 @@ of passes for execution. class SequentialPassNode : PassNode { PassInfo pass_info; // Passes need to be executed. - Array passes; + ffi::Array passes; bool PassEnabled(const PassInfo& info) const; Module operator()(const Module& mod, const PassContext& pass_ctx) const final; }; @@ -321,22 +321,22 @@ favorably use Python APIs to create a specific pass object. Pass CreateFunctionPass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); Pass CreatePrimFuncPass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); Pass CreateModulePass( std::function pass_func, int opt_level, - String name, - Array required); + ffi::String name, + ffi::Array required); - Pass Sequential(tvm::Array passes, PassInfo pass_info); + Pass Sequential(tvm::ffi::Array passes, PassInfo pass_info); Pass Registration ^^^^^^^^^^^^^^^^^ @@ -440,7 +440,7 @@ Multiple ``PassInstrument`` instances can be registed into a single class PassInstrumentNode : public Object { public: - String name; + ffi::String name; virtual void EnterPassContext() const = 0; virtual void ExitPassContext() const = 0; virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index c75d4a075f97..f70df9fe7ca2 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -73,8 +73,5 @@ inline ObjectPtr GetObjectPtr(ObjectType* ptr) { return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); } } // namespace ffi - -using ffi::GetObjectPtr; -using ffi::GetRef; } // namespace tvm #endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 077a55d6d172..7dbcc1f0189e 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -1140,7 +1140,5 @@ inline constexpr bool type_contains_v, Array> = type_contains_v, Map> = } // namespace details } // namespace ffi - -using ffi::Map; } // namespace tvm #endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index 5bea42cb0592..5f66d73a1845 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -298,7 +298,5 @@ template inline constexpr bool type_contains_v, T> = (type_contains_v || ...); } // namespace details } // namespace ffi - -using ffi::Variant; } // namespace tvm #endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 8da30dc5d60b..a9e09d229372 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -38,8 +38,6 @@ namespace ffi { * \brief Extension code beyond the DLDataType. * * This class is always consistent with the DLPack. - * - * TODO(tvm-team): update to latest DLPack types. */ enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index 2e4f3cd6b4e1..1fa9d6539079 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -225,7 +225,5 @@ inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&.. } } // namespace ffi - -using ffi::make_object; } // namespace tvm #endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 3f406d41810b..f93a0f0d555f 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -410,7 +410,5 @@ class Optional>> : public Object } }; } // namespace ffi - -using ffi::Optional; } // namespace tvm #endif // TVM_FFI_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 8da70e5996ad..41720d0d5610 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -993,9 +993,6 @@ inline std::ostream& operator<<(std::ostream& out, const String& input) { } /// \endcond } // namespace ffi - -using ffi::Bytes; -using ffi::String; } // namespace tvm /// \cond Doxygen_Suppress diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 52e9e7209e89..58fde808f068 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -582,7 +582,7 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); + TVM_DLL IntSet operator()(const PrimExpr& expr, const ffi::Map& dom_map); /*! * \brief Find a symbolic integer set that contains all possible @@ -704,7 +704,7 @@ class TVM_DLL Analyzer { * expression. This option should not be used if there is any dependency * between variables. */ - void Bind(const Map& variables, bool allow_override = false); + void Bind(const ffi::Map& variables, bool allow_override = false); /*! * \brief Whether can we prove expr >= val. diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index cf84b9a3a641..6cde90b0b8e5 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -53,8 +53,8 @@ using tir::VarNode; * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, - const Map& relax_map); +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const ffi::Map& hint_map, + const ffi::Map& relax_map); /*! * \brief Same as DeduceBound with unordered_map signature. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 702edba1a462..012f9a3a4479 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -170,12 +170,12 @@ class IntSet : public ObjectRef { // Integer set legacy API. //------------------------------------------------ /*! - * \brief Convert std::unordered_map to Map + * \brief Convert std::unordered_map to ffi::Map * * \param dom_map The domain map to convert. * \return The converted map. */ -Map ConvertDomMap(const std::unordered_map& dom_map); +ffi::Map ConvertDomMap(const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -184,7 +184,7 @@ Map ConvertDomMap(const std::unordered_map& * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, const Map& dom_map); +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each variables. @@ -193,7 +193,7 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, const Map& dom_map); +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -210,7 +210,7 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(Range r, const Map& dom_map); +IntSet EvalSet(Range r, const ffi::Map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -230,13 +230,13 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! - * \brief Same as EvalSet, but takes Array + * \brief Same as EvalSet, but takes ffi::Array * * \param region The range to be evaluated. * \param dom_map The domain of each variable. * \return An array of integer sets that can cover all the possible values. */ -Array EvalSet(const Array& region, const Map& dom_map); +ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! @@ -255,42 +255,42 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, * \param sets The sets to be combined * \return the set after union */ -IntSet Union(const Array& sets); +IntSet Union(const ffi::Array& sets); /*! * \brief The union of N-dimensional integer sets * \param nd_int_sets A list of N-dimensional integer sets * \return An N-dimensional integer set as the result of union */ -Array UnionRegion(const Array>& nd_int_sets); +ffi::Array UnionRegion(const ffi::Array>& nd_int_sets); /*! * \brief Create a lower-bound of union set, where some of the segments may be dropped * \param sets The sets to be combined * \return the set after union */ -IntSet UnionLowerBound(const Array& sets); +IntSet UnionLowerBound(const ffi::Array& sets); /*! * \brief The union of N-dimensional integer sets * \param nd_int_sets A list of N-dimensional integer sets * \return An N-dimensional integer set as the result of union */ -Array UnionRegionLowerBound(const Array>& nd_int_sets); +ffi::Array UnionRegionLowerBound(const ffi::Array>& nd_int_sets); /*! * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return the set after intersected */ -IntSet Intersect(const Array& sets); +IntSet Intersect(const ffi::Array& sets); /*! * \brief Converts the Ranges to IntSets * \param var_dom The ranges of variables * \return The integer sets of the variables */ -Map AsIntSet(const Map& var_dom); +ffi::Map AsIntSet(const ffi::Map& var_dom); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate. @@ -302,10 +302,9 @@ Map AsIntSet(const Map& var_dom); * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of * analysis */ -TVM_DLL Optional> EstimateRegionStrictBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Optional> EstimateRegionStrictBound( + const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate. @@ -317,10 +316,9 @@ TVM_DLL Optional> EstimateRegionStrictBound(const Array& re * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of * analysis */ -TVM_DLL Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Optional> EstimateRegionLowerBound( + const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate @@ -332,10 +330,10 @@ TVM_DLL Optional> EstimateRegionLowerBound(const Array& reg * \param analyzer The analyzer used * \return an array of arith::IntSet as the result of analysis */ -TVM_DLL Array EstimateRegionUpperBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer); +TVM_DLL ffi::Array EstimateRegionUpperBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 6dfc2f0ecb88..eb1e8650e174 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -58,9 +58,9 @@ constexpr int kSimplifyRewriteCanonicalRewrite = 3; class IntGroupBoundsNode : public Object { public: PrimExpr coef; - Array lower; - Array equal; - Array upper; + ffi::Array lower; + ffi::Array equal; + ffi::Array upper; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -93,8 +93,8 @@ class IntGroupBounds : public ObjectRef { * \param equal equalities * \param upper the upper bounds (include) */ - TVM_DLL IntGroupBounds(PrimExpr coef, Array lower, Array equal, - Array upper); + TVM_DLL IntGroupBounds(PrimExpr coef, ffi::Array lower, ffi::Array equal, + ffi::Array upper); /*! * \brief Construct bounds from a range. @@ -106,7 +106,7 @@ class IntGroupBounds : public ObjectRef { /*! * \brief Perform substitution on all components of the struct. */ - IntGroupBounds Substitute(const Map& subst) const; + IntGroupBounds Substitute(const ffi::Map& subst) const; /*! * \brief Find the best range from the grouped bounds. @@ -114,7 +114,7 @@ class IntGroupBounds : public ObjectRef { * \return The best range (has the least difference between the lower bound and upper bound). * undefined if (-inf, +inf). */ - Range FindBestRange(const Map& vranges_addl = {}) const; + Range FindBestRange(const ffi::Map& vranges_addl = {}) const; /*! * \brief Combine the bounds with another range. @@ -134,14 +134,14 @@ class IntGroupBounds : public ObjectRef { class IntConstraintsNode : public Object { public: // e.g., \alpha, \beta, must be integers - Array variables; + ffi::Array variables; // e.g., 1 <= \alpha <= N, etc. // it is absolutely ok to include ranges for parameters // (variables that are not in this->variables) in this map - Map ranges; + ffi::Map ranges; // linear equalities or inequalities // e.g., A \alpha = \beta or A \alpha <= \beta - Array relations; + ffi::Array relations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -170,7 +170,8 @@ class IntConstraints : public ObjectRef { * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL IntConstraints(Array variables, Map ranges, Array relations); + TVM_DLL IntConstraints(ffi::Array variables, ffi::Map ranges, + ffi::Array relations); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -193,8 +194,8 @@ class IntConstraintsTransformNode : public Object { public: IntConstraints src; IntConstraints dst; - Map src_to_dst; - Map dst_to_src; + ffi::Map src_to_dst; + ffi::Map dst_to_src; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -228,7 +229,8 @@ class IntConstraintsTransform : public ObjectRef { * e.g., {m -> a, n -> -b} */ TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, - Map src_to_dst, Map dst_to_src); + ffi::Map src_to_dst, + ffi::Map dst_to_src); /*! * \brief Chain-compose two IntConstraintsTransform together. @@ -242,7 +244,7 @@ class IntConstraintsTransform : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; -typedef std::pair, Array> PartialSolvedInequalities; +typedef std::pair, ffi::Array> PartialSolvedInequalities; /*! * \brief Obtain Smith Normal Form of linear equation A x = y. @@ -301,8 +303,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t * \param bounds grouped boundary of the variables. * \param relations other relations. */ -Array AsConditions(const Array& variables, const Map& bounds, - const Array& relations); +ffi::Array AsConditions(const ffi::Array& variables, + const ffi::Map& bounds, + const ffi::Array& relations); /*! * \brief Solve linear inequalities and infer the range of each variable. diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 25f8e14a7f7b..566b67bf5644 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -197,7 +197,7 @@ class IterSplitExpr : public IterMapExpr { class IterSumExprNode : public IterMapExprNode { public: /*! \brief The args to the sum. */ - Array args; + ffi::Array args; /*! \brief The base offset. */ PrimExpr base; @@ -224,7 +224,7 @@ class IterSumExpr : public IterMapExpr { * \param args The args to the sum. * \param base The base offset. */ - TVM_DLL IterSumExpr(Array args, PrimExpr base); + TVM_DLL IterSumExpr(ffi::Array args, PrimExpr base); TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); @@ -246,11 +246,11 @@ enum IterMapLevel { class IterMapResultNode : public Object { public: // The detected pattern if a match exists. - Array indices; + ffi::Array indices; // Any errors that occurred while converting the input indices. If // the array is empty, the conversion was successful. - Array errors; + ffi::Array errors; /*! \brief Boolean expression indicating if a specific value w * @@ -281,7 +281,7 @@ class IterMapResultNode : public Object { class IterMapResult : public ObjectRef { public: // constructor - IterMapResult() { data_ = make_object(); } + IterMapResult() { data_ = ffi::make_object(); } /*! \return mutable pointers to the node. */ IterMapResultNode* operator->() const { return static_cast(get_mutable()); } @@ -310,9 +310,10 @@ class IterMapResult : public ObjectRef { * \return The detected iteration result. * The return object's .indices is empty on failure. */ -IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +IterMapResult DetectIterMap(const ffi::Array& indices, + const ffi::Map& input_iters, const PrimExpr& predicate, + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices @@ -325,9 +326,11 @@ IterMapResult DetectIterMap(const Array& indices, const Map IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +ffi::Array IterMapSimplify(const ffi::Array& indices, + const ffi::Map& input_iters, + const PrimExpr& input_pred, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Apply the inverse of the affine transformation to the outputs. @@ -349,8 +352,8 @@ Array IterMapSimplify(const Array& indices, const Map InverseAffineIterMap(const Array& iter_map, - const Array outputs); +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const ffi::Array outputs); /*! * \brief Detect if bindings can be written as @@ -379,11 +382,12 @@ Map InverseAffineIterMap(const Array& iter_map, len(bindings): the predicate of outer space and inner space Empty array if no match can be found. */ -Array> SubspaceDivide(const Array& bindings, - const Map& input_iters, - const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, - bool simplify_trivial_iterators = true); +ffi::Array> SubspaceDivide(const ffi::Array& bindings, + const ffi::Map& input_iters, + const ffi::Array& sub_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators = true); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. @@ -408,7 +412,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); * \param analyzer The input analyzer. * \note This function is useful to detect iterator stride patterns. */ -IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iters, +IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, arith::Analyzer* analyzer); } // namespace arith diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 5e1165d509c4..254c1d0933ec 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -37,7 +37,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const PrimExpr& e, const Array& vars); +ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -47,7 +47,7 @@ Array DetectLinearEquation(const PrimExpr& e, const Array& v * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const PrimExpr& e, const Array& vars); +ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h index ad95f2f0ebb5..5879f34633a2 100644 --- a/include/tvm/ir/analysis.h +++ b/include/tvm/ir/analysis.h @@ -55,7 +55,7 @@ class CalleeCollector { virtual void Mark(GlobalVar gvar) = 0; }; -Map> CollectCallMap(const IRModule& mod); +ffi::Map> CollectCallMap(const IRModule& mod); } // namespace ir } // namespace tvm diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 2553116634a2..55576549169c 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -68,11 +68,11 @@ inline DataType NullValue() { class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ - String name; + ffi::String name; /*! \brief type docstring information in str. */ - String type_info; + ffi::String type_info; /*! \brief detailed description of the type */ - String description; + ffi::String description; static void RegisterReflection() { namespace rfl = ffi::reflection; @@ -145,7 +145,7 @@ class Attrs : public ObjectRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + ffi::Map dict; static void RegisterReflection() { namespace rfl = ffi::reflection; @@ -169,7 +169,7 @@ class DictAttrs : public Attrs { * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(Map dict = {}); + TVM_DLL explicit DictAttrs(ffi::Map dict = {}); // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default @@ -194,9 +194,9 @@ class DictAttrs : public Attrs { * \endcode */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { if (!defined()) return default_value; const DictAttrsNode* node = this->as(); auto it = node->dict.find(attr_key); @@ -208,8 +208,8 @@ class DictAttrs : public Attrs { } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! * \brief Check whether the function has an non-zero integer attr. @@ -248,7 +248,7 @@ class DictAttrs : public Attrs { * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); +DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs); /*! * \brief Copy the DictAttrs, but overrides a single attribute. @@ -261,10 +261,10 @@ DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); * * \returns The new DictAttrs with updated attributes. */ -DictAttrs WithAttr(DictAttrs attrs, String key, Any value); +DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value); inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) { - return WithAttr(std::move(attrs), String(key), std::move(value)); + return WithAttr(std::move(attrs), ffi::String(key), std::move(value)); } /*! @@ -325,7 +325,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) * \returns The new function or module with updated attributes. */ template -inline TFunc WithAttrs(TFunc input, Map attrs) { +inline TFunc WithAttrs(TFunc input, ffi::Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); @@ -410,7 +410,7 @@ inline TAttrs AttrsWithDefaultValues() { finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv); return rv.cast(); } else { - auto n = make_object(); + auto n = ffi::make_object(); n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false); return TAttrs(n); } diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 9f4f5770aa60..1d44918cfa21 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -64,7 +64,7 @@ class DiagnosticNode : public Object { */ ObjectRef loc; /*! \brief The diagnostic message. */ - String message; + ffi::String message; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -194,7 +194,7 @@ class DiagnosticContextNode : public Object { IRModule module; /*! \brief The set of diagnostics to report. */ - Array diagnostics; + ffi::Array diagnostics; /*! \brief The renderer set for the context. */ DiagnosticRenderer renderer; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 5afe464109cc..e43575d486eb 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -43,7 +43,7 @@ namespace tvm { class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ - String name; + ffi::String name; /*! \brief The internal packed function */ ffi::Function func; /*! \brief constructor */ @@ -90,7 +90,7 @@ class EnvFunc : public ObjectRef { * \return The created global function. * \note The function can be unique */ - TVM_DLL static EnvFunc Get(const String& name); + TVM_DLL static EnvFunc Get(const ffi::String& name); /*! \brief specify container node */ using ContainerType = EnvFuncNode; }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index f0350af56549..65954b83ac9d 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -39,8 +39,6 @@ namespace tvm { -using tvm::String; - // Forward-declare VirtualDevice to avoid circular imports. class VirtualDevice; @@ -148,7 +146,7 @@ class PrimExpr : public BaseExpr { * \brief construct from string to form a StringImm. * \param value The value to be constructed. */ - TVM_DLL static PrimExpr ConvertFallbackValue(String value); // NOLINT(*) + TVM_DLL static PrimExpr ConvertFallbackValue(ffi::String value); // NOLINT(*) }; /*! @@ -175,19 +173,19 @@ class PrimExprConvertible : public ObjectRef { }; namespace ffi { -// define automatic conversion from bool, int64_t, double, String to PrimExpr +// define automatic conversion from bool, int64_t, double, ffi::String to PrimExpr // These functions are declared early to avoid circular dependency template <> inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(StrictBool value); TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(int64_t value); TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(double value); - TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(String value) { + TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(ffi::String value) { return PrimExpr::ConvertFallbackValue(value); } TVM_FFI_INLINE static PrimExpr ConvertFallbackValue(PrimExprConvertible value) { @@ -426,7 +424,7 @@ class RelaxExprNode : public BaseExprNode { * expression that encapsulate both static shape and * runtime information such as shape. */ - mutable Optional struct_info_ = Optional(); + mutable ffi::Optional struct_info_ = ffi::Optional(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -460,7 +458,7 @@ class GlobalVar; class GlobalVarNode : public RelaxExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - String name_hint; + ffi::String name_hint; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -488,7 +486,7 @@ class GlobalVarNode : public RelaxExprNode { */ class GlobalVar : public RelaxExpr { public: - TVM_DLL explicit GlobalVar(String name_hint, Span span = {}); + TVM_DLL explicit GlobalVar(ffi::String name_hint, Span span = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 53f19ed3f17c..9dd533736f42 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -161,14 +161,14 @@ class BaseFuncNode : public RelaxExprNode { * \endcode */ template - Optional GetAttr(const std::string& attr_key, - Optional default_value = std::nullopt) const { + ffi::Optional GetAttr(const std::string& attr_key, + ffi::Optional default_value = std::nullopt) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -211,7 +211,7 @@ class BaseFuncNode : public RelaxExprNode { */ LinkageType GetLinkageType() const { - if (GetAttr(attr::kGlobalSymbol)) + if (GetAttr(attr::kGlobalSymbol)) return LinkageType::kExternal; else return LinkageType::kInternal; diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index e6ff10ad1bc4..464d781fe472 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -34,7 +34,7 @@ namespace tvm { /*! * \brief Abstract label for an area of memory. */ -using MemoryScope = String; +using MemoryScope = ffi::String; /*! * \brief GlobalInfo are globally static object that are referred by the IR itself. diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 8ed8e5ed4c13..10ca56c9c600 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -58,7 +58,7 @@ class GlobalVarSupplyNode : public Object { * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended * to the name. \return A unique GlobalVar. */ - GlobalVar FreshGlobal(String name, bool add_prefix = true); + GlobalVar FreshGlobal(ffi::String name, bool add_prefix = true); /*! * \brief Looks up for a GlobalVar with the given name in this supply. @@ -67,7 +67,7 @@ class GlobalVarSupplyNode : public Object { * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to * the name before performing the search. \return A cached GlobalVar. */ - GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true); + GlobalVar UniqueGlobalFor(const ffi::String& name, bool add_prefix = true); /*! * \brief Reserves an existing GlobalVar with this supply. @@ -111,7 +111,7 @@ class GlobalVarSupply : public ObjectRef { * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array * of IRModules. */ - TVM_DLL explicit GlobalVarSupply(const Array& modules); + TVM_DLL explicit GlobalVarSupply(const ffi::Array& modules); /*! * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 1a91371cd38f..18ce99740a24 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -103,7 +103,7 @@ namespace instrument { class PassInstrumentNode : public Object { public: /*! \brief Name of this pass instrument object. */ - String name; + ffi::String name; virtual ~PassInstrumentNode() {} diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index f04a6cfe6d53..5da00fb0b377 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -57,18 +57,18 @@ class IRModule; class IRModuleNode : public Object { public: /*! \brief A map from ids to all global functions. */ - Map functions; + ffi::Map functions; /*! \brief The source map for the module. */ SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; /*! \brief Globally static object that are referred by the IR itself */ - Map> global_infos; + ffi::Map> global_infos; /*! * \brief A map from string names to global variables that * ensures global uniqueness. */ - Map global_var_map_; + ffi::Map global_var_map_; /*! * \brief Get a module attribute. @@ -90,15 +90,15 @@ class IRModuleNode : public Object { * \endcode */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -179,7 +179,7 @@ class IRModuleNode : public Object { * \param name The name of the global info. * \param info The new array of global infos. */ - TVM_DLL void UpdateGlobalInfo(const String& name, const Array& info); + TVM_DLL void UpdateGlobalInfo(const ffi::String& name, const ffi::Array& info); /*! * \brief Remove a function from the global environment. @@ -192,21 +192,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const String& name) const; + TVM_DLL bool ContainGlobalVar(const ffi::String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const String& str) const; + TVM_DLL GlobalVar GetGlobalVar(const ffi::String& str) const; /*! * \brief Collect all global vars defined in this module, ordered by * the global variable name. * \returns An array of global vars */ - TVM_DLL Array GetGlobalVars() const; + TVM_DLL ffi::Array GetGlobalVars() const; /*! * \brief Look up a global function by its variable. @@ -220,7 +220,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const String& name) const; + TVM_DLL BaseFunc Lookup(const ffi::String& name) const; /*! * \brief Update the functions inside this environment by @@ -237,7 +237,7 @@ class IRModuleNode : public Object { /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); @@ -263,12 +263,12 @@ class IRModule : public ObjectRef { * \param attrs The module meta-data attributes. * \param global_infos Global infos in the module. */ - TVM_DLL explicit IRModule(Map functions, SourceMap map = {}, + TVM_DLL explicit IRModule(ffi::Map functions, SourceMap map = {}, DictAttrs attrs = DictAttrs(), - Map> global_infos = {}); + ffi::Map> global_infos = {}); /*! \brief default constructor */ - IRModule() : IRModule(Map({})) {} + IRModule() : IRModule(ffi::Map({})) {} /*! * \brief constructor * \param n The object pointer. @@ -286,7 +286,7 @@ class IRModule : public ObjectRef { * imports. */ TVM_DLL static IRModule FromExpr(const RelaxExpr& expr, - const Map& global_funcs = {}); + const ffi::Map& global_funcs = {}); /*! * \brief Create a shallow copy of an IRModule. @@ -318,7 +318,7 @@ constexpr const char* kModuleName = "mod_name"; * node will record the index into this array. See also kConstNameToConstant below, which is * the analog for Realy Functions. * - * Type: Array + * Type: ffi::Array */ constexpr const char* kConstants = "constants"; @@ -326,7 +326,7 @@ constexpr const char* kConstants = "constants"; * \brief All the runtime::Modules accumulated during compilation by external codegen. These * modules must be either directly linked or captured in the final compilation artifact. * - * Type: Array + * Type: ffi::Array */ constexpr const char* kExternalMods = "external_mods"; @@ -365,7 +365,7 @@ constexpr const char* kSystemLibPrefix = "system_lib_prefix"; * and during module initialization these bindings will be recovered from a ConstLoaderModule. * See also kConstantsArray above, which is the analog for PrimFuncs. * - * Type: Map + * Type: ffi::Map */ constexpr const char* kConstNameToConstant = "const_name_to_constant"; diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 6eefaefea793..f367df47ca59 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -50,7 +50,7 @@ class NameSupplyNode : public Object { * \param prefix The prefix to be used with this NameSupply. * \param name_map The map used to guarantee uniqueness. */ - NameSupplyNode(const String& prefix, std::unordered_map name_map) + NameSupplyNode(const ffi::String& prefix, std::unordered_map name_map) : prefix_(prefix), name_map(std::move(name_map)) {} /*! @@ -61,7 +61,8 @@ class NameSupplyNode : public Object { * \param add_underscore If set to true, add '_' between prefix and a digit. * \return A unique name. */ - String FreshName(const String& name, bool add_prefix = true, bool add_underscore = true); + ffi::String FreshName(const ffi::String& name, bool add_prefix = true, + bool add_underscore = true); /*! * \brief Reserves an existing name with this NameSupply. @@ -70,7 +71,7 @@ class NameSupplyNode : public Object { * name before reserving it. \return The name that was reserved with the NameSupply. It can be * different if a prefix is added. */ - String ReserveName(const String& name, bool add_prefix = true); + ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); /*! * \brief Checks if this NameSupply already generated a name. @@ -79,7 +80,7 @@ class NameSupplyNode : public Object { * name before checking for it. \return True if the name has already been generated. False * otherwise. */ - bool ContainsName(const String& name, bool add_prefix = true); + bool ContainsName(const ffi::String& name, bool add_prefix = true); // Prefix for all GlobalVar names. It can be empty. std::string prefix_; @@ -89,7 +90,7 @@ class NameSupplyNode : public Object { private: /*! \brief Helper function to add the NameSupply prefix to the name. */ - String add_prefix_to_name(const String& name); + ffi::String add_prefix_to_name(const ffi::String& name); /*! * \brief Function that will generate a unique name. @@ -114,7 +115,7 @@ class NameSupply : public ObjectRef { * \param prefix The prefix to be used with this NameSupply. * \param name_map An optional map. */ - TVM_DLL explicit NameSupply(const String& prefix = "", + TVM_DLL explicit NameSupply(const ffi::String& prefix = "", std::unordered_map name_map = {}); /*! diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 5f40ff4d3a7b..505b8e1427eb 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -59,21 +59,21 @@ class OpAttrMap; class OpNode : public RelaxExprNode { public: /*! \brief name of the operator */ - String name; + ffi::String name; /*! \brief the type of the operator */ mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ - String description; + ffi::String description; /* \brief Information of input arguments to the operator */ - Array arguments; + ffi::Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. */ - String attrs_type_key; + ffi::String attrs_type_key; /*! * \brief attribute type index, * this field varies in each run and is not exposed to frontend. @@ -139,20 +139,20 @@ class Op : public RelaxExpr { * \tparam ValueType The type of the attribute. */ template - inline static OpAttrMap GetAttrMap(const String& attr_name); + inline static OpAttrMap GetAttrMap(const ffi::String& attr_name); /*! * \brief Checks if an attr map is present in the registry. * \param attr_name The name of the attribute. * \return bool True if the attr is present. */ - TVM_DLL static bool HasAttrMap(const String& attr_name); + TVM_DLL static bool HasAttrMap(const ffi::String& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. * \param op_name Name of the operator. * \return Pointer to a Op, valid throughout program lifetime. */ - TVM_DLL static const Op& Get(const String& op_name); + TVM_DLL static const Op& Get(const ffi::String& op_name); TVM_DEFINE_OBJECT_REF_METHODS(Op, RelaxExpr, OpNode); @@ -162,7 +162,7 @@ class Op : public RelaxExpr { * \param key The attribute key * \return The attr map. */ - TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const String& key); + TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const ffi::String& key); }; /*! @@ -201,7 +201,7 @@ class OpRegEntry { * \param key The attribute type key to be set. * \return reference to self. */ - inline OpRegEntry& set_attrs_type_key(const String& key); + inline OpRegEntry& set_attrs_type_key(const ffi::String& key); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -249,7 +249,7 @@ class OpRegEntry { * \param name The name of the operator. * \return the corresponding entry. */ - TVM_DLL static OpRegEntry& RegisterOrGet(const String& name); + TVM_DLL static OpRegEntry& RegisterOrGet(const ffi::String& name); private: template @@ -263,11 +263,11 @@ class OpRegEntry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpAttrMap - TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); + TVM_DLL void UpdateAttr(const ffi::String& key, ffi::Any value, int plevel); }; /*! - * \brief Map used to store meta-information about Op. + * \brief ffi::Map used to store meta-information about Op. * \tparam ValueType The type of the value stored in map. */ template @@ -318,7 +318,7 @@ class OpAttrMap : public AttrRegistryMap { // implementations template -inline OpAttrMap Op::GetAttrMap(const String& key) { +inline OpAttrMap Op::GetAttrMap(const ffi::String& key) { return OpAttrMap(Op::GetAttrMapContainer(key)); } @@ -331,7 +331,7 @@ inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(* inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - auto n = make_object(); + auto n = ffi::make_object(); n->name = name; n->type_info = type; n->description = description; @@ -351,7 +351,7 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) return *this; } -inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_attrs_type_key(const ffi::String& key) { // NOLINT(*) get()->attrs_type_key = key; get()->attrs_type_index = tvm::ffi::TypeKeyToIndex(key.c_str()); return *this; @@ -376,7 +376,7 @@ template inline ValueType OpAttrMap::get(const RelaxExpr& expr, ValueType def_value) const { ICHECK(expr.defined()); if (const OpNode* op = expr.as()) { - return this->map_.get(GetRef(op), def_value); + return this->map_.get(ffi::GetRef(op), def_value); } else { return def_value; } diff --git a/include/tvm/ir/replace_global_vars.h b/include/tvm/ir/replace_global_vars.h index ea91d46d7c0a..0ed25c9a0a7a 100644 --- a/include/tvm/ir/replace_global_vars.h +++ b/include/tvm/ir/replace_global_vars.h @@ -41,10 +41,10 @@ namespace transform { * * \return The updated IRModule */ -TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map replacements); +TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements); struct GlobalVarReplacer { - using FType = NodeFunctor)>; + using FType = NodeFunctor)>; TVM_DLL static FType& vtable() { static FType inst; return inst; diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index c7fce1c5024c..a8184df6ebdb 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -46,7 +46,7 @@ class SourceName; class SourceNameNode : public Object { public: /*! \brief The source name. */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -70,7 +70,7 @@ class SourceName : public ObjectRef { * \param name Name of the operator. * \return SourceName valid throughout program lifetime. */ - TVM_DLL static SourceName Get(const String& name); + TVM_DLL static SourceName Get(const ffi::String& name); TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); }; @@ -126,7 +126,7 @@ class Span : public ObjectRef { class SequentialSpanNode : public SpanNode { public: /*! \brief The original source list of spans to construct a sequential span. */ - Array spans; + ffi::Array spans; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -143,7 +143,7 @@ class SequentialSpanNode : public SpanNode { */ class SequentialSpan : public Span { public: - TVM_DLL SequentialSpan(Array spans); + TVM_DLL SequentialSpan(ffi::Array spans); TVM_DLL SequentialSpan(std::initializer_list init); @@ -163,7 +163,7 @@ class SourceNode : public Object { SourceName source_name; /*! \brief The raw source. */ - String source; + ffi::String source; /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; @@ -182,7 +182,7 @@ class SourceNode : public Object { class Source : public ObjectRef { public: TVM_DLL Source(SourceName src_name, std::string source); - TVM_DLL tvm::String GetLine(int line); + TVM_DLL tvm::ffi::String GetLine(int line); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); }; @@ -197,7 +197,7 @@ class SourceMap; class SourceMapObj : public Object { public: /*! \brief The source mapping. */ - Map source_map; + ffi::Map source_map; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -211,12 +211,12 @@ class SourceMapObj : public Object { class SourceMap : public ObjectRef { public: - explicit SourceMap(Map source_map); + explicit SourceMap(ffi::Map source_map); explicit SourceMap(std::initializer_list> source_map) - : SourceMap(Map(source_map)) {} + : SourceMap(ffi::Map(source_map)) {} - SourceMap() : SourceMap(Map()) {} + SourceMap() : SourceMap(ffi::Map()) {} void Add(const Source& source); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 45f97ff61f2b..e501ace15997 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -82,16 +82,16 @@ class PassContextNode : public Object { int opt_level{2}; /*! \brief The list of required passes. */ - Array required_pass; + ffi::Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; + ffi::Array disabled_pass; /*! \brief The diagnostic context. */ - mutable Optional diag_ctx; + mutable ffi::Optional diag_ctx; /*! \brief Pass specific configurations. */ - Map config; + ffi::Map config; /*! \brief A list of pass instrument implementations. */ - Array instruments; + ffi::Array instruments; PassContextNode() = default; @@ -107,21 +107,21 @@ class PassContextNode : public Object { * \throw Error if the key exists but the value does not match TObjectRef. */ template - Optional GetConfig( + ffi::Optional GetConfig( const std::string& key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { if (!config.defined()) return default_value; auto it = config.find(key); if (it != config.end()) { - return Downcast>((*it).second); + return Downcast>((*it).second); } else { return default_value; } } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetConfig(const std::string& key, TObjectRef default_value) const { - return GetConfig(key, Optional(default_value)); + ffi::Optional GetConfig(const std::string& key, TObjectRef default_value) const { + return GetConfig(key, ffi::Optional(default_value)); } static void RegisterReflection() { @@ -189,7 +189,7 @@ class PassContext : public ObjectRef { * \brief Get all supported configuration names and metadata, registered within the PassContext. * \return Map indexed by the config name, pointing to the metadata map as key-value */ - TVM_DLL static Map> ListConfigs(); + TVM_DLL static ffi::Map> ListConfigs(); /*! * \brief Call instrument implementations' callbacks when entering PassContext. @@ -247,7 +247,7 @@ class PassContext : public ObjectRef { int32_t tindex = ffi::TypeToRuntimeTypeIndex::v(); auto type_key = ffi::TypeIndexToTypeKey(tindex); auto legalization = [=](ffi::Any value) -> ffi::Any { - if (auto opt_map = value.try_cast>()) { + if (auto opt_map = value.try_cast>()) { return ffi::reflection::ObjectCreator(type_key)(opt_map.value()); } else { auto opt_val = value.try_cast(); @@ -288,7 +288,7 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, String value_type_str, + TVM_DLL static void RegisterConfigOption(const char* key, ffi::String value_type_str, std::function legalization); // Classes to get the Python `with` like syntax. @@ -318,13 +318,13 @@ class PassInfoNode : public Object { int opt_level; /*! \brief The name of an optimization/analysis pass. */ - String name; + ffi::String name; /*! \brief Boolean that tells whether this pass will be traced or not. */ bool traceable; /*! \brief The passes that are required to perform the current pass. */ - Array required; + ffi::Array required; PassInfoNode() = default; @@ -355,7 +355,8 @@ class PassInfo : public ObjectRef { * \param required The passes that are required to perform the current pass. * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); + TVM_DLL PassInfo(int opt_level, ffi::String name, ffi::Array required, + bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -447,7 +448,7 @@ class SequentialNode : public PassNode { PassInfo pass_info; /*! \brief A list of passes that used to compose a sequential pass. */ - tvm::Array passes; + tvm::ffi::Array passes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -498,7 +499,7 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(Array passes, PassInfo pass_info); + TVM_DLL Sequential(ffi::Array passes, PassInfo pass_info); /*! * \brief The constructor of `Sequential`. @@ -508,7 +509,7 @@ class Sequential : public Pass { * This allows users to only provide a list of passes and execute them * under a given context. */ - TVM_DLL Sequential(Array passes, String name = "sequential"); + TVM_DLL Sequential(ffi::Array passes, ffi::String name = "sequential"); Sequential() = default; explicit Sequential(ObjectPtr n) : Pass(n) {} @@ -528,7 +529,7 @@ class Sequential : public Pass { * \return The created module pass. */ TVM_DLL Pass CreateModulePass(std::function pass_func, - int opt_level, String name, Array required, + int opt_level, ffi::String name, ffi::Array required, bool traceable = false); /* @@ -553,7 +554,7 @@ TVM_DLL Pass CreateModulePass(std::function pas * * \return The modified IRModule to IRModule pass. */ -TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex, +TVM_DLL Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, bool error_if_no_function_matches_regex = false); /*! @@ -562,7 +563,7 @@ TVM_DLL Pass ApplyPassToFunction(Pass pass, String func_name_regex, * \param show_meta_data Whether should we show meta data. * \return The pass. */ -TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false); +TVM_DLL Pass PrintIR(ffi::String header = "", bool show_meta_data = false); } // namespace transform } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 9d75e845f88f..1d4992abfb3a 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -162,7 +162,7 @@ class PointerTypeNode : public TypeNode { /*! * \brief The storage scope of the pointer */ - String storage_scope; + ffi::String storage_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -186,7 +186,7 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); + TVM_DLL explicit PointerType(Type element_type, ffi::String storage_scope = ""); TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; @@ -198,7 +198,7 @@ class PointerType : public Type { class TupleTypeNode : public TypeNode { public: /*! \brief The type of each field in the tuple. */ - Array fields; + ffi::Array fields; TupleTypeNode() {} @@ -224,7 +224,7 @@ class TupleType : public Type { * \param fields Fields in the tuple. * \param span The span of the type. */ - TVM_DLL explicit TupleType(Array fields, Span span = Span()); + TVM_DLL explicit TupleType(ffi::Array fields, Span span = Span()); /*! * \brief Create an empty tuple type that constains nothing. @@ -260,7 +260,7 @@ inline bool IsVoidType(const Type& type) { class FuncTypeNode : public TypeNode { public: /*! \brief type type of arguments */ - Array arg_types; + ffi::Array arg_types; /*! \brief The type of return value. */ Type ret_type; @@ -289,7 +289,7 @@ class FuncType : public Type { * \param span The span information. * \sa FuncTypeNode for more docs about these fields. */ - TVM_DLL FuncType(Array arg_types, Type ret_type, Span span = Span()); + TVM_DLL FuncType(ffi::Array arg_types, Type ret_type, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); }; diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 858226354c66..b2878519c424 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -123,7 +123,7 @@ class TVM_DLL TypeMutator : public TypeFunctor { Type VisitType_(const PointerTypeNode* op) override; private: - Array MutateArray(Array arr); + ffi::Array MutateArray(ffi::Array arr); }; } // namespace tvm diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index de005dcd125b..75ef64daa4d4 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -60,14 +60,14 @@ class ArgInfo : public runtime::ObjectRef { * \param func The PrimFunc to get argument information from. * \return An array of the argument information derived. */ - TVM_DLL static Array FromPrimFunc(const tir::PrimFunc& func); + TVM_DLL static ffi::Array FromPrimFunc(const tir::PrimFunc& func); /*! * \brief Extract a list of the argument information from the entry func of an IRModule * \param mod The IRModule to extract argument information from. * \param remove_preproc Whether to remove the preprocessing blocks. * \return An array of the argument information derived. */ - TVM_DLL static Array FromEntryFunc(const IRModule& mod, bool remove_preproc); + TVM_DLL static ffi::Array FromEntryFunc(const IRModule& mod, bool remove_preproc); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode); diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index a5c3fe5f2c5f..6a6df2950271 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -41,7 +41,7 @@ class BuilderInputNode : public runtime::Object { /*! \brief The target to be built for. */ Target target; /*! \brief Parameters for Relax build module. */ - Optional> params; + ffi::Optional> params; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -67,8 +67,9 @@ class BuilderInput : public runtime::ObjectRef { * \param target The target to be built for. * \param params Parameters for Relax build module. */ - TVM_DLL explicit BuilderInput(IRModule mod, Target target, - Optional> params = std::nullopt); + TVM_DLL explicit BuilderInput( + IRModule mod, Target target, + ffi::Optional> params = std::nullopt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; @@ -76,9 +77,9 @@ class BuilderInput : public runtime::ObjectRef { class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ - Optional artifact_path; + ffi::Optional artifact_path; /*! \brief The error message if any. */ - Optional error_msg; + ffi::Optional error_msg; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -102,7 +103,8 @@ class BuilderResult : public runtime::ObjectRef { * \param artifact_path The path to the built artifact. * \param error_msg The error message if any. */ - TVM_DLL explicit BuilderResult(Optional artifact_path, Optional error_msg); + TVM_DLL explicit BuilderResult(ffi::Optional artifact_path, + ffi::Optional error_msg); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderResult, runtime::ObjectRef, BuilderResultNode); }; @@ -116,13 +118,13 @@ class BuilderNode : public runtime::Object { * \param build_inputs The inputs to be built. * \return The build results. */ - virtual Array Build(const Array& build_inputs) = 0; + virtual ffi::Array Build(const ffi::Array& build_inputs) = 0; /*! * \brief The function type of `Build` method. * \param build_inputs The inputs to be built. * \return The build results. */ - using FBuild = ffi::TypedFunction(const Array&)>; + using FBuild = ffi::TypedFunction(const ffi::Array&)>; static constexpr const char* _type_key = "meta_schedule.Builder"; TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object); @@ -154,7 +156,7 @@ class PyBuilderNode : public BuilderNode { refl::ObjectDef().def_ro("f_build", &PyBuilderNode::f_build); } - Array Build(const Array& build_inputs) final { + ffi::Array Build(const ffi::Array& build_inputs) final { ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 9311fdef40c9..2ac20fcca8db 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -47,13 +47,13 @@ class CostModelNode : public runtime::Object { * \brief Load the cost model from given file location. * \param path The file path. */ - virtual void Load(const String& path) = 0; + virtual void Load(const ffi::String& path) = 0; /*! * \brief Save the cost model to given file location. * \param path The file path. */ - virtual void Save(const String& path) = 0; + virtual void Save(const ffi::String& path) = 0; /*! * \brief Update the cost model given running results. @@ -61,8 +61,8 @@ class CostModelNode : public runtime::Object { * \param candidates The measure candidates. * \param results The running results of the measure candidates. */ - virtual void Update(const TuneContext& context, const Array& candidates, - const Array& results) = 0; + virtual void Update(const TuneContext& context, const ffi::Array& candidates, + const ffi::Array& results) = 0; /*! * \brief Predict the normalized score (the larger the better) of given measure candidates. @@ -71,7 +71,7 @@ class CostModelNode : public runtime::Object { * \return The predicted normalized score. */ virtual std::vector Predict(const TuneContext& context, - const Array& candidates) = 0; + const ffi::Array& candidates) = 0; static constexpr const char* _type_key = "meta_schedule.CostModel"; TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); @@ -84,12 +84,12 @@ class PyCostModelNode : public CostModelNode { * \brief Load the cost model from given file location. * \param path The file path. */ - using FLoad = ffi::TypedFunction; + using FLoad = ffi::TypedFunction; /*! * \brief Save the cost model to given file location. * \param path The file path. */ - using FSave = ffi::TypedFunction; + using FSave = ffi::TypedFunction; /*! * \brief Update the cost model given running results. * \param context The tuning context. @@ -97,21 +97,21 @@ class PyCostModelNode : public CostModelNode { * \param results The running results of the measure candidates. * \return Whether cost model was updated successfully. */ - using FUpdate = ffi::TypedFunction&, - const Array&)>; + using FUpdate = ffi::TypedFunction&, + const ffi::Array&)>; /*! * \brief Predict the running results of given measure candidates. * \param context The tuning context. * \param candidates The measure candidates. * \param p_addr The address to save the estimated running results. */ - using FPredict = - ffi::TypedFunction&, void* p_addr)>; + using FPredict = ffi::TypedFunction&, + void* p_addr)>; /*! * \brief Get the cost model as string with name. * \return The string representation of the cost model. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Load` function. */ FLoad f_load; @@ -124,12 +124,12 @@ class PyCostModelNode : public CostModelNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void Load(const String& path); - void Save(const String& path); - void Update(const TuneContext& context, const Array& candidates, - const Array& results); + void Load(const ffi::String& path); + void Save(const ffi::String& path); + void Update(const TuneContext& context, const ffi::Array& candidates, + const ffi::Array& results); std::vector Predict(const TuneContext& context, - const Array& candidates); + const ffi::Array& candidates); static constexpr const char* _type_key = "meta_schedule.PyCostModel"; TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 6c631a9eca74..fbb09d7852c6 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -119,11 +119,11 @@ class TuningRecordNode : public runtime::Object { /*! \brief The workload. */ Workload workload{nullptr}; /*! \brief The profiling result in seconds. */ - Optional> run_secs; + ffi::Optional> run_secs; /*! \brief The target for tuning. */ - Optional target; + ffi::Optional target; /*! \brief The argument information. */ - Optional> args_info; + ffi::Optional> args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -170,8 +170,9 @@ class TuningRecord : public runtime::ObjectRef { \param args_info The argument information of the tuning record. */ TVM_DLL explicit TuningRecord(tir::Trace trace, Workload workload, - Optional> run_secs, Optional target, - Optional> args_info); + ffi::Optional> run_secs, + ffi::Optional target, + ffi::Optional> args_info); /*! * \brief Create a tuning record from a json object. * \param json_obj The json object. @@ -199,7 +200,7 @@ class DatabaseNode : public runtime::Object { * or in case no anchor block is found. * For the definition of the anchor block, see tvm/tir/analysis.h. */ - explicit DatabaseNode(String mod_eq_name = "structural"); + explicit DatabaseNode(ffi::String mod_eq_name = "structural"); /*! \brief Default destructor */ virtual ~DatabaseNode(); @@ -226,12 +227,12 @@ class DatabaseNode : public runtime::Object { * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. */ - virtual Array GetTopK(const Workload& workload, int top_k) = 0; + virtual ffi::Array GetTopK(const Workload& workload, int top_k) = 0; /*! * \brief Get all tuning records from the database. * \return An Array of all the tuning records in the database. */ - virtual Array GetAllTuningRecords() = 0; + virtual ffi::Array GetAllTuningRecords() = 0; /*! * \brief Get the size of the database. * \return The size of the database. @@ -244,8 +245,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; std::nullopt if not found. */ - virtual Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Query the best schedule of the given workload from the database. * \param mod The IRModule to be searched for. @@ -253,8 +254,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - virtual Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Query the best IRModule of the given workload from the database. * \param mod The IRModule to be searched for. @@ -262,8 +263,8 @@ class DatabaseNode : public runtime::Object { * \param workload_name The name of the workload to be searched for. * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ - virtual Optional QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name); + virtual ffi::Optional QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name); /*! * \brief Prune the database and dump it a given database. * \param destination The destination database to be dumped to. @@ -298,7 +299,7 @@ class PyDatabaseNode : public DatabaseNode { * or in case no anchor block is found. * For the definition of the anchor block, see tvm/tir/analysis.h. */ - explicit PyDatabaseNode(String mod_eq_name = "structural"); + explicit PyDatabaseNode(ffi::String mod_eq_name = "structural"); /*! * \brief The function type of `HasWorkload` method. @@ -323,12 +324,12 @@ class PyDatabaseNode : public DatabaseNode { * \param top_k The number of top records to be returned. * \return An array of top K tuning records for the given workload. */ - using FGetTopK = ffi::TypedFunction(const Workload&, int)>; + using FGetTopK = ffi::TypedFunction(const Workload&, int)>; /*! * \brief The function type of `GetAllTuningRecords` method. * \return An Array of all the tuning records in the database. */ - using FGetAllTuningRecords = ffi::TypedFunction()>; + using FGetAllTuningRecords = ffi::TypedFunction()>; /*! * \brief The function type of `QueryTuningRecord` method. * \param mod The IRModule to be searched for. @@ -336,8 +337,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; std::nullopt if not found. */ - using FQueryTuningRecord = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQueryTuningRecord = ffi::TypedFunction( + const IRModule&, const Target&, const ffi::String&)>; /*! * \brief The function type of `QuerySchedule` method. * \param mod The IRModule to be searched for. @@ -345,8 +346,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ - using FQuerySchedule = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQuerySchedule = ffi::TypedFunction( + const IRModule&, const Target&, const ffi::String&)>; /*! * \brief The function type of `QueryIRModule` method. * \param mod The IRModule to be searched for. @@ -354,8 +355,8 @@ class PyDatabaseNode : public DatabaseNode { * \param workload_name The name of the workload to be searched for. * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ - using FQueryIRModule = - ffi::TypedFunction(const IRModule&, const Target&, const String&)>; + using FQueryIRModule = ffi::TypedFunction(const IRModule&, const Target&, + const ffi::String&)>; /*! * \brief The function type of `Size` method. * \return The size of the database. @@ -412,19 +413,19 @@ class PyDatabaseNode : public DatabaseNode { f_commit_tuning_record(record); } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { ICHECK(f_get_all_tuning_records != nullptr) << "PyDatabase's GetAllTuningRecords method not implemented!"; return f_get_all_tuning_records(); } - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_tuning_record == nullptr) { return DatabaseNode::QueryTuningRecord(mod, target, workload_name); } else { @@ -432,8 +433,8 @@ class PyDatabaseNode : public DatabaseNode { } } - Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_schedule == nullptr) { return DatabaseNode::QuerySchedule(mod, target, workload_name); } else { @@ -441,8 +442,8 @@ class PyDatabaseNode : public DatabaseNode { } } - Optional QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { if (f_query_ir_module == nullptr) { return DatabaseNode::QueryIRModule(mod, target, workload_name); } else { @@ -469,7 +470,7 @@ class Database : public runtime::ObjectRef { * \brief An in-memory database. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database MemoryDatabase(String mod_eq_name = "structural"); + TVM_DLL static Database MemoryDatabase(ffi::String mod_eq_name = "structural"); /*! * \brief A database for injecting handcrafted schedule functions. * \param schedule_fn The function to do scheduling, which takes a TIR schedule, @@ -477,7 +478,7 @@ class Database : public runtime::ObjectRef { * \param mod_eq_name A string to specify the module equality testing and hashing method. */ TVM_DLL static Database ScheduleFnDatabase(ffi::TypedFunction schedule_fn, - String mod_eq_name = "structural"); + ffi::String mod_eq_name = "structural"); /*! * \brief Create a default database that uses JSON file for tuning records. * \param path_workload The path to the workload table. @@ -485,8 +486,8 @@ class Database : public runtime::ObjectRef { * \param allow_missing Whether to create new file when the given path is not found. * \param mod_eq_name A string to specify the module equality testing and hashing method. */ - TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, - bool allow_missing, String mod_eq_name = "structural"); + TVM_DLL static Database JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, + bool allow_missing, ffi::String mod_eq_name = "structural"); /*! * \brief A database composed of multiple databases, allowing users to guide IR rewriting using * combined knowledge of those databases. To each query, it returns the best record among all the @@ -494,7 +495,7 @@ class Database : public runtime::ObjectRef { * \param databases The list of databases to be combined. * \return The combined database. */ - TVM_DLL static Database UnionDatabase(Array databases); + TVM_DLL static Database UnionDatabase(ffi::Array databases); /*! * \brief A database composed of multiple databases, allowing users to guide IR rewriting using * combined knowledge of those databases. To each query, it returns the record from the first @@ -502,7 +503,7 @@ class Database : public runtime::ObjectRef { * \param databases The database to be subsetted. * \return The subsetted database. */ - TVM_DLL static Database OrderedUnionDatabase(Array databases); + TVM_DLL static Database OrderedUnionDatabase(ffi::Array databases); /*! * \brief Create a database with customized methods on the python-side. * \param f_has_workload The packed function of `HasWorkload`. @@ -526,9 +527,9 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, PyDatabaseNode::FSize f_size, - String mod_eq_name = "structural"); + ffi::String mod_eq_name = "structural"); /*! \return The current Database in the scope. */ - static Optional Current(); + static ffi::Optional Current(); /*! \brief Entering the scope of the context manager */ void EnterWithScope(); /*! \brief Exiting the scope of the context manager */ diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 57debfee2267..974664bba505 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -42,13 +42,13 @@ namespace meta_schedule { class ExtractedTaskNode : public runtime::Object { public: /*! \brief The name of the task extracted */ - String task_name; + ffi::String task_name; /*! \brief The high-level IR */ IRModule mod; /*! \brief Target */ Target target; /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ - Array dispatched; + ffi::Array dispatched; /*! \brief Weight of the task */ int weight; @@ -73,8 +73,8 @@ class ExtractedTaskNode : public runtime::Object { */ class ExtractedTask : public runtime::ObjectRef { public: - explicit ExtractedTask(String task_name, IRModule mod, Target target, Array dispatched, - int weight); + explicit ExtractedTask(ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode); }; diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 88fcf9ac618d..e15d87679e03 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -49,8 +49,8 @@ class FeatureExtractorNode : public runtime::Object { * \param candidates The measure candidates to extract features from. * \return The feature tensor extracted. */ - virtual Array ExtractFrom(const TuneContext& context, - const Array& candidates) = 0; + virtual ffi::Array ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) = 0; static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); @@ -65,13 +65,13 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { * \param candidates The measure candidates to extract features from. * \return The feature tensor extracted. */ - using FExtractFrom = ffi::TypedFunction( - const TuneContext& context, const Array& candidates)>; + using FExtractFrom = ffi::TypedFunction( + const TuneContext& context, const ffi::Array& candidates)>; /*! * \brief Get the feature extractor as string with name. * \return The string of the feature extractor. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `ExtractFrom` function. */ FExtractFrom f_extract_from; @@ -83,8 +83,8 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { // `f_as_string` is not registered } - Array ExtractFrom(const TuneContext& context, - const Array& candidates) final; + ffi::Array ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) final; static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index d7377c3e5d1f..a266eeb26762 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -54,11 +54,11 @@ class MeasureCallbackNode : public runtime::Object { * \param builder_results The builder results by building the measure candidates. * \param runner_results The runner results by running the built measure candidates. */ - virtual void Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builder_results, // - const Array& runner_results) = 0; + virtual void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builder_results, // + const ffi::Array& runner_results) = 0; static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); @@ -76,16 +76,16 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { * \param results The runner results by running the built measure candidates. * \return Whether the measure callback was successfully applied. */ - using FApply = ffi::TypedFunction& measure_candidates, // - const Array& builds, // - const Array& results)>; + using FApply = ffi::TypedFunction& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results)>; /*! * \brief Get the measure callback function as string with name. * \return The string of the measure callback function. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief The packed function to the `Apply` function. */ FApply f_apply; @@ -97,11 +97,11 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { // `f_as_string` is not registered } - void Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builds, // - const Array& results); + void Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results); static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); @@ -137,7 +137,7 @@ class MeasureCallback : public runtime::ObjectRef { TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, PyMeasureCallbackNode::FAsString f_as_string); /*! \brief The default list of measure callbacks. */ - TVM_DLL static Array Default(); + TVM_DLL static ffi::Array Default(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); }; diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index 0aee01fff5eb..dbc5892236b2 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -35,7 +35,7 @@ class MeasureCandidateNode : public runtime::Object { /*! \brief The schedule for measurement. */ tir::Schedule sch; /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ - Array args_info; + ffi::Array args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -59,7 +59,7 @@ class MeasureCandidate : public runtime::ObjectRef { * \param sch The schedule for measurement. * \param args_info The argument information, e.g., (shape, dtype) for tensors. */ - TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DLL MeasureCandidate(tir::Schedule sch, ffi::Array args_info); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 701045b7fb3f..823501623fe1 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -57,8 +57,8 @@ class MutatorNode : public runtime::Object { * \param rand_state The random state for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - virtual Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) = 0; + virtual ffi::Optional Apply( + const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) = 0; /*! * \brief Clone the mutator. @@ -86,7 +86,7 @@ class Mutator : public runtime::ObjectRef { * \param trace The given trace for mutation. * \return None if mutator failed, otherwise return the mutated trace. */ - using FApply = ffi::TypedFunction( + using FApply = ffi::TypedFunction( const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; /*! * \brief Clone the mutator. @@ -97,7 +97,7 @@ class Mutator : public runtime::ObjectRef { * \brief Get the mutator as string with name. * \return The string of the mutator. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */ TVM_DLL static Mutator MutateTileSize(); /*! @@ -132,13 +132,13 @@ class Mutator : public runtime::ObjectRef { TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, FApply f_apply, FClone f_clone, FAsString f_as_string); /*! \brief Create default mutators for LLVM */ - TVM_DLL static Map DefaultLLVM(); + TVM_DLL static ffi::Map DefaultLLVM(); /*! \brief Create default mutators for CUDA */ - TVM_DLL static Map DefaultCUDA(); + TVM_DLL static ffi::Map DefaultCUDA(); /*! \brief Create default mutators for CUDA with TensorCore */ - TVM_DLL static Map DefaultCUDATensorCore(); + TVM_DLL static ffi::Map DefaultCUDATensorCore(); /*! \brief Create default mutators for Hexagon */ - TVM_DLL static Map DefaultHexagon(); + TVM_DLL static ffi::Map DefaultHexagon(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); }; @@ -167,8 +167,8 @@ class PyMutatorNode : public MutatorNode { } void InitializeWithTuneContext(const TuneContext& context) final; - Optional Apply(const tir::Trace& trace, - support::LinearCongruentialEngine::TRandState* rand_state) final; + ffi::Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final; Mutator Clone() const final; static constexpr const char* _type_key = "meta_schedule.PyMutator"; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 6ed7272fe9b4..91d45e8680f8 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -93,7 +93,7 @@ class Postproc : public runtime::ObjectRef { * \brief Get the postprocessor function as string with name. * \return The string of the postprocessor function. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! * \brief Create a postprocessor with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. @@ -163,17 +163,17 @@ class Postproc : public runtime::ObjectRef { */ TVM_DLL static Postproc RewriteLayout(); /*! \brief Create default postprocessors for LLVM */ - TVM_DLL static Array DefaultLLVM(); + TVM_DLL static ffi::Array DefaultLLVM(); /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */ - TVM_DLL static Array DefaultCPUTensorization(); + TVM_DLL static ffi::Array DefaultCPUTensorization(); /*! \brief Create default postprocessors for RISCV */ - TVM_DLL static Array DefaultRISCV(); + TVM_DLL static ffi::Array DefaultRISCV(); /*! \brief Create default postprocessors for CUDA */ - TVM_DLL static Array DefaultCUDA(); + TVM_DLL static ffi::Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ - TVM_DLL static Array DefaultCUDATensorCore(); + TVM_DLL static ffi::Array DefaultCUDATensorCore(); /*! \brief Create default postprocessors for Hexagon */ - TVM_DLL static Array DefaultHexagon(); + TVM_DLL static ffi::Array DefaultHexagon(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index c3754e0211a1..e8288a5ae6a1 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -69,9 +69,9 @@ class ProfilerNode : public runtime::Object { public: /*! \brief Get the internal stats of the running time */ - Map Get() const; + ffi::Map Get() const; /*! \brief Return a summary of profiling results as table format */ - String Table() const; + ffi::String Table() const; }; /*! @@ -88,13 +88,13 @@ class Profiler : public runtime::ObjectRef { /*! \brief Exiting the scope of the context manager */ void ExitWithScope(); /*! \brief Returns the current profiler */ - static Optional Current(); + static ffi::Optional Current(); /*! * \brief Profile the time usage in the given scope in the given name. * \param name Name for the scope. * \return A scope timer for time profiling. */ - static ScopedTimer TimedScope(String name); + static ScopedTimer TimedScope(ffi::String name); }; } // namespace meta_schedule diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 1bfda4820f6a..2d42b5e590d4 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -35,11 +35,11 @@ namespace meta_schedule { class RunnerInputNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ - String artifact_path; + ffi::String artifact_path; /*! \brief The type of device. */ - String device_type; + ffi::String device_type; /*! \brief The argument information. */ - Array args_info; + ffi::Array args_info; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -66,7 +66,8 @@ class RunnerInput : public runtime::ObjectRef { * \param device_type The type of device. * \param args_info The argument information. */ - TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array args_info); + TVM_DLL explicit RunnerInput(ffi::String artifact_path, ffi::String device_type, + ffi::Array args_info); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); }; @@ -74,9 +75,9 @@ class RunnerInput : public runtime::ObjectRef { class RunnerResultNode : public runtime::Object { public: /*! \brief The run time in seconds.*/ - Optional> run_secs; + ffi::Optional> run_secs; /*! \brief The error message, if any. */ - Optional error_msg; + ffi::Optional error_msg; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -101,7 +102,8 @@ class RunnerResult : public runtime::ObjectRef { * \brief The run time in seconds. * \brief The error message, if any. */ - TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); + TVM_DLL explicit RunnerResult(ffi::Optional> run_secs, + ffi::Optional error_msg); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); }; @@ -182,7 +184,7 @@ class RunnerNode : public runtime::Object { * \return The runner futures. * \sa RunnerFuture */ - using FRun = ffi::TypedFunction(Array)>; + using FRun = ffi::TypedFunction(ffi::Array)>; /*! \brief Default destructor */ virtual ~RunnerNode() = default; @@ -192,7 +194,7 @@ class RunnerNode : public runtime::Object { * \param runner_inputs The runner's inputs. * \return The runner futures. */ - virtual Array Run(Array runner_inputs) = 0; + virtual ffi::Array Run(ffi::Array runner_inputs) = 0; static constexpr const char* _type_key = "meta_schedule.Runner"; TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); @@ -225,7 +227,7 @@ class PyRunnerNode : public RunnerNode { // `f_run` is not registered } - Array Run(Array runner_inputs) final { + ffi::Array Run(ffi::Array runner_inputs) final { ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; return f_run(runner_inputs); } diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index 125d6dc11fc8..aa3df4e7d443 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -36,7 +36,7 @@ namespace meta_schedule { * \return A sampler that returns a random thread extent. */ std::function MakeFactorSampler(tir::Schedule sch, - Array thread_extents); + ffi::Array thread_extents); /*! * \brief Bind blockIdx.x and threadIdx.x to the given loop @@ -47,9 +47,9 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param get_factor A function that returns the tiling factor. * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest. */ -Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // - int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr); +ffi::Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor = nullptr); /*! * \brief Bind the given block if it is not bound to blockIdx or threadIdx. diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 407914e3d074..7305b1b9c82e 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -59,7 +59,7 @@ class ScheduleRuleNode : public runtime::Object { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; + virtual ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; /*! * \brief Deep clone the schedule rule. @@ -89,12 +89,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The list of schedules generated by applying the schedule rule. */ using FApply = - ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; + ffi::TypedFunction(const tir::Schedule&, const tir::BlockRV&)>; /*! * \brief Get the schedule rule as string with name. * \return The string of the schedule rule. */ - using FAsString = ffi::TypedFunction; + using FAsString = ffi::TypedFunction; /*! * \brief The function type of `Clone` method. * \return The cloned schedule rule. @@ -125,7 +125,7 @@ class ScheduleRule : public runtime::ObjectRef { bool disallow_if_then_else, // bool require_injective, // bool require_ordered, // - Optional> disallow_op); + ffi::Optional> disallow_op); /*! * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of @@ -155,13 +155,14 @@ class ScheduleRule : public runtime::ObjectRef { * ignored by default. This function should return True for a block that should be tiled. * \return The schedule rule created */ - TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // - Optional> tile_binds, // - Optional max_innermost_factor, // - Optional> vector_load_lens, // - Optional> reuse_read, // - Optional> reuse_write, - Optional filter_fn = std::nullopt); + TVM_DLL static ScheduleRule MultiLevelTiling( + ffi::String structure, // + ffi::Optional> tile_binds, // + ffi::Optional max_innermost_factor, // + ffi::Optional> vector_load_lens, // + ffi::Optional> reuse_read, // + ffi::Optional> reuse_write, + ffi::Optional filter_fn = std::nullopt); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. @@ -181,9 +182,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( - String intrin_name, String structure, Optional> tile_binds, - Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write); + ffi::String intrin_name, ffi::String structure, + ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate @@ -206,10 +210,12 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( - Array> intrin_groups, String structure, - Optional> tile_binds, Optional max_innermost_factor, - Optional> vector_load_lens, Optional> reuse_read, - Optional> reuse_write, bool use_software_pipeline); + ffi::Array> intrin_groups, ffi::String structure, + ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, bool use_software_pipeline); /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. @@ -223,8 +229,10 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWideVector( - String structure, Integer vector_length_in_bits, Optional max_innermost_factor, - Optional> reuse_read, Optional> reuse_write); + ffi::String structure, Integer vector_length_in_bits, + ffi::Optional max_innermost_factor, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write); /*! * \brief Create a rule: add-rfactor to some blocks if needed @@ -235,14 +243,14 @@ class ScheduleRule : public runtime::ObjectRef { * limit \return The schedule rule created */ TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // - Optional max_innermost_factor); + ffi::Optional max_innermost_factor); /*! * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks * correspondingly when needed * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(ffi::Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -261,9 +269,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + ffi::Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx @@ -273,7 +281,7 @@ class ScheduleRule : public runtime::ObjectRef { * when this schedule rule is created. * \return The schedule rule created */ - TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents, + TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, ffi::Array thread_extents, int max_threads_per_block = -1); /*! * \brief Create a schedule rule with customized methods on the python-side. @@ -290,19 +298,19 @@ class ScheduleRule : public runtime::ObjectRef { FAsString f_as_string); /*! \brief Create default schedule rules for LLVM */ - TVM_DLL static Array DefaultLLVM(); + TVM_DLL static ffi::Array DefaultLLVM(); /*! \brief Create default schedule rules for x86 (AVX512 and VNNI) */ - TVM_DLL static Array DefaultX86(const String& type); + TVM_DLL static ffi::Array DefaultX86(const ffi::String& type); /*! \brief Create default schedule rules for CUDA */ - TVM_DLL static Array DefaultCUDA(); + TVM_DLL static ffi::Array DefaultCUDA(); /*! \brief Create default postprocessors for CUDA with TensorCore */ - TVM_DLL static Array DefaultCUDATensorCore(); + TVM_DLL static ffi::Array DefaultCUDATensorCore(); /*! \brief Create default schedule rules for Hexagon */ - TVM_DLL static Array DefaultHexagon(); + TVM_DLL static ffi::Array DefaultHexagon(); /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */ - TVM_DLL static Array DefaultARM(const String& type); + TVM_DLL static ffi::Array DefaultARM(const ffi::String& type); /*! \brief Create default schedule rules for RISCV CPU (RVV) */ - TVM_DLL static Array DefaultRISCV(int vlen); + TVM_DLL static ffi::Array DefaultRISCV(int vlen); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; @@ -332,7 +340,7 @@ class PyScheduleRuleNode : public ScheduleRuleNode { } void InitializeWithTuneContext(const TuneContext& context) final; - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; ScheduleRule Clone() const final; static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 9e1af10a01d6..8d49ff25fffa 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -98,9 +98,9 @@ class SearchStrategyNode : public runtime::Object { * and reset the search strategy. */ virtual void PreTuning(int max_trials, int num_trials_per_iter, - const Array& design_spaces, - const Optional& database, - const Optional& cost_model) = 0; + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) = 0; /*! * \brief Post-tuning for the search strategy. @@ -113,15 +113,15 @@ class SearchStrategyNode : public runtime::Object { * \brief Generate measure candidates from design spaces for measurement. * \return The measure candidates generated, nullptr if finished. */ - virtual Optional> GenerateMeasureCandidates() = 0; + virtual ffi::Optional> GenerateMeasureCandidates() = 0; /*! * \brief Update the search strategy with measurement results. * \param measure_candidates The candidates to be measured. * \param results The measurement results from the runner. */ - virtual void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) = 0; + virtual void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) = 0; /*! * \brief Clone the search strategy. @@ -147,22 +147,23 @@ class SearchStrategy : public runtime::ObjectRef { /*! * \brief The function type of `PreTuning` method. */ - using FPreTuning = - ffi::TypedFunction&, - const Optional&, const Optional&)>; + using FPreTuning = ffi::TypedFunction&, + const ffi::Optional&, const ffi::Optional&)>; /*! \brief The function type of `PostTuning` method. */ using FPostTuning = ffi::TypedFunction; /*! * \brief The function type of `GenerateMeasureCandidates` method. * \return The measure candidates generated, nullptr if finished. */ - using FGenerateMeasureCandidates = ffi::TypedFunction>()>; + using FGenerateMeasureCandidates = + ffi::TypedFunction>()>; /*! * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ - using FNotifyRunnerResults = - ffi::TypedFunction&, const Array&)>; + using FNotifyRunnerResults = ffi::TypedFunction&, + const ffi::Array&)>; /*! * \brief The function type of `Clone` method. * \return The cloned search strategy. @@ -251,12 +252,14 @@ class PySearchStrategyNode : public SearchStrategyNode { } void InitializeWithTuneContext(const TuneContext& context) final; - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final; + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final; void PostTuning() final; - Optional> GenerateMeasureCandidates() final; - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); + ffi::Optional> GenerateMeasureCandidates() final; + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results); SearchStrategy Clone() const final; static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 7b26b56abbed..f013934e2342 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -76,11 +76,11 @@ class SpaceGenerator; class SpaceGeneratorNode : public runtime::Object { public: /*! \brief The schedule rules. */ - Optional> sch_rules; + ffi::Optional> sch_rules; /*! \brief The postprocessors. */ - Optional> postprocs; + ffi::Optional> postprocs; /*! \brief The probability of using certain mutator. */ - Optional> mutator_probs; + ffi::Optional> mutator_probs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -105,7 +105,7 @@ class SpaceGeneratorNode : public runtime::Object { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + virtual ffi::Array GenerateDesignSpace(const IRModule& mod) = 0; /*! * \brief Clone the space generator. @@ -133,7 +133,7 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mod The module used for design space generation. * \return The generated design spaces, i.e., schedules. */ - using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; + using FGenerateDesignSpace = ffi::TypedFunction(const IRModule&)>; /*! * \brief The function type of `Clone` method. * \return The cloned space generator. @@ -155,8 +155,9 @@ class SpaceGenerator : public runtime::ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - Optional> sch_rules, Optional> postprocs, - Optional> mutator_probs, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone); /*! @@ -164,15 +165,15 @@ class SpaceGenerator : public runtime::ObjectRef { * \param schedule_fn The schedule function, which can have the following signatures: * 1) void(Schedule) * 2) Schedule(Schedule) - * 3) Array(Schedule) + * 3) ffi::Array(Schedule) * \param sch_rules The schedule rules. * \param postprocs The postprocessors. * \param mutator_probs The probability of using certain mutator. */ - TVM_DLL static SpaceGenerator ScheduleFn(ffi::Function schedule_fn, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); + TVM_DLL static SpaceGenerator ScheduleFn( + ffi::Function schedule_fn, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); /*! * \brief Create a design space generator that is union of multiple design space generators. * \param space_generators An array of design space generators to be unioned. @@ -181,10 +182,11 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); + TVM_DLL static SpaceGenerator SpaceGeneratorUnion( + ffi::Array space_generators, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); /*! * \brief Create a design space generator that generates design spaces by applying schedule * rules to blocks in post-DFS order. @@ -194,10 +196,10 @@ class SpaceGenerator : public runtime::ObjectRef { * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(ffi::Function f_block_filter, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs); + TVM_DLL static SpaceGenerator PostOrderApply( + ffi::Function f_block_filter, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; @@ -221,7 +223,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { } void InitializeWithTuneContext(const TuneContext& context) final; - Array GenerateDesignSpace(const IRModule& mod) final; + ffi::Array GenerateDesignSpace(const IRModule& mod) final; SpaceGenerator Clone() const final; static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 9c1300d2433f..0c88cb12c8cc 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -54,11 +54,11 @@ class TaskRecordNode : public runtime::Object { /*! \brief The latency of each run, in milliseconds. */ std::vector latency_ms = {}; /*! \brief The measure candidates. */ - Optional> measure_candidates = std::nullopt; + ffi::Optional> measure_candidates = std::nullopt; /*! \brief The building results. */ - Optional> builder_results = std::nullopt; + ffi::Optional> builder_results = std::nullopt; /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures = std::nullopt; + ffi::Optional> runner_futures = std::nullopt; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -131,13 +131,13 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The tuning task's logging function. */ ffi::Function logger; /*! \brief Records for each task */ - Array tasks_; + ffi::Array tasks_; /*! \brief The list of measure callbacks of the scheduler. */ - Array measure_callbacks_; + ffi::Array measure_callbacks_; /*! \brief The database used in tuning */ - Optional database_; + ffi::Optional database_; /*! \brief The cost model used in tuning */ - Optional cost_model_; + ffi::Optional cost_model_; /*! \brief The number of remaining tasks to be tuned. */ int remaining_tasks_; @@ -164,7 +164,7 @@ class TaskSchedulerNode : public runtime::Object { * \param task_id The task id to be joined. * \return The results from the runner. */ - virtual Array JoinRunningTask(int task_id); + virtual ffi::Array JoinRunningTask(int task_id); /*! * \brief Jointly tune a given list of tasks. * \param tasks The tasks to be tuned @@ -178,16 +178,16 @@ class TaskSchedulerNode : public runtime::Object { * \param database The database used in tuning * \param cost_model The cost model used in tuning */ - virtual void Tune(Array tasks, // - Array task_weights, // - int max_trials_global, // - int max_trials_per_task, // - int num_trials_per_iter, // - Builder builder, // - Runner runner, // - Array measure_callbacks, // - Optional database, // - Optional cost_model); + virtual void Tune(ffi::Array tasks, // + ffi::Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + ffi::Array measure_callbacks, // + ffi::Optional database, // + ffi::Optional cost_model); /*! * \brief Terminate a task * \param task_id The id of the task to be terminated @@ -219,18 +219,18 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { * \brief The function type of `JoinRunningTask` method. * \param task_id The task id to be joined. */ - using FJoinRunningTask = ffi::TypedFunction(int)>; + using FJoinRunningTask = ffi::TypedFunction(int)>; /*! \brief The function type of `Tune` method. */ - using FTune = ffi::TypedFunction tasks, // - Array task_weights, // - int max_trials_global, // - int max_trials_per_task, // - int num_trials_per_iter, // - Builder builder, // - Runner runner, // - Array measure_callbacks, // - Optional database, // - Optional cost_model)>; + using FTune = ffi::TypedFunction tasks, // + ffi::Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + ffi::Array measure_callbacks, // + ffi::Optional database, // + ffi::Optional cost_model)>; /*! \brief The packed function to the `NextTaskId` function. */ FNextTaskId f_next_task_id; @@ -245,11 +245,11 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { } int NextTaskId() final; - Array JoinRunningTask(int task_id) final; - void Tune(Array tasks, Array task_weights, int max_trials_global, + ffi::Array JoinRunningTask(int task_id) final; + void Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) final; + ffi::Array measure_callbacks, ffi::Optional database, + ffi::Optional cost_model) final; static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 47326ac46b99..cd9b8f1b5ad2 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -48,15 +48,15 @@ class TuneContextNode : public runtime::Object { using TRandState = support::LinearCongruentialEngine::TRandState; /*! \brief The workload to be tuned. */ - Optional mod; + ffi::Optional mod; /*! \brief The target to be tuned for. */ - Optional target; + ffi::Optional target; /*! \brief The design space generator. */ - Optional space_generator; + ffi::Optional space_generator; /*! \brief The search strategy. */ - Optional search_strategy; + ffi::Optional search_strategy; /*! \brief The name of the tuning task. */ - Optional task_name; + ffi::Optional task_name; /*! \brief The number of threads to be used. */ int num_threads; /*! \brief The random state. */ @@ -109,10 +109,11 @@ class TuneContext : public runtime::ObjectRef { * \param rand_state The random state. * \param logger The tuning task's logging function. */ - TVM_DLL explicit TuneContext(Optional mod, Optional target, - Optional space_generator, - Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, ffi::Function logger); + TVM_DLL explicit TuneContext(ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, + ffi::Optional task_name, int num_threads, + TRandState rand_state, ffi::Function logger); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); }; diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 37dc710ac161..e273fa8f5fe1 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -86,7 +86,7 @@ class AttrRegistryMapContainerMap { private: /*! \brief The name of the attr field */ - String attr_name_; + ffi::String attr_name_; /*! \brief The internal data. */ std::vector> data_; /*! \brief The constructor */ @@ -97,7 +97,7 @@ class AttrRegistryMapContainerMap { }; /*! - * \brief Map used to store meta-data. + * \brief ffi::Map used to store meta-data. * \tparam KeyType The type of the key * \tparam ValueType The type of the value stored in map. */ diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h index ae23c9e9aa33..4ed5f4178c8b 100644 --- a/include/tvm/node/cast.h +++ b/include/tvm/node/cast.h @@ -57,7 +57,7 @@ inline SubRef Downcast(BaseRef ref) { } TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" << SubRef::ContainerType::_type_key - << "` is not allowed. Use Downcast> instead."; + << "` is not allowed. Use Downcast> instead."; TVM_FFI_UNREACHABLE(); } } diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 7c8c2bfb9214..d5716f96f6d5 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -34,7 +34,8 @@ namespace tvm { * \param fields The fields of the object. * \return The created object. */ -TVM_DLL ffi::Any CreateObject(const String& type_key, const Map& fields); +TVM_DLL ffi::Any CreateObject(const ffi::String& type_key, + const ffi::Map& fields); } // namespace tvm #endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 05687d70d742..f3e0edab6e07 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -83,7 +83,7 @@ inline std::ostream& operator<<(std::ostream& os, const Any& n) { // NOLINT(*) } template -inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const ffi::Variant& n) { // NOLINT(*) ReprPrinter(os).Print(Any(n)); return os; } @@ -94,7 +94,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { namespace refl = ffi::reflection; switch (step->kind) { case refl::AccessKind::kAttr: { - os << '.' << step->key.cast(); + os << '.' << step->key.cast(); return os; } case refl::AccessKind::kArrayItem: { @@ -106,7 +106,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { return os; } case refl::AccessKind::kAttrMissing: { - os << ".key.cast() << "`>"; + os << ".key.cast() << "`>"; return os; } case refl::AccessKind::kArrayItemMissing: { @@ -125,7 +125,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { } inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) { - Array steps = path->ToSteps(); + ffi::Array steps = path->ToSteps(); os << ""; for (const auto& step : steps) { os << step; diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index d046dbfae732..03468150d61e 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -40,7 +40,7 @@ namespace tvm { class PrinterConfigNode : public ffi::Object { public: /*! \brief A stack that tracks the names of the binding hierarchy */ - Array binding_names = {}; + ffi::Array binding_names = {}; /*! \brief Whether or not to show metadata. */ bool show_meta = false; /*! \brief The prefix of IR nodes */ @@ -113,13 +113,13 @@ class PrinterConfigNode : public ffi::Object { bool show_all_struct_info = true; /* \brief Object path to be underlined */ - Array path_to_underline; + ffi::Array path_to_underline; /*! \brief Object path to be annotated. */ - Map path_to_annotate; + ffi::Map path_to_annotate; /*! \brief Object to be underlined. */ - Array obj_to_underline = Array(); + ffi::Array obj_to_underline = ffi::Array(); /*! \brief Object to be annotated. */ - Map obj_to_annotate = Map(); + ffi::Map obj_to_annotate = ffi::Map(); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -146,7 +146,7 @@ class PrinterConfigNode : public ffi::Object { .def_ro("obj_to_annotate", &PrinterConfigNode::obj_to_annotate); } - Array GetBuiltinKeywords(); + ffi::Array GetBuiltinKeywords(); static constexpr const char* _type_key = "script.PrinterConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); @@ -154,7 +154,8 @@ class PrinterConfigNode : public ffi::Object { class PrinterConfig : public ObjectRef { public: - explicit PrinterConfig(Map config_dict = Map()); + explicit PrinterConfig( + ffi::Map config_dict = ffi::Map()); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, PrinterConfigNode); @@ -164,15 +165,16 @@ class PrinterConfig : public ObjectRef { class TVMScriptPrinter { public: /* Convert the object to TVMScript format */ - static std::string Script(const ObjectRef& node, const Optional& cfg); + static std::string Script(const ObjectRef& node, const ffi::Optional& cfg); // Allow registration to be printer. using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; -#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ - std::string Script(const Optional& config = std::nullopt) const { \ - return TVMScriptPrinter::Script(GetRef(this), config.value_or(PrinterConfig())); \ +#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ + std::string Script(const ffi::Optional& config = std::nullopt) const { \ + return TVMScriptPrinter::Script(ffi::GetRef(this), \ + config.value_or(PrinterConfig())); \ } } // namespace tvm diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 12ba59118b72..4f00e1770b41 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -58,10 +58,10 @@ class BaseValueEqual { bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } - bool operator()(const Optional& lhs, const Optional& rhs) const { + bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { return lhs == rhs; } - bool operator()(const Optional& lhs, const Optional& rhs) const { + bool operator()(const ffi::Optional& lhs, const ffi::Optional& rhs) const { return lhs == rhs; } bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 2c0c54db4121..ba7cbaf88aa6 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -78,14 +78,14 @@ class BaseValueHash { uint64_t operator()(const std::string& key) const { return tvm::ffi::details::StableHashBytes(key.data(), key.length()); } - uint64_t operator()(const Optional& key) const { + uint64_t operator()(const ffi::Optional& key) const { if (key.has_value()) { return Reinterpret(*key); } else { return 0; } } - uint64_t operator()(const Optional& key) const { + uint64_t operator()(const ffi::Optional& key) const { if (key.has_value()) { return Reinterpret(*key); } else { diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 267eb1b66eeb..73d1a3dbebce 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -53,7 +53,7 @@ namespace relax { * if result is false, there is still possibility that * two shapes equals to each other during runtime. */ -TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, +TVM_DLL bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, arith::Analyzer* ana); /*! @@ -155,11 +155,11 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * * \return the corresponding erased struct info. */ -TVM_DLL StructInfo -EraseToWellDefined(const StructInfo& info, - std::function(const tir::Var& var)> f_shape_var_map = nullptr, - std::function(const Var& var)> f_var_map = nullptr, - arith::Analyzer* ana = nullptr); +TVM_DLL StructInfo EraseToWellDefined( + const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); /*! * \brief EraseToWellDefined variant with map. @@ -174,8 +174,9 @@ EraseToWellDefined(const StructInfo& info, * * \return the corresponding erased struct info. */ -TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, - Map var_map, arith::Analyzer* ana = nullptr); +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, + ffi::Map shape_var_map, + ffi::Map var_map, arith::Analyzer* ana = nullptr); /*! * \brief Fine grained result of base check. @@ -289,7 +290,7 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, * \param sinfo The struct info object to be analyzed. * \return The list of TIR variables that appear in the input struct info. */ -TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); /*! * \brief Get the TIR variables that appear in the input struct info. @@ -303,7 +304,7 @@ TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); * deduplicated, each TIR variable will appear at most once, and in * order of occurrence. */ -TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); /*! \brief Collect expressions whose usage requires them to be non-negative * @@ -316,7 +317,7 @@ TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); * * \return A list of non-negative expressions. */ -TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); +TVM_DLL ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo); /*! * \brief Get the TIR variables that defined in the input function. @@ -324,7 +325,7 @@ TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are defined in the input function. */ -TVM_DLL Array DefinedSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array DefinedSymbolicVars(const Expr& expr); /*! * \brief Get the TIR variables that are used but not defined in the input function. @@ -332,7 +333,7 @@ TVM_DLL Array DefinedSymbolicVars(const Expr& expr); * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are used but not defined in the input function. */ -TVM_DLL Array FreeSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array FreeSymbolicVars(const Expr& expr); //----------------------------------- // General IR analysis //----------------------------------- @@ -346,7 +347,7 @@ TVM_DLL Array FreeSymbolicVars(const Expr& expr); * * \return List of bound vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array BoundVars(const Expr& expr); +TVM_DLL tvm::ffi::Array BoundVars(const Expr& expr); /*! * \brief Get free type parameters from expression expr. @@ -358,7 +359,7 @@ TVM_DLL tvm::Array BoundVars(const Expr& expr); * * \return List of free vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array FreeVars(const Expr& expr); +TVM_DLL tvm::ffi::Array FreeVars(const Expr& expr); /*! * \brief Get all variables from expression expr. @@ -367,7 +368,7 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); * * \return List of all vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllVars(const Expr& expr); +TVM_DLL tvm::ffi::Array AllVars(const Expr& expr); /*! * \brief Get all global variables from expression expr. @@ -379,7 +380,7 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); * * \return List of all global variables, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); +TVM_DLL tvm::ffi::Array AllGlobalVars(const Expr& expr); /*! * \brief Find all sets of recursive or mutually recursive functions in the module. @@ -404,7 +405,7 @@ TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); * If a function is simply recursive and not mutually recursive with any other, * then it will be listed as a group by itself. */ -TVM_DLL tvm::Array> DetectRecursion(const IRModule& m); +TVM_DLL tvm::ffi::Array> DetectRecursion(const IRModule& m); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -412,7 +413,7 @@ TVM_DLL tvm::Array> DetectRecursion(const IRModule& m); * \param m The IRModule to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const IRModule& m); +TVM_DLL ffi::Map AnalyzeVar2Value(const IRModule& m); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -420,7 +421,7 @@ TVM_DLL Map AnalyzeVar2Value(const IRModule& m); * \param expr The expression to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const Expr& expr); +TVM_DLL ffi::Map AnalyzeVar2Value(const Expr& expr); /*! * \brief Analyze var -> value mapping from VarBindings. @@ -428,7 +429,7 @@ TVM_DLL Map AnalyzeVar2Value(const Expr& expr); * \param dfb The dataflow block to check. * \return Var -> Value (Expr) */ -TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); +TVM_DLL ffi::Map AnalyzeVar2Value(const DataflowBlock& dfb); /*! * \brief Return a mapping from variable name to its Bindings. @@ -436,7 +437,7 @@ TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); * \param fn The function to be analyzed. * \return A mapping from variable name to its Bindings. */ -TVM_DLL Map> NameToBinding(const Function& fn); +TVM_DLL ffi::Map> NameToBinding(const Function& fn); /*! * \brief Get the use-def chain of variables inside a dataflow block. @@ -444,7 +445,7 @@ TVM_DLL Map> NameToBinding(const Function& fn); * \param dfb The dataflow block to be analyzed. * \return A map mapping variable definitions to a set of uses. */ -TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); +TVM_DLL ffi::Map> DataflowBlockUseDef(const DataflowBlock& dfb); /*! * \brief Get the use-def chain of variables inside a function. @@ -457,7 +458,7 @@ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); * variables whose usage occurs outside of any variable binding, * typically the output body of a relax::Function or a relax::SeqExpr. */ -std::pair>, Array> FunctionUseDef(const Expr& expr); +std::pair>, ffi::Array> FunctionUseDef(const Expr& expr); /*! \brief A utility struct returned by CollectVarUsage */ @@ -466,19 +467,19 @@ struct VarUsageInfo { * * This is equivalent to the output of AnalyzeVar2Value */ - Map bound_values; + ffi::Map bound_values; /* \brief The map from variables to downstream usages of the variable * * This is equivalent to the first output of FunctionUseDef. */ - Map> downstream_usage; + ffi::Map> downstream_usage; /* \brief A list of variables produced as output * * This is equivalent to the second output of FunctionUseDef */ - Array outputs; + ffi::Array outputs; }; /*! \brief Collect variable bindings and usage @@ -541,8 +542,8 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ -TVM_DLL Optional FindImpureCall( - const Expr& expr, const Optional& own_name = Optional(std::nullopt)); +TVM_DLL ffi::Optional FindImpureCall( + const Expr& expr, const ffi::Optional& own_name = ffi::Optional(std::nullopt)); /*! * \brief Check if the given expression (likely a function body) contains any impure calls. @@ -555,8 +556,8 @@ TVM_DLL Optional FindImpureCall( * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. */ -TVM_DLL bool ContainsImpureCall(const Expr& expr, - const Optional& own_name = Optional(std::nullopt)); +TVM_DLL bool ContainsImpureCall( + const Expr& expr, const ffi::Optional& own_name = ffi::Optional(std::nullopt)); /*! * \brief Check if the IRModule is well formed. @@ -569,7 +570,7 @@ TVM_DLL bool ContainsImpureCall(const Expr& expr, * where `check_struct_info` might be false, so that other well-formed requirements * will be well tested and will not be blocked by not having structure info. */ -TVM_DLL bool WellFormed(Variant obj, bool check_struct_info = true); +TVM_DLL bool WellFormed(ffi::Variant obj, bool check_struct_info = true); /*! * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks @@ -581,8 +582,8 @@ TVM_DLL bool WellFormed(Variant obj, bool check_struct_info * from the object (block or buffer) to it's index map transformation. */ -TVM_DLL Map> SuggestLayoutTransforms( - const Function& fn, Array write_buffer_transformations); +TVM_DLL ffi::Map> SuggestLayoutTransforms( + const Function& fn, ffi::Array write_buffer_transformations); /* \brief Collect variables whose value can be computed at compile-time * @@ -597,7 +598,7 @@ TVM_DLL Map> SuggestLayoutTransforms( * \return The set of variables that can be computed at compile-time, * in order of their occurrence within the function. */ -TVM_DLL Array ComputableAtCompileTime(const Function& func); +TVM_DLL ffi::Array ComputableAtCompileTime(const Function& func); } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index e6736dd2e731..b1f2632acc5c 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -32,7 +32,7 @@ namespace relax { /*! \brief Attributes used in allreduce operators */ struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { - String op_type; + ffi::String op_type; bool in_group; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index 544ad1ebd1dc..778dffbc55c3 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -31,11 +31,11 @@ namespace relax { /*! \brief Attributes used in image resize2d operator */ struct Resize2DAttrs : public AttrsNodeReflAdapter { - Array roi; - String layout; - String method; - String coordinate_transformation_mode; - String rounding_method; + ffi::Array roi; + ffi::String layout; + ffi::String method; + ffi::String coordinate_transformation_mode; + ffi::String rounding_method; double cubic_alpha; int cubic_exclude; double extrapolation_value; diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index cc914449db30..827fa67eb113 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -31,8 +31,8 @@ namespace relax { /*! \brief Attributes used in take operator */ struct TakeAttrs : public AttrsNodeReflAdapter { - Optional axis; - String mode; + ffi::Optional axis; + ffi::String mode; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index 041b9cb1bef4..2ba871aec63a 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -45,7 +45,7 @@ struct MatmulAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in einsum operator */ struct EinsumAttrs : public AttrsNodeReflAdapter { - String subscripts; + ffi::String subscripts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 6a7cfe0baba2..af4d5f5b806b 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -32,7 +32,7 @@ namespace relax { /*! \brief Attributes used in concat operators */ struct ConcatAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -47,7 +47,7 @@ struct ConcatAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in expand_dims operators */ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { - Array axis; + ffi::Array axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -67,20 +67,20 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter tir::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. - Optional pad_value; + ffi::Optional pad_value; /*! * axis_separators between input axes when generating flattened output axes. For buffers * representing flat 1-d memory (e.g. any buffer in RAM), this should be an empty array. * For buffers representing non-flat memory, each entry in axis_separators should be the * first input axis that is part of a new flattened axis. */ - Optional> axis_separators; + ffi::Optional> axis_separators; /*! * axis_separators for input buffers. * Needed to identify if the input buffer to layout_transform * contains axis separator. */ - Optional> input_axis_separators; + ffi::Optional> input_axis_separators; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -103,7 +103,7 @@ struct LayoutTransformAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in permute_dims operator */ struct PermuteDimsAttrs : public AttrsNodeReflAdapter { - Optional> axes; + ffi::Optional> axes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -134,7 +134,7 @@ struct SplitAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public AttrsNodeReflAdapter { - Optional> axis; + ffi::Optional> axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -151,7 +151,7 @@ struct SqueezeAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in stack operators */ struct StackAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -170,7 +170,7 @@ struct StackAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in repeat operators */ struct RepeatAttrs : public AttrsNodeReflAdapter { int repeats; - Optional axis; + ffi::Optional axis; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -188,7 +188,7 @@ struct RepeatAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in tile operators */ struct TileAttrs : public AttrsNodeReflAdapter { - Array repeats; + ffi::Array repeats; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -264,7 +264,7 @@ struct IndexPutAttrs : public AttrsNodeReflAdapter { /*! \brief Attribute used in meshgrid operator */ struct MeshgridAttrs : public AttrsNodeReflAdapter { - Optional indexing; + ffi::Optional indexing; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -279,7 +279,7 @@ struct MeshgridAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public AttrsNodeReflAdapter { Integer axis; - String reduction; + ffi::String reduction; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -298,7 +298,7 @@ struct ScatterElementsAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in scatter_nd operators */ struct ScatterNDAttrs : public AttrsNodeReflAdapter { - String reduction; + ffi::String reduction; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 9f09bce6af2c..b21a68fb82c0 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -31,13 +31,13 @@ namespace relax { /*! \brief Attributes used in Conv1d operator */ struct Conv1DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -77,13 +77,13 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv2d operator */ struct Conv2DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -125,13 +125,13 @@ struct Conv2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv3d operator */ struct Conv3DAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -175,14 +175,14 @@ struct Conv3DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Conv1DTranspose operator */ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array output_padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -225,14 +225,14 @@ struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in Conv2d operator */ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter { - Array strides; - Array padding; - Array output_padding; - Array dilation; + ffi::Array strides; + ffi::Array padding; + ffi::Array output_padding; + ffi::Array dilation; int groups; - String data_layout; - String kernel_layout; - String out_layout; + ffi::String data_layout; + ffi::String kernel_layout; + ffi::String out_layout; DataType out_dtype; static void RegisterReflection() { @@ -277,14 +277,14 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter /*! \brief Attributes used in max_pool1d and avg_pool1d operator */ struct Pool1DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -320,14 +320,14 @@ struct Pool1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in max_pool2d and avg_pool2d operator */ struct Pool2DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -365,14 +365,14 @@ struct Pool2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in max_pool3d and avg_pool3d operator */ struct Pool3DAttrs : public AttrsNodeReflAdapter { - Array pool_size; - Array strides; - Array padding; - Array dilation; + ffi::Array pool_size; + ffi::Array strides; + ffi::Array padding; + ffi::Array dilation; bool ceil_mode; bool count_include_pad; - String layout; - String out_layout; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -410,9 +410,9 @@ struct Pool3DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 1d adaptive pool operator */ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -436,9 +436,9 @@ struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 2d adaptive pool operator */ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -462,9 +462,9 @@ struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes for 3d adaptive pool operator */ struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter { - Optional> output_size; - String layout; - String out_layout; + ffi::Optional> output_size; + ffi::String layout; + ffi::String out_layout; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -577,7 +577,7 @@ struct BatchNormAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public AttrsNodeReflAdapter { - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -603,7 +603,7 @@ struct LayerNormAttrs : public AttrsNodeReflAdapter { struct GroupNormAttrs : public AttrsNodeReflAdapter { int num_groups; int channel_axis; - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -633,7 +633,7 @@ struct GroupNormAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public AttrsNodeReflAdapter { int channel_axis; - Array axes; + ffi::Array axes; double epsilon; bool center; bool scale; @@ -659,7 +659,7 @@ struct InstanceNormAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in rms_norm operator */ struct RMSNormAttrs : public AttrsNodeReflAdapter { - Array axes; + ffi::Array axes; double epsilon; static void RegisterReflection() { @@ -677,7 +677,7 @@ struct RMSNormAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in nll_loss operator */ struct NLLLossAttrs : public AttrsNodeReflAdapter { - String reduction; + ffi::String reduction; int ignore_index; static void RegisterReflection() { @@ -711,9 +711,9 @@ struct DropoutAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public AttrsNodeReflAdapter { - Optional scale; - Optional causal_mask; - Optional window_size; + ffi::Optional scale; + ffi::Optional causal_mask; + ffi::Optional window_size; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -733,9 +733,9 @@ struct AttentionAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public AttrsNodeReflAdapter { - Array pad_width; + ffi::Array pad_width; double pad_value = 0.0; - tvm::String pad_mode; + tvm::ffi::String pad_mode; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 8af3f77539fe..5f4956f93caf 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -32,8 +32,8 @@ namespace relax { /*! \brief Attributes used in call_tir_with_grad */ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter { - String te_grad_name; - Map te_grad_kwargs; + ffi::String te_grad_name; + ffi::Map te_grad_kwargs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -58,7 +58,7 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter { * store the `i`th output. If an element has the value -1, that means a new tensor should be * allocated for that output. */ - Array inplace_indices; + ffi::Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -79,7 +79,7 @@ struct CallInplacePackedAttrs : public AttrsNodeReflAdapter inplace_indices; + ffi::Array inplace_indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 6fdbe59cea74..4ba775f7a76f 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for search operators */ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; bool keepdims; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index 81705c71a261..4dbf7e172f0b 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -82,7 +82,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter { int k; int axis; bool largest; - String ret_type; + ffi::String ret_type; DataType dtype; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index c61169dc9923..48e0d196dbe7 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for statistical operators */ struct StatisticalAttrs : public AttrsNodeReflAdapter { - Optional> axis; + ffi::Optional> axis; bool keepdims; static void RegisterReflection() { @@ -51,7 +51,7 @@ struct StatisticalAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in scan operators like cumsum, cumprod */ struct ScanopAttrs : public AttrsNodeReflAdapter { - Optional axis; + ffi::Optional axis; DataType dtype; Bool exclusive = Bool(false); diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index bdb405a0af6e..e6f574808955 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -46,7 +46,7 @@ class DataflowBlockRewriteNode : public Object { /*! \brief Insert a Binding statement. */ void Add(Binding binding); /*! \brief Insert an expression as VarBinding with variable name. */ - void Add(String var_name, Expr expr, bool is_dfvar = false) { + void Add(ffi::String var_name, Expr expr, bool is_dfvar = false) { auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // : Var(var_name, GetStructInfo(expr)); Add(VarBinding(std::move(var), std::move(expr))); @@ -81,11 +81,11 @@ class DataflowBlockRewriteNode : public Object { protected: friend class DataflowBlockRewrite; - DataflowBlock dfb_; //!< The rewritten dataflow block. - Optional root_fn_; //!< The rewritten function. - const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. - Map> to_users_; //!< Map from variable to its users. - Array fn_outputs_; //!< Variables required by function outputs. + DataflowBlock dfb_; //!< The rewritten dataflow block. + ffi::Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + ffi::Map> to_users_; //!< Map from variable to its users. + ffi::Array fn_outputs_; //!< Variables required by function outputs. private: NameSupply name_supply_; //!< Name supply for tracking and generating unique names. diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index c33d99b5f91f..b93a2090f6e2 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -104,7 +104,7 @@ class BlockBuilderNode : public Object { * GlobalVar directly. * \return The global var bound to the added function. */ - virtual GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) = 0; + virtual GlobalVar AddFunction(const BaseFunc& func, ffi::String func_name_hint) = 0; /*! * \brief Update a Relax function or a TIR PrimFunc in the internal context module. @@ -128,7 +128,7 @@ class BlockBuilderNode : public Object { * \return The Expr bound to the input \p var. * \note For function parameters, this function returns std::nullopt. */ - virtual Optional LookupBinding(const Var& var) = 0; + virtual ffi::Optional LookupBinding(const Var& var) = 0; /*! * \brief Begin a new scope, with optional parameters that @@ -144,7 +144,7 @@ class BlockBuilderNode : public Object { * * \sa EndScope */ - virtual void BeginScope(Optional> params) = 0; + virtual void BeginScope(ffi::Optional> params) = 0; /*! * \brief Begin a new scope, which inherits visible parameters from @@ -204,7 +204,7 @@ class BlockBuilderNode : public Object { * \note This Emit function normalizes the \p expr, and * performs shape and type deductions by calling Normalize. */ - virtual Var Emit(Expr expr, String name_hint = "") = 0; + virtual Var Emit(Expr expr, ffi::String name_hint = "") = 0; /*! * \brief Emit a MatchCast. @@ -213,7 +213,7 @@ class BlockBuilderNode : public Object { * \param name_hint Name hint for the bound variable. * \return The variable bound to the MatchCast. */ - virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint = "") = 0; /*! * \brief Generate an output for the current dataflow block. @@ -221,7 +221,7 @@ class BlockBuilderNode : public Object { * \param name_hint Name hint for the bound variable. * \return The variable bound to \p output. */ - virtual Var EmitOutput(Expr output, String name_hint = "") = 0; + virtual Var EmitOutput(Expr output, ffi::String name_hint = "") = 0; /*! * \brief Emit a binding that is already normalized. @@ -274,7 +274,7 @@ class BlockBuilder : public ObjectRef { * ctx_mod so you can lookup the context functions for cross function * call analysis. */ - TVM_DLL static BlockBuilder Create(Optional ctx_mod); + TVM_DLL static BlockBuilder Create(ffi::Optional ctx_mod); /*! \brief A marker struct to disable FNormalize * @@ -315,7 +315,7 @@ class BlockBuilder : public ObjectRef { * ctx_mod so you can lookup the context functions for cross function * call analysis. */ - TVM_DLL static BlockBuilder Create(Optional ctx_mod, + TVM_DLL static BlockBuilder Create(ffi::Optional ctx_mod, DisableOperatorSpecificNormalizationForTVMScript tag); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 80359135c200..8a834d1fcd01 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -44,11 +44,12 @@ namespace relax { * \return true if matched * \return false if unmatched */ -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = std::nullopt); +bool MatchExpr(DFPattern pattern, Expr expr, + ffi::Optional> bindings = std::nullopt); /* \brief Similar to above, but return pairs of a matching pattern and an expression. */ -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings = std::nullopt); +ffi::Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, ffi::Optional> bindings = std::nullopt); /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. @@ -56,8 +57,8 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, * \param dfb The function to match. * \return Matched patterns and corresponding bound variables */ -TVM_DLL Optional> MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb); +TVM_DLL ffi::Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb); /** * \brief Rewrite a function with the given pattern and the rewriter function. @@ -70,7 +71,8 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, */ TVM_DLL Function RewriteBindings( const PatternContext& ctx, - ffi::TypedFunction(Map, Map)> rewriter, Function f); + ffi::TypedFunction(ffi::Map, ffi::Map)> rewriter, + Function f); /** * \brief Rewrite a function with the given pattern and the rewriter function. @@ -96,7 +98,7 @@ TVM_DLL Function RewriteBindings( * \return The updated function, if any updates were applied. */ TVM_DLL Function RewriteCall(const DFPattern& pattern, - ffi::TypedFunction)> rewriter, + ffi::TypedFunction)> rewriter, Function func); } // namespace relax diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index c302b29864ab..4a7fd73c6ac0 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -113,7 +113,7 @@ class DFPattern : public ObjectRef { /*! \brief Syntatic Sugar for creating a NotPattern */ TVM_DLL NotPattern operator~() const; /*! \brief Syntatic Sugar for creating an AttrPattern */ - TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + TVM_DLL AttrPattern HasAttr(const ffi::Map& attrs) const; /*! \brief Syntatic Sugar for creating a StructInfoPattern */ TVM_DLL StructInfoPattern HasStructInfo(const StructInfo& struct_info) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ @@ -121,7 +121,7 @@ class DFPattern : public ObjectRef { /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ - TVM_DLL ShapePattern HasShape(const Array& shape) const; + TVM_DLL ShapePattern HasShape(const ffi::Array& shape) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const; /*! \brief Syntatic Sugar for duplicating the current pattern */ @@ -165,7 +165,7 @@ struct PairCons { class DFConstraintNode : public Object { public: /*! \brief Return the patterns on which the constraint depends */ - virtual Array GetDependentPatterns() const = 0; + virtual ffi::Array GetDependentPatterns() const = 0; /*! \brief Convert the constraint to a PrimExpr * @@ -195,7 +195,7 @@ class DFConstraintNode : public Object { * sufficient for the constraint to be satisfied. */ virtual std::tuple AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const = 0; + std::function(const DFPatternNode*)> match_state) const = 0; static constexpr const char* _type_key = "DFConstraintNode"; static constexpr const uint32_t _type_child_slots = 1; @@ -213,7 +213,7 @@ class DFConstraint : public ObjectRef { */ class PatternSeqNode final : public Object { public: - tvm::Array patterns; /*!< The sequence of DFPatterns */ + tvm::ffi::Array patterns; /*!< The sequence of DFPatterns */ std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ static void RegisterReflection() { @@ -232,7 +232,7 @@ class PatternSeqNode final : public Object { class PatternSeq final : public ObjectRef { public: TVM_DLL explicit PatternSeq(DFPattern init_pattern); - TVM_DLL explicit PatternSeq(tvm::Array patterns, bool only_used_by = false); + TVM_DLL explicit PatternSeq(tvm::ffi::Array patterns, bool only_used_by = false); PatternSeq UsedBy(PatternSeq other, int index = -1) const; PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const; @@ -329,7 +329,7 @@ class PatternContext : public ObjectRef { } /*! \brief Get the constraint context object on the top of the stack */ - TVM_DLL static Optional Current(); + TVM_DLL static ffi::Optional Current(); /*! \brief The RAII-like entry of a constraint context scope */ TVM_DLL void EnterWithScope() const; @@ -374,8 +374,8 @@ class ExprPattern : public DFPattern { */ class VarPatternNode : public DFPatternNode { public: - String name; - const String& name_hint() const { return name; } + ffi::String name; + const ffi::String& name_hint() const { return name; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -398,7 +398,7 @@ class VarPattern : public DFPattern { * * \param name_hint Variable name to match. Any if empty (""). */ - TVM_DLL VarPattern(String name_hint); + TVM_DLL VarPattern(ffi::String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); }; @@ -424,7 +424,7 @@ class DataflowVarPatternNode : public VarPatternNode { class DataflowVarPattern : public DFPattern { public: /*! \sa VarPattern::VarPattern */ - TVM_DLL DataflowVarPattern(String name_hint); + TVM_DLL DataflowVarPattern(ffi::String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode); }; @@ -444,7 +444,7 @@ class GlobalVarPatternNode : public VarPatternNode { */ class GlobalVarPattern : public DFPattern { public: - TVM_DLL GlobalVarPattern(String name_hint); + TVM_DLL GlobalVarPattern(ffi::String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode); }; @@ -483,8 +483,8 @@ class CallPatternNode : public DFPatternNode { * - relax::Op which corresponds to the primitive operators. * - user defined functions (Function, GlobalVar, Var). */ - DFPattern op; /*!< The operator (function) being invoked */ - tvm::Array args; /*!< The arguments of the function call */ + DFPattern op; /*!< The operator (function) being invoked */ + tvm::ffi::Array args; /*!< The arguments of the function call */ /*! * \note If varg_default_wildcard is true. Given args of [pA, pB], when matching a call whose * arguments are [A, B, ...], the pattern will still match despite N(args) < N(call.args). That @@ -508,7 +508,7 @@ class CallPatternNode : public DFPatternNode { class CallPattern : public DFPattern { public: - TVM_DLL CallPattern(DFPattern op, Array args, bool varg_default_wildcard = false); + TVM_DLL CallPattern(DFPattern op, ffi::Array args, bool varg_default_wildcard = false); TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); }; @@ -519,7 +519,7 @@ class CallPattern : public DFPattern { */ class PrimArrPatternNode : public DFPatternNode { public: - Array fields; /*!< The array to match */ + ffi::Array fields; /*!< The array to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -536,7 +536,7 @@ class PrimArrPatternNode : public DFPatternNode { */ class PrimArrPattern : public DFPattern { public: - TVM_DLL PrimArrPattern(Array arr); + TVM_DLL PrimArrPattern(ffi::Array arr); TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode); }; @@ -547,7 +547,7 @@ class PrimArrPattern : public DFPattern { */ class FunctionPatternNode : public DFPatternNode { public: - tvm::Array params; /*!< The parameters of the function */ + tvm::ffi::Array params; /*!< The parameters of the function */ /*! * \note Note that in Relax, the function body is a SeqExpr which contains * 1) SeqExprNode::blocks, which is a list of blocks of statements; and 2) @@ -578,7 +578,7 @@ class FunctionPattern : public DFPattern { * \param params The parameters of the function. * \param body The body of the function. */ - TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + TVM_DLL FunctionPattern(tvm::ffi::Array params, DFPattern body); TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); }; @@ -589,7 +589,7 @@ class FunctionPattern : public DFPattern { */ class TuplePatternNode : public DFPatternNode { public: - tvm::Array fields; /*!< The fields of the tuple */ + tvm::ffi::Array fields; /*!< The fields of the tuple */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -606,7 +606,7 @@ class TuplePatternNode : public DFPatternNode { */ class TuplePattern : public DFPattern { public: - TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DLL explicit TuplePattern(tvm::ffi::Array fields); TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); }; @@ -616,7 +616,7 @@ class TuplePattern : public DFPattern { */ class UnorderedTuplePatternNode : public DFPatternNode { public: - tvm::Array fields; /*!< The fields of the tuple */ + tvm::ffi::Array fields; /*!< The fields of the tuple */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -634,7 +634,7 @@ class UnorderedTuplePatternNode : public DFPatternNode { */ class UnorderedTuplePattern : public DFPattern { public: - TVM_DLL explicit UnorderedTuplePattern(tvm::Array fields); + TVM_DLL explicit UnorderedTuplePattern(tvm::ffi::Array fields); TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode); }; @@ -819,8 +819,8 @@ class StructInfoPattern : public DFPattern { */ class ShapePatternNode : public DFPatternNode { public: - DFPattern pattern; /*!< The root pattern to match */ - Array shape; /*!< The shape to match */ + DFPattern pattern; /*!< The root pattern to match */ + ffi::Array shape; /*!< The shape to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -839,7 +839,7 @@ class ShapePatternNode : public DFPatternNode { */ class ShapePattern : public DFPattern { public: - TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DLL ShapePattern(DFPattern pattern, ffi::Array type); TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); }; @@ -849,12 +849,12 @@ class ShapePattern : public DFPattern { */ class SameShapeConstraintNode : public DFConstraintNode { public: - Array args; /*!< The patterns with matching shapes */ + ffi::Array args; /*!< The patterns with matching shapes */ - Array GetDependentPatterns() const override { return args; } + ffi::Array GetDependentPatterns() const override { return args; } std::tuple AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const override; + std::function(const DFPatternNode*)> match_state) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -871,7 +871,7 @@ class SameShapeConstraintNode : public DFConstraintNode { */ class SameShapeConstraint : public DFConstraint { public: - TVM_DLL SameShapeConstraint(Array args); + TVM_DLL SameShapeConstraint(ffi::Array args); TVM_DEFINE_OBJECT_REF_METHODS(SameShapeConstraint, DFConstraint, SameShapeConstraintNode); }; @@ -942,10 +942,10 @@ class AttrPattern : public DFPattern { */ class ExternFuncPatternNode : public DFPatternNode { public: - String global_symbol_; /*!< The global symbol name of the external function */ + ffi::String global_symbol_; /*!< The global symbol name of the external function */ /*! \brief The external function name */ - const String& global_symbol() const { return global_symbol_; } + const ffi::String& global_symbol() const { return global_symbol_; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -963,12 +963,12 @@ class ExternFuncPatternNode : public DFPatternNode { */ class ExternFuncPattern : public DFPattern { public: - TVM_DLL ExternFuncPattern(String global_symbol); + TVM_DLL ExternFuncPattern(ffi::String global_symbol); TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); }; /*! \brief Syntatic Sugar for creating a VarPattern with a name */ -VarPattern IsVar(const String& name); +VarPattern IsVar(const ffi::String& name); /*! \brief Syntatic Sugar for creating a ConstantPattern */ ConstantPattern IsConst(); /*! \brief Syntatic Sugar for creating a WildcardPattern */ @@ -976,26 +976,27 @@ WildcardPattern Wildcard(); /*! \brief Syntatic Sugar for creating a ExprPattern */ ExprPattern IsExpr(const Expr& expr); /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ -ExprPattern IsOp(const String& op_name); +ExprPattern IsOp(const ffi::String& op_name); /*! \brief Syntatic Sugar for call_tir (return a tensor) */ // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo -CallPattern IsCallTIR(const String& name, Optional args = std::nullopt); +CallPattern IsCallTIR(const ffi::String& name, ffi::Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ -CallPattern IsCallTIR(const String& name, TuplePattern var_args); +CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ -CallPattern IsCallDPSPacked(const String& name, Optional args = std::nullopt); +CallPattern IsCallDPSPacked(const ffi::String& name, + ffi::Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ -CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); +CallPattern IsCallDPSPacked(const ffi::String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ -DFPattern IsTuple(const Array& fields, bool unordered = false); +DFPattern IsTuple(const ffi::Array& fields, bool unordered = false); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1); /*! \brief Implementation of the templated CallPattern syntax sugar */ template CallPattern DFPattern::operator()(Args&&... args) const { - return CallPattern(GetRef(this->get()), - Array({std::forward(args)...})); + return CallPattern(ffi::GetRef(this->get()), + ffi::Array({std::forward(args)...})); } } // namespace relax diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 565aaa0835f5..ddb618e06b1f 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -58,7 +58,8 @@ class BufferAxisHash { * \param analyzer The analyzer * \return The iter var whose extent to be changed */ -Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::Analyzer* analyzer); +Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, + arith::Analyzer* analyzer); /*! * \brief Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they @@ -69,7 +70,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { static std::vector> GetTIRVarAxisGraph(const PrimFunc& prim_func) { BufferAxisGraphExtractor extractor; extractor(prim_func->body); - Map inverse_buffer_map; + ffi::Map inverse_buffer_map; for (const auto& pr : prim_func->buffer_map) { inverse_buffer_map.Set(pr.second, pr.first); } @@ -162,14 +163,14 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { arith::Analyzer analyzer; for (const auto& access_pr : buffer_access_indices_) { Buffer buffer = access_pr.first; - Array indices = access_pr.second; + ffi::Array indices = access_pr.second; for (int i = 0; i < static_cast(indices.size()); i++) { for (const auto& another_access_pr : buffer_access_indices_) { if (another_access_pr.first.same_as(buffer)) { continue; } Buffer another_buffer = another_access_pr.first; - Array another_indices = another_access_pr.second; + ffi::Array another_indices = another_access_pr.second; for (int j = 0; j < static_cast(another_indices.size()); j++) { if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j], &analyzer)) { @@ -192,9 +193,9 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { buffer_axis_graph_[axis2].push_back(axis1); } - std::vector>> buffer_access_indices_; + std::vector>> buffer_access_indices_; std::unordered_map, BufferAxisHash> buffer_axis_graph_; - Map iter_var_range_; + ffi::Map iter_var_range_; std::string func_name; }; } // namespace tir @@ -439,7 +440,7 @@ class AxisGroupGraph { } } ICHECK(specs.size() == 1) << "multiple possible sharding for axis: (" - << GetRef(axis.tensor) << ", " << axis.dim << ")"; + << ffi::GetRef(axis.tensor) << ", " << axis.dim << ")"; } } diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 5e0afc0dcaa7..4606388b43c1 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -40,10 +40,10 @@ class DeviceMeshNode : public GlobalInfoNode { ffi::Shape shape; /*! \brief device ids in the mesh*/ - Array device_ids; + ffi::Array device_ids; /*! \brief Optionally use range to represent device_ids*/ - Optional device_range; + ffi::Optional device_range; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -63,7 +63,7 @@ class DeviceMeshNode : public GlobalInfoNode { */ class DeviceMesh : public GlobalInfo { public: - TVM_DLL DeviceMesh(ffi::Shape shape, Array device_ids); + TVM_DLL DeviceMesh(ffi::Shape shape, ffi::Array device_ids); TVM_DLL DeviceMesh(ffi::Shape shape, Range device_range); TVM_DEFINE_OBJECT_REF_METHODS(DeviceMesh, GlobalInfo, DeviceMeshNode); }; diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index cd4c2e7daef2..9de7273d5ee0 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -86,9 +86,9 @@ class ShardingNode : public PlacementSpecNode { class PlacementNode : public Object { public: /*! \brief specs for each dim of device mesh.*/ - Array dim_specs; + ffi::Array dim_specs; - String ToString() const; + ffi::String ToString() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -106,9 +106,9 @@ class PlacementNode : public Object { */ class Placement : public ObjectRef { public: - TVM_DLL explicit Placement(Array dim_specs); + TVM_DLL explicit Placement(ffi::Array dim_specs); /*! \brief replica dim is printed as "R" and sharding dim is printed as "S[i]".]*/ - static Placement FromText(String text_repr); + static Placement FromText(ffi::String text_repr); TVM_DEFINE_OBJECT_REF_METHODS(Placement, ObjectRef, PlacementNode); }; diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index dd0539cb9666..464d42c2e423 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -62,7 +62,7 @@ class ExecBuilderNode : public Object { * \param init_register_size Initial setting of register file size. */ void EmitFunction(const std::string& func, int64_t num_inputs, - Optional> param_names, + ffi::Optional> param_names, vm::VMFuncInfo::FuncKind kind = vm::VMFuncInfo::FuncKind::kVMFunc, int64_t init_register_size = 0); /*! diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e7198fcf2237..e0e2f4770fe9 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -53,7 +53,7 @@ class IdNode : public Object { * this only acts as a hint to the user, * and is not used for equality. */ - String name_hint; + ffi::String name_hint; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -73,7 +73,7 @@ class Id : public ObjectRef { * \brief The constructor * \param name_hint The name of the variable. */ - TVM_DLL explicit Id(String name_hint); + TVM_DLL explicit Id(ffi::String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); }; @@ -152,7 +152,7 @@ class CallNode : public ExprNode { Expr op; /*! \brief The arguments(inputs) of the call */ - tvm::Array args; + tvm::ffi::Array args; /*! \brief The additional attributes */ Attrs attrs; @@ -163,7 +163,7 @@ class CallNode : public ExprNode { * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main * usage of structure info inference. */ - Array sinfo_args; + ffi::Array sinfo_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -188,8 +188,8 @@ class Call : public Expr { * \param sinfo_args The structure info arguments passed to a function. * \param span The source span of the expression. */ - TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), - Array sinfo_args = Array(), Span span = Span()); + TVM_DLL Call(Expr op, ffi::Array args, Attrs attrs = Attrs(), + ffi::Array sinfo_args = ffi::Array(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); @@ -200,17 +200,18 @@ class Call : public Expr { * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -Call WithFields(Call call, Optional opt_op = Optional(), - Optional> opt_args = Optional>(), - Optional opt_attrs = Optional(), - Optional> opt_sinfo_args = Optional>(), - Optional opt_span = Optional()); +Call WithFields( + Call call, ffi::Optional opt_op = ffi::Optional(), + ffi::Optional> opt_args = ffi::Optional>(), + ffi::Optional opt_attrs = ffi::Optional(), + ffi::Optional> opt_sinfo_args = ffi::Optional>(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief Tuple container */ class TupleNode : public ExprNode { public: /*! \brief the fields of the tuple */ - tvm::Array fields; + tvm::ffi::Array fields; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -228,15 +229,15 @@ class Tuple : public Expr { * \param fields The fields of a tuple. * \param span The source span of the expression. */ - TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + TVM_DLL explicit Tuple(tvm::ffi::Array fields, Span span = Span()); /*! * \brief Utility constructor to handle conversion to relax::Expr * * If the calling scope already has an array of a specific type of - * relax expression (e.g. `Array`), it must be converted + * relax expression (e.g. `ffi::Array`), it must be converted * into an array of base type. This constructor handles the - * conversion to the base `Array`. + * conversion to the base `ffi::Array`. * * \tparam RelaxExpr The type of relax expression passed in as an argument. * @@ -245,7 +246,7 @@ class Tuple : public Expr { * \param span The source span of the expression. */ template >> - TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()) + TVM_DLL explicit Tuple(tvm::ffi::Array fields, Span span = Span()) : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {} TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); @@ -257,8 +258,9 @@ class Tuple : public Expr { * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), - Optional opt_span = Optional()); +Tuple WithFields(Tuple tuple, + ffi::Optional> opt_fields = ffi::Optional>(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief Get index-th field out of a tuple. */ class TupleGetItemNode : public ExprNode { @@ -298,9 +300,10 @@ class TupleGetItem : public Expr { * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), - Optional opt_index = Optional(), - Optional opt_span = Optional()); +TupleGetItem WithFields(TupleGetItem tuple_get_item, + ffi::Optional opt_tuple = ffi::Optional(), + ffi::Optional opt_index = ffi::Optional(), + ffi::Optional opt_span = ffi::Optional()); /*! * \brief Base type of all (non-function) leaf Exprs. @@ -327,7 +330,7 @@ class LeafExpr : public Expr { class ShapeExprNode : public LeafExprNode { public: /*! The values of the shape expression. */ - Array values; + ffi::Array values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -340,7 +343,7 @@ class ShapeExprNode : public LeafExprNode { class ShapeExpr : public LeafExpr { public: - TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); + TVM_DLL explicit ShapeExpr(ffi::Array values, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); }; @@ -353,7 +356,7 @@ class VarNode : public LeafExprNode { Id vid; /*! \return The name hint of the variable */ - const String& name_hint() const { return vid->name_hint; } + const ffi::String& name_hint() const { return vid->name_hint; } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -386,11 +389,12 @@ class VarNode : public LeafExprNode { class Var : public LeafExpr { public: - TVM_DLL explicit Var(String name_hint, Optional struct_info_annotation, + TVM_DLL explicit Var(ffi::String name_hint, ffi::Optional struct_info_annotation, Span span = Span()) : Var(Id(name_hint), struct_info_annotation, span) {} - TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); + TVM_DLL explicit Var(Id vid, ffi::Optional struct_info_annotation, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); VarNode* CopyOnWrite(); @@ -413,11 +417,11 @@ class DataflowVarNode : public VarNode { class DataflowVar : public Var { public: - TVM_DLL explicit DataflowVar(String name_hint, Optional struct_info_annotation, - Span span = Span()) + TVM_DLL explicit DataflowVar(ffi::String name_hint, + ffi::Optional struct_info_annotation, Span span = Span()) : DataflowVar(Id(name_hint), struct_info_annotation, span) {} - TVM_DLL explicit DataflowVar(Id vid, Optional struct_info_annotation, + TVM_DLL explicit DataflowVar(Id vid, ffi::Optional struct_info_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); @@ -459,7 +463,7 @@ class Constant : public LeafExpr { * \param span The source span of the expression. */ TVM_DLL explicit Constant(runtime::Tensor data, - Optional struct_info_annotation = std::nullopt, + ffi::Optional struct_info_annotation = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); @@ -516,7 +520,7 @@ class PrimValue : public LeafExpr { class StringImmNode : public LeafExprNode { public: /*! \brief The data value. */ - String value; + ffi::String value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -538,7 +542,7 @@ class StringImm : public LeafExpr { * \param value The value input. * \param span The source span of the expression. */ - TVM_DLL explicit StringImm(String value, Span span = Span()); + TVM_DLL explicit StringImm(ffi::String value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); @@ -680,7 +684,7 @@ class VarBinding : public Binding { class BindingBlockNode : public Object { public: - Array bindings; + ffi::Array bindings; mutable Span span; static void RegisterReflection() { @@ -699,7 +703,7 @@ class BindingBlockNode : public Object { class BindingBlock : public ObjectRef { public: - TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); + TVM_DLL explicit BindingBlock(ffi::Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); BindingBlockNode* CopyOnWrite(); @@ -719,7 +723,7 @@ class DataflowBlockNode : public BindingBlockNode { class DataflowBlock : public BindingBlock { public: - TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); + TVM_DLL explicit DataflowBlock(ffi::Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); }; @@ -730,7 +734,7 @@ class DataflowBlock : public BindingBlock { */ class SeqExprNode : public ExprNode { public: - Array blocks; + ffi::Array blocks; Expr body; static void RegisterReflection() { @@ -760,7 +764,7 @@ class SeqExpr : public Expr { */ TVM_DLL SeqExpr(Expr body); // NOLINT(*) - TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); + TVM_DLL explicit SeqExpr(ffi::Array blocks, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); }; @@ -828,16 +832,16 @@ class If : public Expr { * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -If WithFields(If if_expr, Optional opt_cond = Optional(), - Optional opt_true_branch = Optional(), - Optional opt_false_branch = Optional(), - Optional opt_span = Optional()); +If WithFields(If if_expr, ffi::Optional opt_cond = ffi::Optional(), + ffi::Optional opt_true_branch = ffi::Optional(), + ffi::Optional opt_false_branch = ffi::Optional(), + ffi::Optional opt_span = ffi::Optional()); /*! \brief A Relax function. */ class FunctionNode : public BaseFuncNode { public: /*! \brief The parameters to the function. */ - Array params; + ffi::Array params; /*! \brief The body of the function. */ SeqExpr body; /*! \brief The return type of the function. */ @@ -882,14 +886,15 @@ class Function : public BaseFunc { * * \param span The source span of the expression. */ - TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); + TVM_DLL explicit Function(ffi::Array params, Expr body, + ffi::Optional ret_struct_info, bool is_pure = true, + DictAttrs attrs = DictAttrs(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. * \note ret_struct_info is required, since it can not deduced by the body. */ - TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, + TVM_DLL static Function CreateEmpty(ffi::Array params, StructInfo ret_struct_info, bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); @@ -932,7 +937,7 @@ constexpr const char* kNumInput = "num_input"; class ExternFuncNode : public BaseFuncNode { public: /*! \brief The name of global symbol. */ - String global_symbol; + ffi::String global_symbol; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -945,8 +950,8 @@ class ExternFuncNode : public BaseFuncNode { class ExternFunc : public BaseFunc { public: - TVM_DLL ExternFunc(String global_symbol, Span span = Span()); - TVM_DLL ExternFunc(String global_symbol, StructInfo struct_info, Span span = Span()); + TVM_DLL ExternFunc(ffi::String global_symbol, Span span = Span()); + TVM_DLL ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 7634bc34a26f..afacb81e4072 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -379,7 +379,7 @@ class ExprMutatorBase : public ExprFunctor { */ bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { if (const StructInfoNode* sinfo = struct_info.as()) { - return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + return this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)).same_as(struct_info); } else { return true; } @@ -421,7 +421,7 @@ class ExprMutator : public ExprMutatorBase { public: using ExprMutatorBase::VisitExpr_; - ExprMutator(Optional mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); } + ExprMutator(ffi::Optional mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const DataflowVarNode* op) override; @@ -502,7 +502,8 @@ class ExprMutator : public ExprMutatorBase { * * \note The body_expr must be an SeqExpr in the normal form. */ - Expr VisitWithNewScope(const Expr& body_expr, Optional> params = std::nullopt); + Expr VisitWithNewScope(const Expr& body_expr, + ffi::Optional> params = std::nullopt); /*! * \brief Rewrite the expr with a new scope, used in the branches of If. @@ -526,7 +527,7 @@ class ExprMutator : public ExprMutatorBase { * \return The value bound to the input \p var. * \note For function parameters, this function returns std::nullopt. */ - Optional LookupBinding(const Var& var); + ffi::Optional LookupBinding(const Var& var); /*! * \brief Post-order rewrite a node and normalize. diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 8620ad80bda7..aac3175d72df 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -140,20 +140,20 @@ class NestedMsg { data_ = std::move(other); return *this; } - // Array> handling - NestedMsg(Array, void> other) // NOLINT(*) + // ffi::Array> handling + NestedMsg(ffi::Array, void> other) // NOLINT(*) : data_(other) {} - NestedMsg& operator=(Array, void> other) { + NestedMsg& operator=(ffi::Array, void> other) { data_ = std::move(other); return *this; } // initializer list handling NestedMsg(std::initializer_list> other) // NOLINT(*) - : NestedMsg(Array, void>(other)) {} + : NestedMsg(ffi::Array, void>(other)) {} NestedMsg& operator=(std::initializer_list> other) { - return operator=(Array, void>(other)); + return operator=(ffi::Array, void>(other)); } // delete the int constructor @@ -190,8 +190,9 @@ class NestedMsg { * \return a corresponding nested array. * \note This checks if the underlying data type is array. */ - Array, void> NestedArray() const { - return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck, void>>(data_); + ffi::Array, void> NestedArray() const { + return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck, void>>( + data_); } private: @@ -238,8 +239,8 @@ bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue()); } else { if (!rhs.IsNested()) return false; - Array> arr_lhs = lhs.NestedArray(); - Array> arr_rhs = rhs.NestedArray(); + ffi::Array> arr_lhs = lhs.NestedArray(); + ffi::Array> arr_rhs = rhs.NestedArray(); if (arr_lhs.size() != arr_rhs.size()) return false; for (size_t i = 0; i < arr_lhs.size(); ++i) { if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false; @@ -264,7 +265,7 @@ bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { template NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { if (auto* tuple = expr.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (Expr x : tuple->fields) { res.push_back(MapToNestedMsg(x, fmapleaf)); @@ -291,7 +292,7 @@ NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { template NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { if (auto* tuple = sinfo.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (StructInfo x : tuple->fields) { res.push_back(MapToNestedMsg(x, fmapleaf)); @@ -320,7 +321,7 @@ template NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { auto sinfo = GetStructInfo(expr); if (auto* tuple = sinfo.as()) { - Array> res; + ffi::Array> res; res.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { Expr field; @@ -346,9 +347,9 @@ NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { * * \param msg The input nested message. * \param fmapleaf The mapping function for each leaf with signature - * `TargetType fmapleaf(Optional)`. + * `TargetType fmapleaf(ffi::Optional)`. * \param fcombine The function for combining all childs of a node into TargetType with signature - * `TargetType fmapleaf(Array)`. + * `TargetType fmapleaf(ffi::Array)`. * \tparam TargetType the target type to map nested msg to. * \tparam T the content type of nested msg. * \tparam FMapLeaf The leaf mapping function type. @@ -362,8 +363,8 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { return fmapleaf(msg.LeafValue()); } else { ICHECK(msg.IsNested()); - Array> arr = msg.NestedArray(); - Array subexpr; + ffi::Array> arr = msg.NestedArray(); + ffi::Array subexpr; subexpr.reserve(arr.size()); for (size_t i = 0; i < arr.size(); ++i) { subexpr.push_back(NestedMsgTo(arr[i], fmapleaf, fcombine)); @@ -380,14 +381,14 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { * then recursively combines the results as tuple expr. * * \param msg The input nested message. - * \param fmapleaf The mapping function for each leaf with signature `Expr fmapleaf(Optional)`. - * \tparam T the content type of nested msg. - * \tparam FType The mapping function type. + * \param fmapleaf The mapping function for each leaf with signature `Expr + * fmapleaf(ffi::Optional)`. \tparam T the content type of nested msg. \tparam FType The mapping + * function type. */ template Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { - return NestedMsgTo(msg, fmapleaf, [](Array arr) { - Optional simplified_tuple; + return NestedMsgTo(msg, fmapleaf, [](ffi::Array arr) { + ffi::Optional simplified_tuple; bool simplified_flag = false; if (arr.size() >= 1) { simplified_flag = true; @@ -436,11 +437,11 @@ NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine } else { ICHECK(lhs.IsNested()); ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; - Array> arr_lhs = lhs.NestedArray(); - Array> arr_rhs = rhs.NestedArray(); + ffi::Array> arr_lhs = lhs.NestedArray(); + ffi::Array> arr_rhs = rhs.NestedArray(); ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) << "Cannot combine two nested array with different sizes"; - Array> res; + ffi::Array> res; res.reserve(arr_lhs.size()); for (size_t i = 0; i < arr_lhs.size(); ++i) { res.push_back(CombineNestedMsg(arr_lhs[i], arr_rhs[i], fcombine)); @@ -465,8 +466,8 @@ NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { return fmapleaf(msg.LeafValue()); } else { ICHECK(msg.IsNested()); - Array> arr = msg.NestedArray(); - Array> res; + ffi::Array> arr = msg.NestedArray(); + ffi::Array> res; res.reserve(arr.size()); for (int i = 0; i < static_cast(arr.size()); ++i) { res.push_back(MapNestedMsg(arr[i], fmapleaf)); @@ -492,7 +493,7 @@ template void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { if (auto* tuple = expr.as()) { ICHECK(msg.IsNested()) << "Expected nested to match tuple"; - Array> arr = msg.NestedArray(); + ffi::Array> arr = msg.NestedArray(); ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; for (size_t i = 0; i < arr.size(); ++i) { DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); @@ -511,7 +512,7 @@ void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { * * \param expr The input expression to be transform.  * \param msgs The input messages to guide the transformation. - * \param ftransleaf with signature ftransleaf(Expr, Array>)->Expr + * \param ftransleaf with signature ftransleaf(Expr, ffi::Array>)->Expr * \tparam T the content type of nested msg * \tparam N the number of messages * \tparam FType The visit function type. @@ -520,13 +521,13 @@ template Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { StructInfo sinfo = GetStructInfo(expr); if (const auto* tuple = sinfo.as()) { - std::array>, N> msg_arrays; + std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; - Array fields; + ffi::Array fields; fields.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { Expr field; @@ -560,7 +561,7 @@ Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftran * * \param sinfo The input sinfo to be transform.  * \param msgs The input messages to guide the transformation. - * \param ftransleaf with signature ftransleaf(StructInfo, Array>)->StructInfo + * \param ftransleaf with signature ftransleaf(StructInfo, ffi::Array>)->StructInfo * \tparam T the content type of nested msg * \tparam N the number of messages * \tparam FType The visit function type. @@ -569,13 +570,13 @@ template StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs, FType ftransleaf) { if (const auto* tuple = sinfo.as()) { - std::array>, N> msg_arrays; + std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; - Array fields; + ffi::Array fields; fields.reserve(tuple->fields.size()); for (size_t i = 0; i < tuple->fields.size(); ++i) { StructInfo field = tuple->fields[i]; @@ -654,7 +655,7 @@ struct TypeTraits> : public TypeTraitsBase { } if (src->type_index == TypeIndex::kTVMFFIArray) { const ArrayObj* n = reinterpret_cast(src->v_obj); - Array> result; + ffi::Array> result; result.reserve(n->size()); for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index bd9c59da3acb..2e686035b20c 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -65,7 +65,7 @@ using FInferStructInfo = ffi::TypedFunction( +using FPrimalGradient = ffi::TypedFunction( const Var& orig_var, const Call& orig_call, const Var& output_grad, const BlockBuilder& ctx)>; } // namespace relax diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index a897f031a289..8a97658330df 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -61,7 +61,7 @@ class ObjectStructInfo : public StructInfo { class PrimStructInfoNode : public StructInfoNode { public: /*! \brief Underlying primitive value, if known */ - Optional value; + ffi::Optional value; /*! \brief Underlying data type of the primitive value */ DataType dtype; @@ -98,7 +98,7 @@ class PrimStructInfo : public StructInfo { class ShapeStructInfoNode : public StructInfoNode { public: /*! \brief optionally stores the symbolic value patterns of the shape */ - Optional> values; + ffi::Optional> values; /*! * \brief The number of dimension of the shape, can be unknown. * \sa kUnknownNDim @@ -130,7 +130,7 @@ class ShapeStructInfo : public StructInfo { * \param values The symbolic shape values * \param span The span of the AST. */ - TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + TVM_DLL ShapeStructInfo(ffi::Array values, Span span = Span()); /*! * \brief Construction with known unknown symbolic shape patterns. * \param ndim Number of dimensions -- can be kUnknownNDim @@ -150,11 +150,11 @@ class TensorStructInfoNode : public StructInfoNode { * \brief optionally store the shape expression of the tensor. * \note shape must be normalized: it can only be std::nullopt or ShapeExpr or Var. */ - Optional shape; + ffi::Optional shape; /*! \brief The virtual device, indicates where the tensor * is expected to be executed. */ - Optional vdevice; + ffi::Optional vdevice; /*! \brief The content data type, use void to denote the dtype is unknown. */ DataType dtype; /*! @@ -170,7 +170,7 @@ class TensorStructInfoNode : public StructInfoNode { bool IsUnknownDtype() const { return dtype.is_void(); } /*! \return Shape if it is known. */ - Optional> GetShape() const { + ffi::Optional> GetShape() const { if (!shape.defined()) return {}; ShapeStructInfo shape_sinfo = Downcast(this->shape.value()->struct_info_); return shape_sinfo->values; @@ -204,8 +204,8 @@ class TensorStructInfo : public StructInfo { * * \note shape must already be normalized. */ - TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Optional vdevice = std::nullopt, - Span span = Span()); + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, + ffi::Optional vdevice = std::nullopt, Span span = Span()); /*! * \brief Construction with an unknown shape expression. @@ -214,7 +214,7 @@ class TensorStructInfo : public StructInfo { * \param vdevice The virtual device. * \param span The span of the AST. */ - TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional vdevice = std::nullopt, + TVM_DLL TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); @@ -226,7 +226,7 @@ class TensorStructInfo : public StructInfo { class TupleStructInfoNode : public StructInfoNode { public: /*! \brief The struct info of tuple fields. */ - Array fields; + ffi::Array fields; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -248,7 +248,7 @@ class TupleStructInfo : public StructInfo { * \param fields Struct info of tuple fields. * \param span The span of the AST. */ - TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + TVM_DLL TupleStructInfo(ffi::Array fields, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); }; @@ -274,7 +274,7 @@ class FuncStructInfoNode : public StructInfoNode { * \note When params is std::nullopt means the function can take arbitrary number of arguments. * We define such functions as Opaque function. */ - Optional> params; + ffi::Optional> params; /*! * \brief The struct info of the function's return value. */ @@ -284,7 +284,7 @@ class FuncStructInfoNode : public StructInfoNode { * \note When derive_func is not empty, then params should be std::nullopt, * ret should be ObjectStructInfo() */ - Optional derive_func; + ffi::Optional derive_func; /*! * \brief Whether the function is pure. * \note This parameter should be set to true only if the function is pure on all inputs. @@ -327,7 +327,7 @@ class FuncStructInfo : public StructInfo { * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from * params. If you are unsure, you can always erase ret to static. */ - TVM_DLL FuncStructInfo(Array params, StructInfo ret, bool purity = true, + TVM_DLL FuncStructInfo(ffi::Array params, StructInfo ret, bool purity = true, Span span = Span()); /*! @@ -369,10 +369,10 @@ class FuncStructInfo : public StructInfo { * \tparam T the underlying structure info type */ template -inline Optional MatchStructInfo(const Expr& expr) { +inline ffi::Optional MatchStructInfo(const Expr& expr) { using TNode = typename T::ContainerType; if (const TNode* ptr = expr->struct_info_.as()) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -401,7 +401,7 @@ inline const T* GetStructInfoAs(const Expr& expr) { inline StructInfo GetStructInfo(const Expr& expr) { auto* ptr = expr->struct_info_.as(); ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; - return GetRef(ptr); + return ffi::GetRef(ptr); } /*! diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 1397bafc36ff..695a509bddd5 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -41,9 +41,9 @@ class MatchResultNode : public Object { /*! The matched tir pattern*/ TIRPattern pattern; /*! \brief The evaluated values of symbolic vars. */ - Array symbol_values; + ffi::Array symbol_values; /*! \brief The matched buffers of input and output. */ - Array matched_buffers; + ffi::Array matched_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -68,13 +68,13 @@ class MatchResult : public ObjectRef { * \param symbol_values The evaluated values of symbolic vars. * \param matched_buffers The matched buffers of input and output. */ - TVM_DLL explicit MatchResult(TIRPattern pattern, Array symbol_values, - Array matched_buffers); + TVM_DLL explicit MatchResult(TIRPattern pattern, ffi::Array symbol_values, + ffi::Array matched_buffers); TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode); }; -using FCodegen = ffi::TypedFunction(Array match_results)>; +using FCodegen = ffi::TypedFunction(ffi::Array match_results)>; } // namespace relax } // namespace tvm #endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 1567294a4b38..ba3a41fa63fb 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -54,8 +54,8 @@ using tvm::transform::CreateModulePass; * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(std::function pass_func, - int opt_level, String name, tvm::Array required, - bool traceable = false); + int opt_level, ffi::String name, + tvm::ffi::Array required, bool traceable = false); /*! * \brief Create a dataflowblock pass. @@ -70,7 +70,7 @@ TVM_DLL Pass CreateFunctionPass(std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable = false); + ffi::String name, tvm::ffi::Array required, bool traceable = false); /*! * \brief Perform lambda lifting to lift functions from nested into global. @@ -196,7 +196,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false); * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map params); +TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map params); /*! * \brief Bind symbolic vars to constant shape values. @@ -213,8 +213,8 @@ TVM_DLL Pass BindParams(String func_name, Map params); * * \return The Pass. */ -TVM_DLL Pass BindSymbolicVars(Map, PrimExpr> binding_map, - Optional func_name = std::nullopt); +TVM_DLL Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, + ffi::Optional func_name = std::nullopt); /*! * \brief Fold constant expressions within dataflow blocks. @@ -248,7 +248,8 @@ TVM_DLL Pass FoldConstant(); * showing up in the database. * \return The Pass. */ -TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); +TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, + bool enable_warning = false); /*! * \brief Propagate virtual device information. @@ -303,7 +304,8 @@ TVM_DLL Pass SplitLayoutRewritePreproc(); * * \return The Pass. */ -TVM_DLL Pass LiftTransformParams(Variant> shared_transform = Bool(false)); +TVM_DLL Pass +LiftTransformParams(ffi::Variant> shared_transform = Bool(false)); /*! * \brief Update virtual device. @@ -364,7 +366,7 @@ class FusionPatternNode : public Object { * \brief The name of pattern. It becomes the value of the kComposite attribute * of a fused function after successful matching */ - String name; + ffi::String name; /*! * \brief The dataflow pattern that will be used to match expression in the DataflowBlock. @@ -376,7 +378,7 @@ class FusionPatternNode : public Object { * \brief The map which is used to extract important expressions from the pattern match * result. All DFPattern in this map should be part of the `pattern`. */ - Map annotation_patterns; + ffi::Map annotation_patterns; /*! * \brief The function to determine whether the match result is accepted. This can be @@ -385,15 +387,15 @@ class FusionPatternNode : public Object { * It should have signature * bool(const PatternCheckContext& context) */ - Optional check; + ffi::Optional check; /*! * \brief The function to get attributes for fused function * * It should have signature - * Map(const Map& context) + * ffi::Map(const ffi::Map& context) */ - Optional attrs_getter; + ffi::Optional attrs_getter; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -411,10 +413,11 @@ class FusionPatternNode : public Object { class FusionPattern : public ObjectRef { public: - FusionPattern(String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter); + FusionPattern(ffi::String name, DFPattern pattern, + ffi::Map annotation_patterns, + ffi::Optional check, ffi::Optional attrs_getter); - FusionPattern(String name, DFPattern pattern) + FusionPattern(ffi::String name, DFPattern pattern) : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {} TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode); @@ -434,25 +437,25 @@ class PatternCheckContextNode : public Object { * \brief A map which contains all expressions matched by the sub patterns in * FusionPattern::annotation_patterns. */ - Map annotated_expr; + ffi::Map annotated_expr; /*! * \brief Map from variable to its value. It contains variables from bindings that * is being fused by FuseOpsByPattern. */ - Map matched_bindings; + ffi::Map matched_bindings; /*! * \brief A map mapping variable definitions to a set of uses. It has all variables * used in the function. */ - Map> var_usages; + ffi::Map> var_usages; /*! * \brief Map from value to its bound variable. It doesn't have variables after the * matched expression. */ - Map value_to_bound_var; + ffi::Map value_to_bound_var; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -470,9 +473,10 @@ class PatternCheckContextNode : public Object { class PatternCheckContext : public ObjectRef { public: - PatternCheckContext(Expr matched_expr, Map annotated_expr, - Map matched_bindings, Map> var_usages, - Map value_to_bound_var); + PatternCheckContext(Expr matched_expr, ffi::Map annotated_expr, + ffi::Map matched_bindings, + ffi::Map> var_usages, + ffi::Map value_to_bound_var); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, PatternCheckContextNode); @@ -503,7 +507,8 @@ class PatternCheckContext : public ObjectRef { * * \note ConvertToDataflow may need to be called first to provide dataflow blocks. */ -TVM_DLL Pass Gradient(String func_name, Optional> require_grads = std::nullopt, +TVM_DLL Pass Gradient(ffi::String func_name, + ffi::Optional> require_grads = std::nullopt, int target_index = 0); /*! @@ -526,9 +531,9 @@ TVM_DLL Pass Gradient(String func_name, Optional> require_grads = std * * \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants = true, - bool annotate_codegen = false, - const tvm::Array& entry_function_names = {}); +TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array& patterns, + bool bind_constants = true, bool annotate_codegen = false, + const tvm::ffi::Array& entry_function_names = {}); /*! * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new @@ -553,8 +558,9 @@ TVM_DLL Pass FuseTIR(); * \param entry_functions list of entry functions * \return The Pass. */ -TVM_DLL Pass RunCodegen(Optional>> target_options, - Array entry_functions); +TVM_DLL Pass +RunCodegen(ffi::Optional>> target_options, + ffi::Array entry_functions); /*! * \brief Decompose composite operators during inference. For example, The result of batch norm (a @@ -564,7 +570,7 @@ TVM_DLL Pass RunCodegen(Optional>> target_opti * \param func_name The name of the specified function. If not specified, the pass will run in * all functions. */ -TVM_DLL Pass DecomposeOpsForInference(Optional func_name); +TVM_DLL Pass DecomposeOpsForInference(ffi::Optional func_name); /*! * \brief Decompose composite operators during training. For example, The result of batch norm (a @@ -574,7 +580,7 @@ TVM_DLL Pass DecomposeOpsForInference(Optional func_name); * \param func_name The name of the specified function. If not specified, the pass will run in * all functions. */ -TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); +TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional func_name); /*! * \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p @@ -590,10 +596,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); * \param input_axis_separators Map from kOperatorName attr to axis_separator for input buffer * \return The Pass. */ -TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, - const Map>& op_buffer_transforms, - const Map>>>& axis_separators, - const Map>>>& input_axis_separators); +TVM_DLL Pass AlterOpImpl( + const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms, + const ffi::Map>>>& axis_separators, + const ffi::Map>>>& + input_axis_separators); /*! * \brief Layout conversion pass. @@ -601,7 +609,7 @@ TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, * \return The Pass. * \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ConvertLayout(Map> desired_layouts); +TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts); /*! * \brief A pass that converts consecutive dataflow operations @@ -628,7 +636,7 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); * * \return The Pass. */ -TVM_DLL Pass DeadCodeElimination(Array entry_functions = {}); +TVM_DLL Pass DeadCodeElimination(ffi::Array entry_functions = {}); /*! * \brief Pass that changes calls to operators that can be done in-place @@ -651,8 +659,9 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype, - Optional> fp16_input_names = std::nullopt); +TVM_DLL Pass +ToMixedPrecision(const DataType& out_dtype, + ffi::Optional> fp16_input_names = std::nullopt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index e48c1856f9fe..70ecbe4855ac 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -47,15 +47,15 @@ namespace relax { * * \return The updated expression. */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds, - const tvm::Map& symbolic_var_map = {}); +TVM_DLL Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, + const tvm::ffi::Map& symbolic_var_map = {}); /*! * \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by * other pass functions to help optimizations. */ TVM_DLL StructInfo Bind(const StructInfo& sinfo, - const tvm::Map& symbolic_var_map); + const tvm::ffi::Map& symbolic_var_map); /*! * \brief Infer a binding map for symbolic variables @@ -74,8 +74,8 @@ TVM_DLL StructInfo Bind(const StructInfo& sinfo, * * \return A map of TIR variables to TIR expressions */ -TVM_DLL tvm::Map InferSymbolicVarMap( - const tvm::Map& binds, arith::Analyzer* analyzer); +TVM_DLL tvm::ffi::Map InferSymbolicVarMap( + const tvm::ffi::Map& binds, arith::Analyzer* analyzer); /*! * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean diff --git a/include/tvm/runtime/contrib/papi.h b/include/tvm/runtime/contrib/papi.h index 93c1aa274bfd..551f66726473 100644 --- a/include/tvm/runtime/contrib/papi.h +++ b/include/tvm/runtime/contrib/papi.h @@ -38,7 +38,8 @@ namespace profiling { * collected on that device. You can find the names of available metrics by * running `papi_native_avail`. */ -TVM_DLL MetricCollector CreatePAPIMetricCollector(Map> metrics); +TVM_DLL MetricCollector +CreatePAPIMetricCollector(ffi::Map> metrics); } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index acd4a214ff7b..ae119e52652b 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -62,7 +62,7 @@ inline std::string ReduceKind2String(ReduceKind kind) { * \param device The default device used to initialize the RelaxVM * \return The RelaxVM as a runtime Module */ -TVM_DLL ffi::Module LoadVMModule(std::string path, Optional device); +TVM_DLL ffi::Module LoadVMModule(std::string path, ffi::Optional device); /*! * \brief Create an uninitialized empty Tensor * \param shape The shape of the Tensor @@ -70,7 +70,7 @@ TVM_DLL ffi::Module LoadVMModule(std::string path, Optional device); * \param device The device the Tensor is created on. If None, use the thread local default device * \return The Tensor created */ -TVM_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, Optional device); +TVM_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device); /*! * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on @@ -100,7 +100,7 @@ TVM_DLL void BroadcastFromWorker0(Tensor send, bool in_group, Tensor recv); * \param in_group Whether the scatter operation performs globally or in group as default. * \param recv The receiving buffer, which must not be None. */ -TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, Tensor recv); +TVM_DLL void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv); /*! * \brief Perform a gather operation to worker-0. * \param send The sending buffer, which must not be None. @@ -108,7 +108,7 @@ TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, Tensor rec * \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The * receiving buffer will be divided into equal parts and receive from each worker accordingly. */ -TVM_DLL void GatherToWorker0(Tensor send, bool in_group, Optional recv); +TVM_DLL void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv); /*! * \brief Receive a buffer from worker-0. No-op if the current worker is worker-0. * \param buffer The buffer to be received diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 078c061b7b82..464efb59c01b 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -79,7 +79,7 @@ class DiscoWorker { /*! \brief The default device to allocate data if not specified */ Device default_device; /*! \brief The name of the underlying collective communication library. */ - String ccl; + ffi::String ccl; /*! * \brief The data shared between worker-0 and the controler. It's a nullptr if * the worker is not worker-0. diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 72ac577d52d4..1506d2548f1f 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -235,7 +235,7 @@ class SessionObj : public Object { * \param ccl The name of the communication backend, e.g., nccl, rccl, mpi. * \param device_ids The device ids of the workers. */ - TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0; + TVM_DLL virtual void InitCCL(ffi::String ccl, IntTuple device_ids) = 0; /*! * \brief Get the value of a register from a remote worker. * \param reg_id The id of the register to be fetched. @@ -287,7 +287,7 @@ class Session : public ObjectRef { * worker-0 does not exist in the process pool. */ TVM_DLL static Session ProcessSession(int num_workers, int num_groups, - String process_pool_creator, String entrypoint); + ffi::String process_pool_creator, ffi::String entrypoint); TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj); }; diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index a10bc6b36e04..52a91d63c66c 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -67,7 +67,7 @@ class Allocator { * \return The empty Tensor. */ TVM_DLL Tensor Empty(ffi::Shape shape, DLDataType dtype, Device dev, - Optional mem_scope = std::nullopt); + ffi::Optional mem_scope = std::nullopt); /*! \brief Return the allocator type. */ inline AllocatorType type() const { return type_; } /*! \brief Allocate a buffer given a size, alignment and type. @@ -168,7 +168,7 @@ class StorageObj : public Object { /*! \brief Allocate an Tensor with memory scope from a given piece of storage. */ TVM_DLL Tensor AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, - String scope = "global"); + ffi::String scope = "global"); ~StorageObj() { if (allocator) { diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index f805ec988d37..1e0e7039448b 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -45,7 +45,7 @@ namespace runtime { * \param target The target module name. * \return Whether runtime is enabled. */ -TVM_DLL bool RuntimeEnabled(const String& target); +TVM_DLL bool RuntimeEnabled(const ffi::String& target); /*! \brief namespace for constant symbols */ namespace symbol { @@ -105,11 +105,11 @@ struct ModuleVTableEntryHelper { } // namespace runtime } // namespace tvm -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* kind() const final { return TypeKey; } \ - ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const String& _name) override { \ - using SelfPtr = std::remove_cv_t; \ - ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* kind() const final { return TypeKey; } \ + ::tvm::ffi::Optional<::tvm::ffi::Function> GetFunction(const ffi::String& _name) override { \ + using SelfPtr = std::remove_cv_t; \ + ::tvm::ffi::ObjectPtr<::tvm::ffi::Object> _self = \ ::tvm::ffi::GetObjectPtr<::tvm::ffi::Object>(this); #define TVM_MODULE_VTABLE_END() \ return std::nullopt; \ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 9da9467e8ff2..e04a800400f1 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -106,18 +106,18 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - static_assert(ObjectName::_type_final, \ - "TVM's CopyOnWrite may only be used for " \ - "Object types that are declared as final, " \ - "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ - ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ + ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = ::tvm::ffi::make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ } /* diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 88a22c981652..43bb2f25ce20 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -137,7 +137,7 @@ class Timer : public ObjectRef { * TVM_FFI_STATIC_INIT_BLOCK({ * namespace refl = tvm::ffi::reflection; * refl::GlobalDef().def("profiling.timer.cpu", [](Device dev) { - * return Timer(make_object()); + * return Timer(ffi::make_object()); * }); * }); * \endcode @@ -174,7 +174,7 @@ struct DeviceWrapperNode : public Object { /*! \brief Wrapper for `Device`. */ class DeviceWrapper : public ObjectRef { public: - explicit DeviceWrapper(Device dev) { data_ = make_object(dev); } + explicit DeviceWrapper(Device dev) { data_ = ffi::make_object(dev); } TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DeviceWrapper, ObjectRef, DeviceWrapperNode); }; @@ -189,7 +189,7 @@ class ReportNode : public Object { * and "Duration (us)". Values are one of `String`, `PercentNode`, * `DurationNode`, or `CountNode`. */ - Array> calls; + ffi::Array> calls; /*! \brief Metrics collected for the entire run of the model on a per-device basis. * * `device_metrics` is indexed by device name then metric. @@ -197,17 +197,17 @@ class ReportNode : public Object { * These metrics may be larger than the sum of the same metric in `calls` * because these metrics include the overhead of the executor. */ - Map> device_metrics; + ffi::Map> device_metrics; /*! Configuration used for this profiling run. Includes number of threads, executor. * * Values must be an object type that can be used with device_metrics. */ - Map configuration; + ffi::Map configuration; /*! \brief Output `calls` in CSV format. * * Note that this does not include `device_metrics`, it only includes per-call metrics. */ - String AsCSV() const; + ffi::String AsCSV() const; /*! \brief Create a human readable table of profiling metrics. * * \param aggregate Whether or not to join multiple calls to the @@ -222,7 +222,7 @@ class ReportNode : public Object { * the Count, Duation, and Percent columns. * */ - String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; + ffi::String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; /*! \brief Convert this report to JSON. * * Output JSON will be of this format: @@ -255,7 +255,7 @@ class ReportNode : public Object { * } * \endcode */ - String AsJSON() const; + ffi::String AsJSON() const; static constexpr const char* _type_key = "runtime.profiling.Report"; TVM_DECLARE_FINAL_OBJECT_INFO(ReportNode, Object); @@ -268,15 +268,15 @@ class Report : public ObjectRef { * \param device_metrics Per-device metrics for overall execution. * \param configuration Configuration data specific to this profiling run. */ - explicit Report(Array> calls, - Map> device_metrics, - Map configuration); + explicit Report(ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration); /*! Deserialize a Report from a JSON object. Needed for sending the report over RPC. * \param json Serialized json report from `ReportNode::AsJSON`. * \returns A Report. */ - static Report FromJSON(String json); + static Report FromJSON(ffi::String json); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode); }; @@ -304,7 +304,7 @@ class MetricCollectorNode : public Object { * expensive precomputation should happen here. * \param devs The list of devices this collector will be run on. */ - virtual void Init(Array devs) = 0; + virtual void Init(ffi::Array devs) = 0; /*! \brief Start colling metrics for a function call. * \param dev The device the call will be run on. * \returns An object used to maintain state of the metric collection. This @@ -317,7 +317,7 @@ class MetricCollectorNode : public Object { * \returns A set of metric names and the associated values. Values must be * one of DurationNode, PercentNode, CountNode, or String. */ - virtual Map Stop(ffi::ObjectRef obj) = 0; + virtual ffi::Map Stop(ffi::ObjectRef obj) = 0; virtual ~MetricCollectorNode() {} @@ -336,7 +336,7 @@ struct CallFrame { /*! Device on which the call was made */ Device dev; /*! Name of the function or op */ - String name; + ffi::String name; /*! Runtime of the function or op */ Timer timer; /*! Extra performance metrics */ @@ -382,7 +382,7 @@ class Profiler { * \param configuration Additional configuration data to add to the outputted profiling report. */ explicit Profiler(std::vector devs, std::vector metric_collectors, - std::unordered_map configuration = {}); + std::unordered_map configuration = {}); /*! \brief Start the profiler. * * This function should only be called once per object. @@ -403,7 +403,7 @@ class Profiler { * `StopCall`. Function calls are stopped in LIFO order, so calls to * `StartCall` and `StopCall` must be nested properly. */ - void StartCall(String name, Device dev, + void StartCall(ffi::String name, Device dev, std::unordered_map extra_metrics = {}); /*! \brief Stop the last `StartCall`. * \param extra_metrics Optional additional profiling information to add to @@ -427,7 +427,7 @@ class Profiler { std::vector calls_; std::stack in_flight_; std::vector collectors_; - std::unordered_map configuration_; + std::unordered_map configuration_; }; /* \brief A duration in time. */ @@ -490,23 +490,23 @@ class RatioNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(RatioNode, Object); }; -/*! \brief String representation of an array of Tensor shapes +/*! \brief ffi::String representation of an array of Tensor shapes * \param shapes Array of Tensors to get the shapes of. * \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`. */ -String ShapeString(const std::vector& shapes); -/*! \brief String representation of shape encoded as an Tensor +ffi::String ShapeString(const std::vector& shapes); +/*! \brief ffi::String representation of shape encoded as an Tensor * \param shape Tensor containing the shape. * \param dtype The dtype of the shape. * \return A textual representation of the shape. For example: `float32[2]`. */ -String ShapeString(Tensor shape, DLDataType dtype); -/*! \brief String representation of a shape encoded as a vector +ffi::String ShapeString(Tensor shape, DLDataType dtype); +/*! \brief ffi::String representation of a shape encoded as a vector * \param shape Shape as a vector of integers. * \param dtype The dtype of the shape. * \return A textual representation of the shape. For example: `float32[2]`. */ -String ShapeString(const std::vector& shape, DLDataType dtype); +ffi::String ShapeString(const std::vector& shape, DLDataType dtype); /*! \brief Collect performance information of a function execution. Usually * used with a compiled PrimFunc (via tvm.compile). @@ -536,11 +536,12 @@ String ShapeString(const std::vector& shape, DLDataType dtype); * \param collectors List of different * ways to collect metrics. See MetricCollector. * \returns A ffi::Function which takes the same arguments as the `mod[func_name]` - * and returns performance metrics as a `Map` where + * and returns performance metrics as a `ffi::Map` where * values can be `CountNode`, `DurationNode`, `PercentNode`. */ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, - int device_id, int warmup_iters, Array collectors); + int device_id, int warmup_iters, + ffi::Array collectors); /*! * \brief Wrap a timer function to measure the time cost of a given packed function. diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index 9536dd2005c5..71f8d27be008 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -112,7 +112,8 @@ class Tensor : public tvm::ffi::Tensor { * \return The array under another device. * \note The copy always triggers a TVMSynchronize. */ - TVM_DLL Tensor CopyTo(const Device& dev, Optional mem_scope = std::nullopt) const; + TVM_DLL Tensor CopyTo(const Device& dev, + ffi::Optional mem_scope = std::nullopt) const; /*! * \brief Load Tensor from stream * \param stream The input data stream @@ -156,7 +157,7 @@ class Tensor : public tvm::ffi::Tensor { * \return The created Array */ TVM_DLL static Tensor Empty(ffi::Shape shape, DLDataType dtype, Device dev, - Optional mem_scope = std::nullopt); + ffi::Optional mem_scope = std::nullopt); /*! * \brief Function to copy data from one array to another. * \param from The source array. diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 6dfc2b0c50be..37488ff31f52 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -113,12 +113,12 @@ class VMExecutable : public ffi::ModuleObj { * \brief Print the instructions as text format. * \return The text format of the instructions. */ - String AsText() const; + ffi::String AsText() const; /*! * \brief Print the instructions as python program. * \return The python program of the instructions, represented by a string. */ - String AsPython() const; + ffi::String AsPython() const; /*! * \brief Write the VMExecutable to the binary stream in serialized form. * \return The binary bytes that save the executable to. @@ -135,19 +135,19 @@ class VMExecutable : public ffi::ModuleObj { * \param file_name The name of the file to write the serialized data to. * \param format The target format of the saved file. */ - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; /*! \brief Create a Relax virtual machine and load `this` as the executable. */ ffi::Module VMLoadExecutable() const; /*! \brief Create a Relax virtual machine with profiler and load `this` as the executable. */ ffi::Module VMProfilerLoadExecutable() const; /*! \brief Check if the VMExecutable contains a specific function. */ - bool HasFunction(const String& name) const; + bool HasFunction(const ffi::String& name) const; /*! * \brief Load VMExecutable from the file. * \param file_name The path of the file that load the executable from. * \return The loaded executable, in the form of a `runtime::Module`. */ - static ffi::Module LoadFromFile(const String& file_name); + static ffi::Module LoadFromFile(const ffi::String& file_name); /*! \brief The virtual machine's function table. */ std::vector func_table; diff --git a/include/tvm/runtime/vm/tensor_cache_support.h b/include/tvm/runtime/vm/tensor_cache_support.h index d2112cc83f4e..c489064792e7 100644 --- a/include/tvm/runtime/vm/tensor_cache_support.h +++ b/include/tvm/runtime/vm/tensor_cache_support.h @@ -47,7 +47,7 @@ struct TensorCacheMetadata { * in other cases */ TVM_DLL Tensor Load(Device device, const std::string* raw_data, - Optional* staging_buffer = nullptr) const; + ffi::Optional* staging_buffer = nullptr) const; /*! \brief Name of the parameter */ std::string name; @@ -64,10 +64,10 @@ struct TensorCacheMetadata { }; /*! \brief Load a FileRecord into memory */ - TVM_DLL Array Load(Device device, // - const std::string& path_prefix, // - std::string* raw_data_buffer, // - Optional* staging_buffer = nullptr) const; + TVM_DLL ffi::Array Load(Device device, // + const std::string& path_prefix, // + std::string* raw_data_buffer, // + ffi::Optional* staging_buffer = nullptr) const; /*! \brief Relative path to the bin file */ std::string data_path; diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 3a0b7418b946..9fa894f61367 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -68,7 +68,7 @@ class VMClosureObj : public Object { * \brief The function name. The function could be any * function object that is compatible to the VM runtime. */ - String func_name; + ffi::String func_name; /*! * \brief The implementation of the Closure. @@ -85,7 +85,7 @@ class VMClosureObj : public Object { /*! \brief reference to closure. */ class VMClosure : public ObjectRef { public: - VMClosure(String func_name, ffi::Function impl); + VMClosure(ffi::String func_name, ffi::Function impl); TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, ObjectRef, VMClosureObj); /*! @@ -149,7 +149,7 @@ class VirtualMachine : public ffi::ModuleObj { * \param func_name The name of the function. * \return The closure */ - virtual VMClosure GetClosure(const String& func_name) = 0; + virtual VMClosure GetClosure(const ffi::String& func_name) = 0; /*! * \brief Invoke closure or packed function using ffi::Function convention. * \param closure_or_packedfunc A VM closure or a packed_func. diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 0c9e54eaf113..b2586e938719 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -157,9 +157,9 @@ class IRBuilderFrame : public runtime::ObjectRef { class IRBuilderNode : public runtime::Object { public: /*! \brief A stack of context frames in the IRBuilder */ - Array frames; + ffi::Array frames; /*! \brief The outcome of IR construction */ - Optional result; + ffi::Optional result; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -178,7 +178,7 @@ class IRBuilderNode : public runtime::Object { * \return The frame if found, otherwise std::nullopt. */ template - inline Optional FindFrame() const; + inline ffi::Optional FindFrame() const; /*! * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`. * \tparam TFrame The assumed type of the last frame on stack. @@ -186,7 +186,7 @@ class IRBuilderNode : public runtime::Object { * Otherwise std::nullopt. */ template - inline Optional GetLastFrame() const; + inline ffi::Optional GetLastFrame() const; /*! * \brief Get the IR being constructed. * \tparam TObjectRef The type of the IR being constructed. @@ -249,7 +249,7 @@ class IRBuilder : public runtime::ObjectRef { * \param obj The object to name. */ template - inline static TObjectRef Name(String name, TObjectRef obj); + inline static TObjectRef Name(ffi::String name, TObjectRef obj); }; ////////////////////////////// Details ////////////////////////////// @@ -258,32 +258,32 @@ namespace details { class Namer { public: - using FType = NodeFunctor; + using FType = NodeFunctor; static FType& vtable(); - static void Name(ObjectRef node, String name); + static void Name(ObjectRef node, ffi::String name); }; } // namespace details template -inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) { +inline TObjectRef IRBuilder::Name(ffi::String name, TObjectRef obj) { details::Namer::Name(obj, name); return Downcast(obj); } template -inline Optional IRBuilderNode::FindFrame() const { +inline ffi::Optional IRBuilderNode::FindFrame() const { using TFrameNode = typename TFrame::ContainerType; for (auto it = frames.rbegin(); it != frames.rend(); ++it) { if (const TFrameNode* p = (*it).template as()) { - return GetRef(p); + return ffi::GetRef(p); } } return std::nullopt; } template -inline Optional IRBuilderNode::GetLastFrame() const { +inline ffi::Optional IRBuilderNode::GetLastFrame() const { using TFrameNode = typename TFrame::ContainerType; if (!frames.empty() && frames.back()->IsInstance()) { return Downcast(frames.back()); @@ -297,7 +297,7 @@ inline TObjectRef IRBuilderNode::Get() const { CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; const auto* n = result.as(); CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key; - return GetRef(n); + return ffi::GetRef(n); } } // namespace ir_builder diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index b009338cf0d4..e9f98d4a8ea6 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -41,16 +41,16 @@ namespace ir { class IRModuleFrameNode : public IRBuilderFrameNode { public: /*! \brief A map from string names to global variables that ensures global uniqueness. */ - Map global_var_map; + ffi::Map global_var_map; /*! * \brief A map from GlobalVar to all global functions. * \note Only defined functions are in the map, while declared functions are not included. */ - Map functions; + ffi::Map functions; /*! \brief IRModule's attributes. */ - Map attrs; + ffi::Map attrs; /*! \brief IRModule's global_infos */ - Map> global_infos; + ffi::Map> global_infos; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index 49bdcf60e6fb..9fe3d7e1ac65 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -45,14 +45,14 @@ TVM_DLL IRModuleFrame IRModule(); * (i.e. func params and func return type/shape). * \return The corresponding GlobalVar. */ -TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); +TVM_DLL GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature); /*! * \brief Define the function which is declared before. * \param func_name The function unique name. * \param func The given function implementation */ -TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); +TVM_DLL void DefFunction(const ffi::String& func_name, const BaseFunc& func); } // namespace ir } // namespace ir_builder diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index f729d19a14dd..053f84285f6e 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -57,9 +57,9 @@ class RelaxFrame : public IRBuilderFrame { class SeqExprFrameNode : public RelaxFrameNode { public: /*! \brief The binding blocks inside the frame. */ - Array binding_blocks; + ffi::Array binding_blocks; /*! \brief The frame output expr. `std::nullopt` when undefined. */ - Optional output; + ffi::Optional output; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -89,9 +89,9 @@ class FunctionFrameNode : public SeqExprFrameNode { * \note The name will not be specified in constructor, so it is "Optional", * However, we must specify the name by `R.func_name` before exit this frame. */ - Optional name; + ffi::Optional name; /*! \brief The function params. */ - Array params; + ffi::Array params; /*! * \brief The function return struct info. * \note Usually the function return type can be deduced by the function body. @@ -101,13 +101,13 @@ class FunctionFrameNode : public SeqExprFrameNode { * if we ret_struct_info is base of body.struct_info. If not, we will * take the specified `ret_struct_info`. */ - Optional ret_struct_info; + ffi::Optional ret_struct_info; /*! \brief Whether the function is annotated as pure */ - Optional is_pure; + ffi::Optional is_pure; /*! \brief Whether the function is annotated as private */ - Optional is_private; + ffi::Optional is_private; /*! \brief The function attributes. */ - Map attrs; + ffi::Map attrs; /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; @@ -143,7 +143,7 @@ class BlockFrameNode : public RelaxFrameNode { /*! \brief The flag that indicates whether the block is a dataflow block. */ bool is_dataflow; /*! \brief The variables emitted in this block. */ - Array emitted_vars; + ffi::Array emitted_vars; /*! * \brief A boolean indicating if the dataflow block is ended of construction. * If it is true, any new binding trying to be emitted into this block will cause an error. @@ -154,7 +154,7 @@ class BlockFrameNode : public RelaxFrameNode { * \brief The output vars of the dataflow block. * \note Only used for a dataflow block. */ - Array output_vars; + ffi::Array output_vars; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -188,13 +188,13 @@ class IfFrameNode : public RelaxFrameNode { /*! \brief The condition of the if statement. */ tvm::relax::Expr condition; /*! \brief The Bindings in the true branch. */ - Optional then_expr; + ffi::Optional then_expr; /*! \brief The Bindings in the false branch. */ - Optional else_expr; + ffi::Optional else_expr; /*! \brief The Binding var. */ tvm::relax::Var var; /*! \brief The binding var name. */ - String var_name; + ffi::String var_name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 49bc1a2851d3..80b70daffd0b 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -45,19 +45,19 @@ TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private); * \param struct_info The struct_info of the parameter. * \return The created function parameter var. */ -TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); +TVM_DLL tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info); /*! * \brief Specify the name of the last function frame. * \param name The function name. */ -TVM_DLL void FuncName(const String& name); +TVM_DLL void FuncName(const ffi::String& name); /*! * \brief Specify the attrs of the last function frame. * \param attrs The function attrs. */ -TVM_DLL void FuncAttrs(Map attrs); +TVM_DLL void FuncAttrs(ffi::Map attrs); /*! * \brief Specify the return struct info of the last function frame. @@ -89,7 +89,7 @@ TVM_DLL BlockFrame Dataflow(); * \brief Expose the dataflow block output variables as global ones * \param vars The output variables of a dataflow block */ -TVM_DLL void DataflowBlockOutput(const Array& vars); +TVM_DLL void DataflowBlockOutput(const ffi::Array& vars); ////////////////////////////// Bindings //////////////////////////////// @@ -101,7 +101,7 @@ TVM_DLL void DataflowBlockOutput(const Array& vars); */ TVM_DLL tvm::relax::Var Emit( const tvm::relax::Expr& value, - const Optional& annotate_struct_info = std::nullopt); + const ffi::Optional& annotate_struct_info = std::nullopt); /*! * \brief Emit a match_cast binding to the last binding block frame. diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 52173a8d8a4f..1c3e19959024 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -36,7 +36,7 @@ namespace tir { class TIRFrameNode : public IRBuilderFrameNode { public: /*! \brief The Stmt within in this frame. */ - Array stmts; + ffi::Array stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -68,21 +68,21 @@ class TIRFrame : public IRBuilderFrame { class PrimFuncFrameNode : public TIRFrameNode { public: /*! \brief The name of the block. */ - Optional name; + ffi::Optional name; /*! \brief Function parameters. */ - Array args; + ffi::Array args; /*! \brief Whether the PrimFunc is annotated as private. */ bool is_private; /*! \brief The return type of the function. */ - Optional ret_type; + ffi::Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ - Map buffer_map; + ffi::Map buffer_map; /*! \brief Additional attributes storing the meta-data */ - Map attrs; + ffi::Map attrs; /*! \brief The variable map bound to thread env. */ - Map env_threads; + ffi::Map env_threads; /*! \brief The buffer allocated in root block. */ - Array root_alloc_buffers; + ffi::Array root_alloc_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -126,28 +126,28 @@ class PrimFuncFrame : public TIRFrame { class BlockFrameNode : public TIRFrameNode { public: /*! \brief The name of the block. */ - String name; + ffi::String name; /*! \brief The variables of the block. */ - Array iter_vars; + ffi::Array iter_vars; /*! \brief The read buffer regions of the block. */ - Optional> reads; + ffi::Optional> reads; /*! \brief The write buffer regions of the block. */ - Optional> writes; + ffi::Optional> writes; /*! \brief The init statement of the bolck. */ - Optional init; + ffi::Optional init; /*! \brief The buffer allocated in the block. */ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The match buffer regions. */ - Array match_buffers; + ffi::Array match_buffers; /*! \brief The annotation of the block. */ - Optional> annotations; + ffi::Optional> annotations; /*! \brief The corresponding values of the iter vars. */ - Array iter_values; + ffi::Array iter_values; /*! * \brief The predicate of the block realization, the block will only be executed when the * predicate is true. */ - Optional predicate; + ffi::Optional predicate; /*! \brief The flag whether to construct BlockRealize or Block. */ bool no_realize; @@ -241,12 +241,13 @@ class ForFrameNode : public TIRFrameNode { * \param loop_body The loop body * \return A stmt, the loop nest */ - using FMakeForLoop = ffi::TypedFunction loop_vars, Array loop_extents, tvm::tir::Stmt loop_body)>; + using FMakeForLoop = + ffi::TypedFunction loop_vars, + ffi::Array loop_extents, tvm::tir::Stmt loop_body)>; /*! \brief The loop variable. */ - Array vars; + ffi::Array vars; /*! \brief The domains of iteration. */ - Array doms; + ffi::Array doms; /*! \brief The for loop generating function. */ FMakeForLoop f_make_for_loop; @@ -369,7 +370,7 @@ class LaunchThreadFrameNode : public TIRFrameNode { /*! \brief The extent of environment thread. */ PrimExpr extent; /*! \brief The attribute key, could be either virtual_thread or thread_extent. */ - String attr_key; + ffi::String attr_key; /*! \brief The iteration variable. */ tvm::tir::IterVar iter_var; @@ -413,7 +414,7 @@ class RealizeFrameNode : public TIRFrameNode { /*! \brief The region of buffer access. */ tvm::tir::BufferRegion buffer_slice; /*! \brief The storage scope associated with this realization. */ - String storage_scope; + ffi::String storage_scope; /*! \brief The condition expression. */ PrimExpr condition; @@ -454,15 +455,15 @@ class RealizeFrame : public TIRFrame { class AllocateFrameNode : public TIRFrameNode { public: /*! \brief The extents of the allocate. */ - Array extents; + ffi::Array extents; /*! \brief The data type of the buffer. */ DataType dtype; /*! \brief The storage scope. */ - String storage_scope; + ffi::String storage_scope; /*! \brief The condition. */ PrimExpr condition; /*! \brief Additional annotation hints. */ - Map annotations; + ffi::Map annotations; /*! \brief The buffer var. */ tvm::tir::Var buffer_var; @@ -508,13 +509,13 @@ class AllocateConstFrameNode : public TIRFrameNode { /*! \brief The data type of the buffer. */ DataType dtype; /*! \brief The extents of the allocate. */ - Array extents; + ffi::Array extents; /*! \brief The data associated with the constant. */ tvm::runtime::Tensor data; /*! \brief The buffer var */ tvm::tir::Var buffer_var; /*! \brief Additional annotations about the allocation. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -557,7 +558,7 @@ class AttrFrameNode : public TIRFrameNode { /*! \brief The node to annotate the attribute. */ Any node; /*! \brief Attribute type key. */ - String attr_key; + ffi::String attr_key; /*! \brief The value of the attribute. */ PrimExpr value; @@ -636,9 +637,9 @@ class IfFrameNode : public TIRFrameNode { /*! \brief The condition of the if statement. */ PrimExpr condition; /*! \brief The statements in the true branch. */ - Optional> then_stmts; + ffi::Optional> then_stmts; /*! \brief The stetements in the false branch. */ - Optional> else_stmts; + ffi::Optional> else_stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 6894bfa1fb58..24ce8fdf990a 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -47,10 +47,11 @@ using tvm::tir::Var; * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, - Optional> strides, Optional elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type, - Optional> axis_separators); +Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, int align, + int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators); /*! * \brief The primitive function statement. @@ -64,7 +65,7 @@ PrimFuncFrame PrimFunc(bool is_private); * \param var The variable argument. * \return The variable. */ -Var Arg(String name, Var var); +Var Arg(ffi::String name, Var var); /*! * \brief The PrimFunc buffer arguments adding function. @@ -72,19 +73,19 @@ Var Arg(String name, Var var); * \param buffer The buffer argument. * \return The buffer. */ -Buffer Arg(String name, Buffer buffer); +Buffer Arg(ffi::String name, Buffer buffer); /*! * \brief The PrimFunc naming statement. * \param name The name of the PrimFunc. */ -void FuncName(String name); +void FuncName(ffi::String name); /*! * \brief The PrimFunc annotation statement. * \param attrs The annotations of the PrimFunc. */ -void FuncAttrs(Map attrs); +void FuncAttrs(ffi::Map attrs); /*! * \brief The PrimFunc return type statement. @@ -108,11 +109,12 @@ Type FuncRet(Type ret_type); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The matched buffer. */ -Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = DataType::Float(32), - Optional data = std::nullopt, Array strides = {}, - PrimExpr elem_offset = PrimExpr(), String storage_scope = "global", - int align = -1, int offset_factor = 0, String buffer_type = "default", - Optional> axis_separators = std::nullopt); +Buffer MatchBuffer(ObjectRef param, ffi::Array shape, + DataType dtype = DataType::Float(32), ffi::Optional data = std::nullopt, + ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), + ffi::String storage_scope = "global", int align = -1, int offset_factor = 0, + ffi::String buffer_type = "default", + ffi::Optional> axis_separators = std::nullopt); /*! * \brief The block declaration statement. @@ -120,7 +122,7 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = Data * \param no_realize The flag whether to construct BlockRealize or Block. * \return The BlockFrame. */ -BlockFrame Block(String name, bool no_realize = false); +BlockFrame Block(ffi::String name, bool no_realize = false); /*! * \brief The block initialization statement. @@ -138,19 +140,19 @@ void Where(PrimExpr predicate); * \brief The block buffer region reading statement. * \param buffer_slices The array of buffer regions to read. */ -void Reads(Array buffer_slices); +void Reads(ffi::Array buffer_slices); /*! * \brief The block buffer region writing statement. * \param buffer_slices The array of buffer regions to write. */ -void Writes(Array buffer_slices); +void Writes(ffi::Array buffer_slices); /*! * \brief The block annotation statement. * \param attrs The annotation of the block. */ -void BlockAttrs(Map attrs); +void BlockAttrs(ffi::Map attrs); /*! * \brief The buffer allocation function. @@ -166,11 +168,11 @@ void BlockAttrs(Map attrs); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The allocated buffer. */ -Buffer AllocBuffer(Array shape, DataType dtype = DataType::Float(32), - Optional data = std::nullopt, Array strides = {}, - PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, - int offset_factor = 0, String buffer_type = "default", - Optional> axis_separators = std::nullopt); +Buffer AllocBuffer(ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::Optional data = std::nullopt, ffi::Array strides = {}, + PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "", + int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", + ffi::Optional> axis_separators = std::nullopt); namespace axis { /*! @@ -216,7 +218,8 @@ Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data types of the iteration variables. * \return The iteration variables. */ -Array Remap(String kinds, Array bindings, DataType dtype = DataType::Int(32)); +ffi::Array Remap(ffi::String kinds, ffi::Array bindings, + DataType dtype = DataType::Int(32)); } // namespace axis @@ -228,7 +231,7 @@ Array Remap(String kinds, Array bindings, DataType dtype = DataTy * \return The ForFrame. */ ForFrame Serial(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt); /*! * \brief The parallel For statement. * \param start The minimum value of iteration. @@ -237,7 +240,7 @@ ForFrame Serial(PrimExpr start, PrimExpr stop, * \return The ForFrame. */ ForFrame Parallel(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt); /*! * \brief The vectorized For statement. * \param start The minimum value of iteration. @@ -246,7 +249,7 @@ ForFrame Parallel(PrimExpr start, PrimExpr stop, * \return The ForFrame. */ ForFrame Vectorized(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt); /*! * \brief The unrolled For statement. * \param start The minimum value of iteration. @@ -255,7 +258,7 @@ ForFrame Vectorized(PrimExpr start, PrimExpr stop, * \return The ForFrame. */ ForFrame Unroll(PrimExpr start, PrimExpr stop, - Optional> annotations = std::nullopt); + ffi::Optional> annotations = std::nullopt); /*! * \brief The thread-binding For statement. * \param start The minimum value of iteration. @@ -264,14 +267,14 @@ ForFrame Unroll(PrimExpr start, PrimExpr stop, * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Optional> annotations = std::nullopt); +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, + ffi::Optional> annotations = std::nullopt); /*! * \brief The grid For statement. * \param extents The extents of the iteration. * \return The ForFrame. */ -ForFrame Grid(Array extents); +ForFrame Grid(ffi::Array extents); /*! * \brief The assertion statement. @@ -279,7 +282,7 @@ ForFrame Grid(Array extents); * \param message The error message when the assertion fails. * \return The AssertFrame. */ -AssertFrame Assert(PrimExpr condition, String message); +AssertFrame Assert(PrimExpr condition, ffi::String message); /*! * \brief The let binding. @@ -290,8 +293,8 @@ AssertFrame Assert(PrimExpr condition, String message); * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame LetStmt(PrimExpr value, Optional type_annotation = std::nullopt, - Optional var = std::nullopt); +LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation = std::nullopt, + ffi::Optional var = std::nullopt); /*! * \brief The realization. @@ -300,7 +303,8 @@ LetFrame LetStmt(PrimExpr value, Optional type_annotation = std::nullopt, * \param condition The condition expression. * \return The result RealizeFrame. */ -RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope, + PrimExpr condition); /*! * \brief The allocate node. @@ -311,9 +315,9 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, * \param annotations Additional annotation hints. * \return The created AllocateFrame. */ -AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", - Optional condition = std::nullopt, - Optional> annotations = std::nullopt); +AllocateFrame Allocate(ffi::Array extents, DataType dtype, ffi::String storage_scope = "", + ffi::Optional condition = std::nullopt, + ffi::Optional> annotations = std::nullopt); /*! * \brief The allocate constant node. @@ -323,8 +327,9 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s * \param annotations Additional annotation hints. * \return The created AllocateConstFrame. */ -AllocateConstFrame AllocateConst(Tensor data, DataType dtype, Array extents, - Optional> annotations = std::nullopt); +AllocateConstFrame AllocateConst( + Tensor data, DataType dtype, ffi::Array extents, + ffi::Optional> annotations = std::nullopt); /*! * \brief Create an attribute. @@ -333,7 +338,7 @@ AllocateConstFrame AllocateConst(Tensor data, DataType dtype, Array ex * \param value The value of the attribute. * \return The result AttrFrame. */ -AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value); +AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value); /*! * \brief Create a while loop. @@ -376,11 +381,11 @@ ElseFrame Else(); * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, - Optional data, Optional> strides, - Optional elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type, - Optional> axis_separators); +DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators); /*! * \brief Launch a thread. @@ -396,7 +401,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); * \param extent The extent of environment thread. * \return The result LaunchThreadFrame. */ -LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); +LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent); /*! * \brief Bind a var to thread env. @@ -404,7 +409,7 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); * \param dtype The data type of the variable. * \return The result variable which gets bound to the thread env. */ -Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); +Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32)); /*! * \brief Store data in a buffer. @@ -414,8 +419,8 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * stored. The number lanes of the mask must be equal to the number of lanes in value. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate); +void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate); /*! * \brief Evaluate the input expression. @@ -441,7 +446,7 @@ void Evaluate(PrimExpr value); * \return The pointer. */ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), - String storage_scope = "global", bool is_size_var = false, + ffi::String storage_scope = "global", bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; if (is_unknown_type && storage_scope == "global") { @@ -454,12 +459,13 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); } -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(Optional expr = std::nullopt, bool is_size_var = false) { \ - DataType dtype = DType; \ - return expr.defined() \ - ? tvm::cast(dtype, expr.value()) \ - : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ + bool is_size_var = false) { \ + DataType dtype = DType; \ + return expr.defined() \ + ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index b045ee00315b..976e3183a16e 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -42,7 +42,7 @@ class Doc; * \param doc Doc to be converted * \param cfg The configuration of the printer */ -String DocToPythonScript(Doc doc, const PrinterConfig& cfg); +ffi::String DocToPythonScript(Doc doc, const PrinterConfig& cfg); /*! * \brief The base class of all Doc. @@ -64,7 +64,7 @@ class DocNode : public Object { * this Doc is generated, in order to position the diagnostic * message. */ - mutable Array source_paths; + mutable ffi::Array source_paths; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -106,19 +106,19 @@ class ExprDocNode : public DocNode { * \brief Create a doc representing attribute access on the current ExprDoc * \param attr The attribute to access. */ - ExprDoc Attr(String attr) const; + ExprDoc Attr(ffi::String attr) const; /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. */ - ExprDoc operator[](Array indices) const; + ExprDoc operator[](ffi::Array indices) const; /*! * \brief Create a doc representing calling the current ExprDoc * \param args The positional arguments of the function call. */ - ExprDoc Call(Array args) const; + ExprDoc Call(ffi::Array args) const; /*! * \brief Create a doc representing attribute access on the current ExprDoc @@ -126,9 +126,9 @@ class ExprDocNode : public DocNode { * \param kwargs_keys Keys of keywords arguments of the function call. * \param kwargs_values Values of keywords arguments of the function call. */ - ExprDoc Call(Array args, // - Array kwargs_keys, // - Array kwargs_values) const; + ExprDoc Call(ffi::Array args, // + ffi::Array kwargs_keys, // + ffi::Array kwargs_values) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -154,7 +154,7 @@ class ExprDoc : public Doc { * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. */ - ExprDoc operator[](Array indices) const; + ExprDoc operator[](ffi::Array indices) const; TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); }; @@ -174,7 +174,7 @@ class StmtDocNode : public DocNode { * line as the statement, or the line above, or inside the statement * if it spans over multiple lines. * */ - mutable Optional comment{std::nullopt}; + mutable ffi::Optional comment{std::nullopt}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -208,7 +208,7 @@ class StmtDoc : public Doc { class StmtBlockDocNode : public DocNode { public: /*! \brief The list of statements. */ - Array stmts; + ffi::Array stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -230,7 +230,7 @@ class StmtBlockDoc : public Doc { * \brief Constructor of StmtBlockDoc. * \param stmts The list of statements. */ - explicit StmtBlockDoc(Array stmts); + explicit StmtBlockDoc(ffi::Array stmts); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtBlockDoc, Doc, StmtBlockDocNode); }; @@ -269,20 +269,22 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ffi::Any value, const Optional& object_path); + explicit LiteralDoc(ffi::Any value, const ffi::Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } + static LiteralDoc None(const ffi::Optional& p) { + return LiteralDoc(ffi::Any(nullptr), p); + } /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. * \param p The object path */ - static LiteralDoc Int(int64_t v, const Optional& p) { + static LiteralDoc Int(int64_t v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Int(64), v), p); } /*! @@ -290,7 +292,7 @@ class LiteralDoc : public ExprDoc { * \param v The boolean value. * \param p The object path */ - static LiteralDoc Boolean(bool v, const Optional& p) { + static LiteralDoc Boolean(bool v, const ffi::Optional& p) { return LiteralDoc(IntImm(DataType::Bool(), v), p); } /*! @@ -298,7 +300,7 @@ class LiteralDoc : public ExprDoc { * \param v The float value. * \param p The object path */ - static LiteralDoc Float(double v, const Optional& p) { + static LiteralDoc Float(double v, const ffi::Optional& p) { return LiteralDoc(FloatImm(DataType::Float(64), v), p); } /*! @@ -306,13 +308,15 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc Str(const String& v, const Optional& p) { return LiteralDoc(v, p); } + static LiteralDoc Str(const ffi::String& v, const ffi::Optional& p) { + return LiteralDoc(v, p); + } /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const Optional& p) { + static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { std::string dtype = v.is_void() ? "void" : runtime::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } @@ -321,7 +325,7 @@ class LiteralDoc : public ExprDoc { * \param v The device. * \param p The object path */ - static LiteralDoc Device(const DLDevice& v, const Optional& p) { + static LiteralDoc Device(const DLDevice& v, const ffi::Optional& p) { std::ostringstream os; runtime::operator<<(os, v); return LiteralDoc::Str(os.str(), p); @@ -338,7 +342,7 @@ class LiteralDoc : public ExprDoc { class IdDocNode : public ExprDocNode { public: /*! \brief The name of the identifier */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -361,7 +365,7 @@ class IdDoc : public ExprDoc { * \brief Constructor of IdDoc. * \param name The name of identifier. */ - explicit IdDoc(String name); + explicit IdDoc(ffi::String name); explicit IdDoc(std::nullptr_t) : ExprDoc(nullptr) {} TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode); }; @@ -376,7 +380,7 @@ class AttrAccessDocNode : public ExprDocNode { /*! \brief The target expression to be accessed */ ExprDoc value{nullptr}; /*! \brief The attribute to be accessed */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -402,7 +406,7 @@ class AttrAccessDoc : public ExprDoc { * \param value The target expression of attribute access. * \param name The name of attribute to access. */ - explicit AttrAccessDoc(ExprDoc value, String name); + explicit AttrAccessDoc(ExprDoc value, ffi::String name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); }; @@ -422,7 +426,7 @@ class IndexDocNode : public ExprDocNode { * - ExprDoc (single point access like a[1, 2]) * - SliceDoc (slice access like a[1:5, 2]) */ - Array indices; // Each element is union of: Slice / ExprDoc + ffi::Array indices; // Each element is union of: Slice / ExprDoc static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -448,7 +452,7 @@ class IndexDoc : public ExprDoc { * \param value The target expression of index access. * \param indices The indices to access. */ - explicit IndexDoc(ExprDoc value, Array indices); + explicit IndexDoc(ExprDoc value, ffi::Array indices); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IndexDoc, ExprDoc, IndexDocNode); }; @@ -462,16 +466,16 @@ class CallDocNode : public ExprDocNode { /*! \brief The callee of this function call */ ExprDoc callee{nullptr}; /*! \brief The positional arguments */ - Array args; + ffi::Array args; /*! \brief The keys of keyword arguments */ - Array kwargs_keys; + ffi::Array kwargs_keys; /*! * \brief The values of keyword arguments. * * The i-th element is the value of the i-th key in `kwargs_keys`. * It must have the same length as `kwargs_keys`. */ - Array kwargs_values; + ffi::Array kwargs_values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -501,8 +505,8 @@ class CallDoc : public ExprDoc { * \param kwargs_keys Keys of keyword arguments. * \param kwargs_values Values of keyword arguments, must have the same length as `kwargs_keys. */ - CallDoc(ExprDoc callee, Array args, Array kwargs_keys, - Array kwargs_values); + CallDoc(ExprDoc callee, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CallDoc, ExprDoc, CallDocNode); }; @@ -557,7 +561,7 @@ class OperationDocNode : public ExprDocNode { /*! \brief The kind of operation (operator) */ Kind kind; /*! \brief Operands of this expression */ - Array operands; + ffi::Array operands; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -583,7 +587,7 @@ class OperationDoc : public ExprDoc { * \param kind The kind of operation. * \param operands Operands of this expression. */ - explicit OperationDoc(OperationDocNode::Kind kind, Array operands); + explicit OperationDoc(OperationDocNode::Kind kind, ffi::Array operands); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OperationDoc, ExprDoc, OperationDocNode); }; @@ -598,7 +602,7 @@ class OperationDoc : public ExprDoc { class LambdaDocNode : public ExprDocNode { public: /*! \brief The arguments of this anonymous function */ - Array args; + ffi::Array args; /*! \brief The body of this anonymous function */ ExprDoc body{nullptr}; @@ -626,7 +630,7 @@ class LambdaDoc : public ExprDoc { * \param args Arguments of this function. * \param body Body expression of this function. */ - explicit LambdaDoc(Array args, ExprDoc body); + explicit LambdaDoc(ffi::Array args, ExprDoc body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, ExprDoc, LambdaDocNode); }; @@ -638,7 +642,7 @@ class LambdaDoc : public ExprDoc { class TupleDocNode : public ExprDocNode { public: /*! \brief Elements of tuple */ - Array elements; + ffi::Array elements; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -665,7 +669,7 @@ class TupleDoc : public ExprDoc { * \brief Constructor of TupleDoc * \param elements Elements of tuple. */ - explicit TupleDoc(Array elements); + explicit TupleDoc(ffi::Array elements); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleDoc, ExprDoc, TupleDocNode); }; @@ -677,7 +681,7 @@ class TupleDoc : public ExprDoc { class ListDocNode : public ExprDocNode { public: /*! \brief Elements of list */ - Array elements; + ffi::Array elements; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -704,7 +708,7 @@ class ListDoc : public ExprDoc { * \brief Constructor of ListDoc * \param elements Elements of list. */ - explicit ListDoc(Array elements); + explicit ListDoc(ffi::Array elements); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ListDoc, ExprDoc, ListDocNode); }; @@ -716,14 +720,14 @@ class ListDoc : public ExprDoc { class DictDocNode : public ExprDocNode { public: /*! \brief keys of dictionary */ - Array keys; + ffi::Array keys; /*! * \brief Values of dictionary * * The i-th element is the value of the i-th element of `keys`. * It must have the same length as `keys`. */ - Array values; + ffi::Array values; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -753,7 +757,7 @@ class DictDoc : public ExprDoc { * \param keys Keys of dictionary. * \param values Values of dictionary, must have same length as `keys`. */ - explicit DictDoc(Array keys, Array values); + explicit DictDoc(ffi::Array keys, ffi::Array values); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DictDoc, ExprDoc, DictDocNode); }; @@ -767,11 +771,11 @@ class DictDoc : public ExprDoc { class SliceDocNode : public DocNode { public: /*! \brief The start of slice */ - Optional start; + ffi::Optional start; /*! \brief The exclusive end of slice */ - Optional stop; + ffi::Optional stop; /*! \brief The step of slice */ - Optional step; + ffi::Optional step; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -799,7 +803,8 @@ class SliceDoc : public Doc { * \param stop The exclusive end of slice. * \param step The step of slice. */ - explicit SliceDoc(Optional start, Optional stop, Optional step); + explicit SliceDoc(ffi::Optional start, ffi::Optional stop, + ffi::Optional step); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); }; @@ -817,9 +822,9 @@ class AssignDocNode : public StmtDocNode { * * If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))` * */ - Optional rhs; + ffi::Optional rhs; /*! \brief The type annotation of this assignment. */ - Optional annotation; + ffi::Optional annotation; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -847,7 +852,7 @@ class AssignDoc : public StmtDoc { * \param rhs The right hand side of the assignment. * \param annotation The type annotation of this assignment. */ - explicit AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation); + explicit AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional annotation); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssignDoc, StmtDoc, AssignDocNode); }; @@ -861,9 +866,9 @@ class IfDocNode : public StmtDocNode { /*! \brief The predicate of the if-then-else statement. */ ExprDoc predicate{nullptr}; /*! \brief The then branch of the if-then-else statement. */ - Array then_branch; + ffi::Array then_branch; /*! \brief The else branch of the if-then-else statement. */ - Array else_branch; + ffi::Array else_branch; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -891,7 +896,8 @@ class IfDoc : public StmtDoc { * \param then_branch The then branch of the if-then-else statement. * \param else_branch The else branch of the if-then-else statement. */ - explicit IfDoc(ExprDoc predicate, Array then_branch, Array else_branch); + explicit IfDoc(ExprDoc predicate, ffi::Array then_branch, + ffi::Array else_branch); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IfDoc, StmtDoc, IfDocNode); }; @@ -905,7 +911,7 @@ class WhileDocNode : public StmtDocNode { /*! \brief The predicate of the while statement. */ ExprDoc predicate{nullptr}; /*! \brief The body of the while statement. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -931,7 +937,7 @@ class WhileDoc : public StmtDoc { * \param predicate The predicate of the while statement. * \param body The body of the while statement. */ - explicit WhileDoc(ExprDoc predicate, Array body); + explicit WhileDoc(ExprDoc predicate, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WhileDoc, StmtDoc, WhileDocNode); }; @@ -951,7 +957,7 @@ class ForDocNode : public StmtDocNode { /*! \brief The right hand side of the assignment of iterating variable. */ ExprDoc rhs{nullptr}; /*! \brief The body of the for statement. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -979,7 +985,7 @@ class ForDoc : public StmtDoc { * \param rhs The right hand side of the assignment of iterating variable. * \param body The body of the for statement. */ - explicit ForDoc(ExprDoc lhs, ExprDoc rhs, Array body); + explicit ForDoc(ExprDoc lhs, ExprDoc rhs, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ForDoc, StmtDoc, ForDocNode); }; @@ -996,11 +1002,11 @@ class ForDoc : public StmtDoc { class ScopeDocNode : public StmtDocNode { public: /*! \brief The name of the scoped variable. */ - Optional lhs{std::nullopt}; + ffi::Optional lhs{std::nullopt}; /*! \brief The value of the scoped variable. */ ExprDoc rhs{nullptr}; /*! \brief The body of the scope doc. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1028,14 +1034,14 @@ class ScopeDoc : public StmtDoc { * \param rhs The value of the scoped variable. * \param body The body of the scope doc. */ - explicit ScopeDoc(Optional lhs, ExprDoc rhs, Array body); + explicit ScopeDoc(ffi::Optional lhs, ExprDoc rhs, ffi::Array body); /*! * \brief Constructor of ScopeDoc. * \param rhs The value of the scoped variable. * \param body The body of the scope doc. */ - explicit ScopeDoc(ExprDoc rhs, Array body); + explicit ScopeDoc(ExprDoc rhs, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ScopeDoc, StmtDoc, ScopeDocNode); }; @@ -1085,7 +1091,7 @@ class AssertDocNode : public StmtDocNode { /*! \brief The expression to test. */ ExprDoc test{nullptr}; /*! \brief The optional error message when assertion failed. */ - Optional msg{std::nullopt}; + ffi::Optional msg{std::nullopt}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1111,7 +1117,7 @@ class AssertDoc : public StmtDoc { * \param test The expression to test. * \param msg The optional error message when assertion failed. */ - explicit AssertDoc(ExprDoc test, Optional msg = std::nullopt); + explicit AssertDoc(ExprDoc test, ffi::Optional msg = std::nullopt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode); }; @@ -1166,13 +1172,13 @@ class FunctionDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief Decorators of function. */ - Array decorators; + ffi::Array decorators; /*! \brief The return type of function. */ - Optional return_type{std::nullopt}; + ffi::Optional return_type{std::nullopt}; /*! \brief The body of function. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1204,8 +1210,8 @@ class FunctionDoc : public StmtDoc { * \param return_type The return type of function. * \param body The body of function. */ - explicit FunctionDoc(IdDoc name, Array args, Array decorators, - Optional return_type, Array body); + explicit FunctionDoc(IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode); }; @@ -1219,9 +1225,9 @@ class ClassDocNode : public StmtDocNode { /*! \brief The name of class. */ IdDoc name{nullptr}; /*! \brief Decorators of class. */ - Array decorators; + ffi::Array decorators; /*! \brief The body of class. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -1249,7 +1255,7 @@ class ClassDoc : public StmtDoc { * \param decorators The decorator of class. * \param body The body of class. */ - explicit ClassDoc(IdDoc name, Array decorators, Array body); + explicit ClassDoc(IdDoc name, ffi::Array decorators, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode); }; @@ -1276,7 +1282,7 @@ class CommentDocNode : public StmtDocNode { */ class CommentDoc : public StmtDoc { public: - explicit CommentDoc(String comment); + explicit CommentDoc(ffi::String comment); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, CommentDocNode); }; @@ -1303,7 +1309,7 @@ class DocStringDocNode : public StmtDocNode { */ class DocStringDoc : public StmtDoc { public: - explicit DocStringDoc(String docs); + explicit DocStringDoc(ffi::String docs); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, DocStringDocNode); }; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index dd7eaff7cc69..6e6be57f9ce5 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -50,7 +50,7 @@ class IRDocsifierNode; class FrameNode : public Object { public: /*! The docs generated in the frame */ - Array stmts; + ffi::Array stmts; /*! The corresponding IRDocsifier */ IRDocsifierNode* d; /*! The callbacks that are going to be invoked when the frame exits */ @@ -82,7 +82,7 @@ class FrameNode : public Object { * \param d The docsifier. * \param token The token to be added. */ - void AddDispatchToken(const IRDocsifier& d, const String& token); + void AddDispatchToken(const IRDocsifier& d, const ffi::String& token); /*! * \brief Method that's called when Frame enters the scope. */ @@ -129,7 +129,7 @@ class IRDocsifierNode : public Object { /*! \brief The creator */ DocCreator creator; /*! \brief The name of the variable */ - Optional name; + ffi::Optional name; }; /*! \brief The configuration of the printer */ PrinterConfig cfg{nullptr}; @@ -137,22 +137,22 @@ class IRDocsifierNode : public Object { * \brief The stack of frames. * \sa FrameNode */ - Array frames; + ffi::Array frames; /*! * \brief The stack of dispatch tokens. * * The dispatch token on the top decides which dispatch function to use * when converting IR node object to Doc. */ - Array dispatch_tokens; + ffi::Array dispatch_tokens; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; /*! \brief Metadata printing */ - std::unordered_map> metadata; + std::unordered_map> metadata; /*! \brief GlobalInfo printing */ - std::unordered_map> global_infos; + std::unordered_map> global_infos; /*! \brief The variable names used already */ - std::unordered_set defined_names; + std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ std::unordered_map> common_prefix; /*! \brief The IR usages for headers printing */ @@ -181,7 +181,7 @@ class IRDocsifierNode : public Object { * This function will rename the variable to avoid name conflict with other variables * in the table. */ - IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint); + IdDoc Define(const ObjectRef& obj, const Frame& frame, const ffi::String& name_hint); /*! * \brief Define variable by doc factory. @@ -207,14 +207,14 @@ class IRDocsifierNode : public Object { * * \return The doc for variable, if it exists in the table. Otherwise it returns std::nullopt. */ - Optional GetVarDoc(const ObjectRef& obj) const; + ffi::Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ ExprDoc AddMetadata(const ffi::Any& obj); /*! \brief Add a GlobalInfo to the global_infos map. * \param name The name of key of global_infos. * \param ginfo The GlobalInfo to be added. */ - void AddGlobalInfo(const String& name, const GlobalInfo& ginfo); + void AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo); /*! * \brief Check if a variable exists in the table. * \param obj The variable object. @@ -259,7 +259,7 @@ class IRDocsifier : public ObjectRef { inline void FrameNode::EnterWithScope() { if (d != nullptr) { - d->frames.push_back(GetRef(this)); + d->frames.push_back(ffi::GetRef(this)); } } @@ -295,7 +295,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ac } for (const auto& pair : cfg->path_to_annotate) { AccessPath p = pair.first; - String attn = pair.second; + ffi::String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { if (stmt->comment.has_value()) { @@ -340,7 +340,8 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con default: { if (auto opt_obj = value.as()) { ObjectRef obj = opt_obj.value(); - Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this)); + Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, + ffi::GetRef(this)); d->source_paths.push_back(path); AddDocDecoration(d, obj, path, cfg); return Downcast(d); @@ -352,7 +353,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con } } -inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) { +inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const ffi::String& token) { d->dispatch_tokens.push_back(token); this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); }); } diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index e4be2d31aa57..4500a7d8607b 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -61,7 +61,7 @@ class IRDocsifierFunctor { * dispatch function for TObjectRef with the default dispatch token (empty string). */ template - R operator()(const String& token, TObjectRef obj, Args... args) const { + R operator()(const ffi::String& token, TObjectRef obj, Args... args) const { uint32_t type_index = obj.defined() ? obj->type_index() : 0; const ffi::Function* pf = nullptr; if ((pf = LookupDispatchTable(token, type_index)) != nullptr) { @@ -91,7 +91,7 @@ class IRDocsifierFunctor { * This takes a type-erased packed function as input. It should be used * through FFI boundary, for example, registering dispatch function from Python. */ - TSelf& set_dispatch(String token, uint32_t type_index, ffi::Function f) { + TSelf& set_dispatch(ffi::String token, uint32_t type_index, ffi::Function f) { std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { table->resize(type_index + 1, nullptr); @@ -120,7 +120,7 @@ class IRDocsifierFunctor { */ template ::value>> - TSelf& set_dispatch(String token, TCallable f) { + TSelf& set_dispatch(ffi::String token, TCallable f) { return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(), ffi::TypedFunction(f)); } @@ -140,7 +140,7 @@ class IRDocsifierFunctor { * This is useful when dispatch function comes from other language's runtime, and * those function should be removed before that language runtime shuts down. */ - void remove_dispatch(String token, uint32_t type_index) { + void remove_dispatch(ffi::String token, uint32_t type_index) { std::vector* table = &dispatch_table_[token]; if (table->size() <= type_index) { return; @@ -155,7 +155,7 @@ class IRDocsifierFunctor { * \param type_index The TVM object type index. * \return Returns the functor if the lookup succeeds, nullptr otherwise. */ - const ffi::Function* LookupDispatchTable(const String& token, uint32_t type_index) const { + const ffi::Function* LookupDispatchTable(const ffi::String& token, uint32_t type_index) const { auto it = dispatch_table_.find(token); if (it == dispatch_table_.end()) { return nullptr; diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index 26111caa079a..5513a8298e8f 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -37,9 +37,9 @@ namespace tvm { class TargetTagNode : public Object { public: /*! \brief Name of the target */ - String name; + ffi::String name; /*! \brief Config map to generate the target */ - Map config; + ffi::Map config; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -56,7 +56,7 @@ class TargetTagNode : public Object { /*! \brief Return the index stored in attr registry */ uint32_t AttrRegistryIndex() const { return index_; } /*! \brief Return the name stored in attr registry */ - String AttrRegistryName() const { return name; } + ffi::String AttrRegistryName() const { return name; } /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; @@ -78,12 +78,12 @@ class TargetTag : public ObjectRef { * \param target_tag_name Name of the target tag * \return The Target requested */ - TVM_DLL static Optional Get(const String& target_tag_name); + TVM_DLL static ffi::Optional Get(const ffi::String& target_tag_name); /*! * \brief List all names of the existing target tags * \return A dictionary that maps tag name to the concrete target it corresponds to */ - TVM_DLL static Map ListTags(); + TVM_DLL static ffi::Map ListTags(); /*! * \brief Add a tag into the registry * \param name Name of the tag @@ -91,7 +91,7 @@ class TargetTag : public ObjectRef { * \param override Allow overriding existing tags * \return Target created with the tag */ - TVM_DLL static Target AddTag(String name, Map config, bool override); + TVM_DLL static Target AddTag(ffi::String name, ffi::Map config, bool override); TVM_DEFINE_OBJECT_REF_METHODS(TargetTag, ObjectRef, TargetTagNode); @@ -107,13 +107,13 @@ class TargetTagRegEntry { * \brief Set the config dict corresponding to the target tag * \param config The config dict for target creation */ - inline TargetTagRegEntry& set_config(Map config); + inline TargetTagRegEntry& set_config(ffi::Map config); /*! * \brief Add a key-value pair to the config dict * \param key The attribute name * \param value The attribute value */ - inline TargetTagRegEntry& with_config(String key, Any value); + inline TargetTagRegEntry& with_config(ffi::String key, Any value); /*! \brief Set name of the TargetTag to be the same as registry if it is empty */ inline TargetTagRegEntry& set_name(); /*! @@ -121,14 +121,14 @@ class TargetTagRegEntry { * \param target_tag_name The name of the TargetTag. * \return the corresponding entry. */ - TVM_DLL static TargetTagRegEntry& RegisterOrGet(const String& target_tag_name); + TVM_DLL static TargetTagRegEntry& RegisterOrGet(const ffi::String& target_tag_name); private: TargetTag tag_; - String name; + ffi::String name; /*! \brief private constructor */ - explicit TargetTagRegEntry(uint32_t reg_index) : tag_(make_object()) { + explicit TargetTagRegEntry(uint32_t reg_index) : tag_(ffi::make_object()) { tag_->index_ = reg_index; } template @@ -136,12 +136,12 @@ class TargetTagRegEntry { friend class TargetTag; }; -inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map config) { +inline TargetTagRegEntry& TargetTagRegEntry::set_config(ffi::Map config) { tag_->config = std::move(config); return *this; } -inline TargetTagRegEntry& TargetTagRegEntry::with_config(String key, ffi::Any value) { +inline TargetTagRegEntry& TargetTagRegEntry::with_config(ffi::String key, ffi::Any value) { tag_->config.Set(key, value); return *this; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 678d36aeceda..d4486c34e8ba 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -51,15 +51,15 @@ class TargetNode : public Object { /*! \brief The kind of the target device */ TargetKind kind; /*! \brief Target host information, must be Target type */ - Optional host; + ffi::Optional host; /*! \brief Tag of the target, can be empty */ - String tag; + ffi::String tag; /*! \brief Keys for this target */ - Array keys; + ffi::Array keys; /*! \brief Collection of attributes */ - Map attrs; + ffi::Map attrs; /*! \brief Target features */ - Map features; + ffi::Map features; /*! * \brief The raw string representation of the target @@ -68,9 +68,9 @@ class TargetNode : public Object { */ TVM_DLL const std::string& str() const; /*! \return Export target to JSON-like configuration */ - TVM_DLL Map Export() const; - /*! \return The Optional typed target host of the TargetNode */ - TVM_DLL Optional GetHost() const; + TVM_DLL ffi::Map Export() const; + /*! \return The ffi::Optional typed target host of the TargetNode */ + TVM_DLL ffi::Optional GetHost() const; /*! \return The device type for this target */ TVM_DLL int GetTargetDeviceType() const; @@ -91,7 +91,7 @@ class TargetNode : public Object { * TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently * code depends on str() and << being the same. */ - String ToDebugString() const; + ffi::String ToDebugString() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -112,12 +112,12 @@ class TargetNode : public Object { * \return An optional, std::nullopt if not found, otherwise the value found */ template - Optional GetAttr( + ffi::Optional GetAttr( const std::string& attr_key, - Optional default_value = Optional(std::nullopt)) const { + ffi::Optional default_value = ffi::Optional(std::nullopt)) const { auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + return Downcast>((*it).second); } else { return default_value; } @@ -130,8 +130,8 @@ class TargetNode : public Object { * \return An optional, std::nullopt if not found, otherwise the value found */ template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); + ffi::Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, ffi::Optional(default_value)); } /*! @@ -154,8 +154,9 @@ class TargetNode : public Object { * \endcode */ template - Optional GetFeature(const std::string& feature_key, - Optional default_value = std::nullopt) const { + ffi::Optional GetFeature( + const std::string& feature_key, + ffi::Optional default_value = std::nullopt) const { if (auto feature = features.Get(feature_key)) { return Downcast(feature.value()); } else { @@ -164,8 +165,9 @@ class TargetNode : public Object { } // variant that uses TObjectRef to enable implicit conversion to default value. template - Optional GetFeature(const std::string& attr_key, TObjectRef default_value) const { - return GetFeature(attr_key, Optional(default_value)); + ffi::Optional GetFeature(const std::string& attr_key, + TObjectRef default_value) const { + return GetFeature(attr_key, ffi::Optional(default_value)); } /*! \brief Get the keys for this target as a vector of string */ @@ -196,12 +198,12 @@ class Target : public ObjectRef { * \brief Construct a Target given a string * \param tag_or_config_or_target_str the string to parse for target */ - TVM_DLL explicit Target(const String& tag_or_config_or_target_str); + TVM_DLL explicit Target(const ffi::String& tag_or_config_or_target_str); /*! * \brief Construct a Target using a JSON-like configuration * \param config The JSON-like configuration for target */ - TVM_DLL explicit Target(const Map& config); + TVM_DLL explicit Target(const ffi::Map& config); /*! * \brief Get the current target context from thread local storage. * \param allow_not_defined If the context stack is empty and this is set to true, an @@ -230,8 +232,8 @@ class Target : public ObjectRef { Target WithoutHost() const; private: - Target(TargetKind kind, Optional host, String tag, Array keys, - Map attrs); + Target(TargetKind kind, ffi::Optional host, ffi::String tag, + ffi::Array keys, ffi::Map attrs); // enable with syntax. friend class TargetInternal; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index d89148964bcd..ad167ce08bcc 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -41,7 +41,7 @@ class Target; /*! * \brief Map containing parsed features of a specific Target */ -using TargetFeatures = Map; +using TargetFeatures = ffi::Map; /*! * \brief TargetParser to apply on instantiation of a given TargetKind @@ -50,7 +50,7 @@ using TargetFeatures = Map; * * \return The transformed Target JSON object. */ -using TargetJSON = Map; +using TargetJSON = ffi::Map; using FTVMTargetParser = ffi::TypedFunction; namespace detail { @@ -67,11 +67,11 @@ class TargetKindAttrMap; class TargetKindNode : public Object { public: /*! \brief Name of the target kind */ - String name; + ffi::String name; /*! \brief Device type of target kind */ int default_device_type; /*! \brief Default keys of the target */ - Array default_keys; + ffi::Array default_keys; /*! \brief Function used to preprocess on target creation */ ffi::Function preprocessor; /*! \brief Function used to parse a JSON target during creation */ @@ -95,18 +95,18 @@ class TargetKindNode : public Object { /*! \brief Return the index stored in attr registry */ uint32_t AttrRegistryIndex() const { return index_; } /*! \brief Return the name stored in attr registry */ - String AttrRegistryName() const { return name; } + ffi::String AttrRegistryName() const { return name; } /*! \brief Stores the required type_key and type_index of a specific attr of a target */ struct ValueTypeInfo { - String type_key; + ffi::String type_key; int32_t type_index; std::unique_ptr key; std::unique_ptr val; }; /*! \brief A hash table that stores the type information of each attr of the target key */ - std::unordered_map key2vtype_; + std::unordered_map key2vtype_; /*! \brief A hash table that stores the default value of each attr of the target key */ - std::unordered_map key2default_; + std::unordered_map key2default_; /*! \brief Index used for internal lookup of attribute registry */ uint32_t index_; @@ -129,13 +129,13 @@ class TargetKind : public ObjectRef { TargetKind() = default; /*! \brief Get the attribute map given the attribute name */ template - static inline TargetKindAttrMap GetAttrMap(const String& attr_name); + static inline TargetKindAttrMap GetAttrMap(const ffi::String& attr_name); /*! * \brief Retrieve the TargetKind given its name * \param target_kind_name Name of the target kind * \return The TargetKind requested */ - TVM_DLL static Optional Get(const String& target_kind_name); + TVM_DLL static ffi::Optional Get(const ffi::String& target_kind_name); /*! \brief Mutable access to the container class */ TargetKindNode* operator->() { return static_cast(data_.get()); } @@ -143,13 +143,13 @@ class TargetKind : public ObjectRef { private: TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer( - const String& attr_name); + const ffi::String& attr_name); friend class TargetKindRegEntry; friend class TargetInternal; }; /*! - * \brief Map used to store meta-information about TargetKind + * \brief ffi::Map used to store meta-information about TargetKind * \tparam ValueType The type of the value stored in map */ template @@ -188,7 +188,7 @@ class TargetKindRegEntry { * \tparam ValueType The type of the value to be set. */ template - inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value, + inline TargetKindRegEntry& set_attr(const ffi::String& attr_name, const ValueType& value, int plevel = 10); /*! * \brief Set DLPack's device_type the target @@ -199,7 +199,7 @@ class TargetKindRegEntry { * \brief Set DLPack's device_type the target * \param keys The default keys */ - inline TargetKindRegEntry& set_default_keys(std::vector keys); + inline TargetKindRegEntry& set_default_keys(std::vector keys); /*! * \brief Set the pre-processing function applied upon target creation * \tparam FLambda Type of the function @@ -218,7 +218,7 @@ class TargetKindRegEntry { * \tparam ValueType The value type to be registered */ template - inline TargetKindRegEntry& add_attr_option(const String& key); + inline TargetKindRegEntry& add_attr_option(const ffi::String& key); /*! * \brief Register a valid configuration option and its ValueType for validation * \param key The configuration key @@ -226,33 +226,33 @@ class TargetKindRegEntry { * \tparam ValueType The value type to be registered */ template - inline TargetKindRegEntry& add_attr_option(const String& key, ffi::Any default_value); + inline TargetKindRegEntry& add_attr_option(const ffi::String& key, ffi::Any default_value); /*! \brief Set name of the TargetKind to be the same as registry if it is empty */ inline TargetKindRegEntry& set_name(); /*! * \brief List all the entry names in the registry. * \return The entry names. */ - TVM_DLL static Array ListTargetKinds(); + TVM_DLL static ffi::Array ListTargetKinds(); /*! * \brief Get all supported option names and types for a given Target kind. * \return Map of option name to type */ - TVM_DLL static Map ListTargetKindOptions(const TargetKind& kind); + TVM_DLL static ffi::Map ListTargetKindOptions(const TargetKind& kind); /*! * \brief Register or get a new entry. * \param target_kind_name The name of the TargetKind. * \return the corresponding entry. */ - TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name); + TVM_DLL static TargetKindRegEntry& RegisterOrGet(const ffi::String& target_kind_name); private: TargetKind kind_; - String name; + ffi::String name; /*! \brief private constructor */ - explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object()) { + explicit TargetKindRegEntry(uint32_t reg_index) : kind_(ffi::make_object()) { kind_->index_ = reg_index; } /*! @@ -261,7 +261,7 @@ class TargetKindRegEntry { * \param value The value to be set * \param plevel The priority level */ - TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel); + TVM_DLL void UpdateAttr(const ffi::String& key, ffi::Any value, int plevel); template friend class AttrRegistry; friend class TargetKind; @@ -278,8 +278,9 @@ struct is_specialized, Container> : std::true_type { using type = std::true_type; }; -template ::type, - typename IsMap = typename is_specialized::type> +template ::type, + typename IsMap = typename is_specialized::type> struct ValueTypeInfoMaker {}; template @@ -295,7 +296,7 @@ struct ValueTypeInfoMaker { info.type_index = tindex; info.type_key = runtime::Object::TypeIndex2Key(tindex); return info; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { // special handle string since it can be backed by multiple types. info.type_index = ffi::TypeIndex::kTVMFFIStr; info.type_key = ffi::TypeTraits::TypeStr(); @@ -346,12 +347,12 @@ struct ValueTypeInfoMaker { } // namespace detail template -inline TargetKindAttrMap TargetKind::GetAttrMap(const String& attr_name) { +inline TargetKindAttrMap TargetKind::GetAttrMap(const ffi::String& attr_name) { return TargetKindAttrMap(GetAttrMapContainer(attr_name)); } template -inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name, +inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const ffi::String& attr_name, const ValueType& value, int plevel) { ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; ffi::Any rv; @@ -365,7 +366,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int devic return *this; } -inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector keys) { +inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector keys) { kind_->default_keys = keys; return *this; } @@ -383,7 +384,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParse } template -inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) { +inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const ffi::String& key) { ICHECK(!kind_->key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key << "' has been set once"; kind_->key2vtype_[key] = detail::ValueTypeInfoMaker()(); @@ -391,7 +392,7 @@ inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key } template -inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key, +inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const ffi::String& key, Any default_value) { add_attr_option(key); kind_->key2default_[key] = default_value; @@ -420,8 +421,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { * TVM_REGISTER_TARGET_KIND("llvm") * .set_attr("TPreCodegenPass", a-pre-codegen-pass) * .add_attr_option("system_lib") - * .add_attr_option("mtriple") - * .add_attr_option("mattr"); + * .add_attr_option("mtriple") + * .add_attr_option("mattr"); * * \endcode */ @@ -430,11 +431,11 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \ .set_name() \ .set_default_device_type(DeviceType) \ - .add_attr_option>("keys") \ - .add_attr_option("tag") \ - .add_attr_option("device") \ - .add_attr_option("model") \ - .add_attr_option>("libs") \ + .add_attr_option>("keys") \ + .add_attr_option("tag") \ + .add_attr_option("device") \ + .add_attr_option("model") \ + .add_attr_option>("libs") \ .add_attr_option("host") \ .add_attr_option("from_device") \ .add_attr_option("target_device_type") diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index aabd3a2ecaf2..bb67d96fbe7a 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -39,10 +39,10 @@ namespace tvm { * Abstract label for an area of memory. * * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation - * of a memory pool in the future. Please try to use this alias instead of String to aid future + * of a memory pool in the future. Please try to use this alias instead of ffi::String to aid future * code migration. */ -using MemoryScope = String; +using MemoryScope = ffi::String; // NOTE: cannot use enum as they are out of bound of the original enum // and results in an undefined behavior @@ -333,7 +333,7 @@ class VirtualDevice : public ObjectRef { * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such * join exists, ie there's disagreement on at least one constrained field. */ - static Optional Join(const VirtualDevice& lhs, const VirtualDevice& rhs); + static ffi::Optional Join(const VirtualDevice& lhs, const VirtualDevice& rhs); /*! * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 6c1ea6195f5e..f978c9953cf1 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -60,7 +60,7 @@ class TVM_DLL OperationNode : public Object { /*! \brief optional tag of the operation */ std::string tag; /*! \brief additional attributes of the operation*/ - Map attrs; + ffi::Map attrs; // virtual destructor. virtual ~OperationNode() {} /*! \return number of outputs */ @@ -76,12 +76,12 @@ class TVM_DLL OperationNode : public Object { * \param i The output index. * \return shape of i-th output. */ - virtual Array output_shape(size_t i) const = 0; + virtual ffi::Array output_shape(size_t i) const = 0; /*! * \brief List all the input Tensors. * \return List of input tensors. */ - virtual Array InputTensors() const = 0; + virtual ffi::Array InputTensors() const = 0; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -102,14 +102,14 @@ class TVM_DLL OperationNode : public Object { class PlaceholderOpNode : public OperationNode { public: /*! \brief The shape of the input */ - Array shape; + ffi::Array shape; /*! \brief The data type of the input. */ DataType dtype; // override behavior. int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -129,7 +129,7 @@ class PlaceholderOpNode : public OperationNode { */ class PlaceholderOp : public Operation { public: - TVM_DLL PlaceholderOp(std::string name, Array shape, DataType dtype); + TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, DataType dtype); TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); }; @@ -141,11 +141,11 @@ class PlaceholderOp : public Operation { class TVM_DLL BaseComputeOpNode : public OperationNode { public: /*! \brief IterVar on each axis */ - Array axis; + ffi::Array axis; /*! \brief IterVar on each reduction axis, if the body is a Reduce */ - Array reduce_axis; + ffi::Array reduce_axis; // override functions - Array output_shape(size_t idx) const final; + ffi::Array output_shape(size_t idx) const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -165,13 +165,13 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { class TVM_DLL ComputeOpNode : public BaseComputeOpNode { public: /*! \brief the compute expression */ - Array body; + ffi::Array body; /*! \brief constructor */ ComputeOpNode() {} // override functions int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array InputTensors() const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -189,8 +189,8 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { */ class ComputeOp : public Operation { public: - TVM_DLL ComputeOp(std::string name, std::string tag, Map attrs, - Array axis, Array body); + TVM_DLL ComputeOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array axis, ffi::Array body); TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode); @@ -204,16 +204,16 @@ class ScanOpNode : public OperationNode { /*! \brief IterVar to scan over */ IterVar scan_axis; /*! \brief the initialization tensors */ - Array init; + ffi::Array init; /*! \brief the update function represented by tensor */ - Array update; + ffi::Array update; /*! \brief The placeholder to refer as states in update. */ - Array state_placeholder; + ffi::Array state_placeholder; /*! * \brief the inputs to the scan, these are optionally provided * But they can be helpful to provide hints to speedup get of scan body. */ - Array inputs; + ffi::Array inputs; /*! * \brief Spatial axis to indicate spatial dimension of each output. * They corresponds to flattened spatial axis of the outputs. @@ -223,14 +223,14 @@ class ScanOpNode : public OperationNode { * They do not corresponds to splittable iterations, thus the name comes * with underscore. */ - Array spatial_axis_; + ffi::Array spatial_axis_; /*! \brief constructor */ ScanOpNode() {} // override behavior. int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -254,9 +254,10 @@ class ScanOpNode : public OperationNode { */ class ScanOp : public Operation { public: - TVM_DLL ScanOp(std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array input); + TVM_DLL ScanOp(std::string name, std::string tag, + ffi::Optional> attrs, IterVar axis, + ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array input); TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); }; @@ -267,11 +268,11 @@ class ScanOp : public Operation { class ExternOpNode : public OperationNode { public: /*! \brief The input tensors */ - Array inputs; + ffi::Array inputs; /*! \brief Symbolic placeholder representation of inputs */ - Array input_placeholders; + ffi::Array input_placeholders; /*! \brief Symbolic placeholder representation of outputs */ - Array output_placeholders; + ffi::Array output_placeholders; /*! \brief the statement that generates the computation. */ Stmt body; @@ -280,8 +281,8 @@ class ExternOpNode : public OperationNode { // override functions int num_outputs() const final; DataType output_dtype(size_t i) const final; - Array output_shape(size_t i) const final; - Array InputTensors() const final; + ffi::Array output_shape(size_t i) const final; + ffi::Array InputTensors() const final; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -303,9 +304,9 @@ class ExternOpNode : public OperationNode { */ class ExternOp : public Operation { public: - TVM_DLL ExternOp(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body); + TVM_DLL ExternOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body); TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); }; @@ -334,10 +335,10 @@ TVM_DLL IterVar thread_axis(Range dom, std::string tag); TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function(const Array& i)>; +using FBatchCompute = std::function(const ffi::Array& i)>; /*! * \brief create a place holder tensor. @@ -345,7 +346,7 @@ using FBatchCompute = std::function(const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(ffi::Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! @@ -357,8 +358,8 @@ TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Flo * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", - std::string tag = "", Map attrs = {}); +TVM_DLL Tensor compute(ffi::Array shape, FCompute fcompute, std::string name = "tensor", + std::string tag = "", ffi::Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -369,9 +370,9 @@ TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string nam * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, FBatchCompute fcompute, - std::string name = "tensor", std::string tag = "", - Map attrs = {}); +TVM_DLL ffi::Array compute(ffi::Array shape, FBatchCompute fcompute, + std::string name = "tensor", std::string tag = "", + ffi::Map attrs = {}); /*! * \brief Construct new tensors by scan. @@ -385,34 +386,35 @@ TVM_DLL Array compute(Array shape, FBatchCompute fcompute, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array scan(Array init, Array update, - Array state_placeholder, Array inputs = Array(), - std::string name = "scan", std::string tag = "", - Map attrs = {}); +TVM_DLL ffi::Array scan(ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, + ffi::Array inputs = ffi::Array(), + std::string name = "scan", std::string tag = "", + ffi::Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1], i[2]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, std::function f, +inline Tensor compute(ffi::Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { - FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2], i[3]); }; + ffi::Map attrs = {}) { + FCompute fc = [f](const ffi::Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index f45a96df63d8..8bcad6950f4d 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -69,7 +69,7 @@ class Operation : public ObjectRef { class TensorNode : public DataProducerNode { public: /*! \brief The shape of the tensor */ - Array shape; + ffi::Array shape; /*! \brief data type in the content of the tensor */ DataType dtype; /*! \brief the source operation, can be None */ @@ -79,13 +79,13 @@ class TensorNode : public DataProducerNode { static void RegisterReflection(); - Array GetShape() const final { return shape; } + ffi::Array GetShape() const final { return shape; } DataType GetDataType() const final { return dtype; } TVM_DLL PrimExpr ToPrimExpr() const final; - TVM_DLL String GetNameHint() const final; + TVM_DLL ffi::String GetNameHint() const final; static constexpr const char* _type_key = "te.Tensor"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; @@ -105,10 +105,10 @@ class Tensor : public DataProducer { * \param support_negative_indices Whether to normalize indices in the case of negative indices. * \return the result expression representing tensor read. */ - inline PrimExpr IndexTensor(Array indices, bool support_negative_indices) const; + inline PrimExpr IndexTensor(ffi::Array indices, bool support_negative_indices) const; public: - TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); + TVM_DLL Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index); /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. @@ -130,7 +130,7 @@ class Tensor : public DataProducer { */ template inline PrimExpr operator()(Args&&... args) const { - Array indices{std::forward(args)...}; + ffi::Array indices{std::forward(args)...}; return operator()(indices); } /*! @@ -138,13 +138,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(ffi::Array indices) const; /*! * \brief Take elements from the tensor * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr operator()(Array indices) const; + TVM_DLL PrimExpr operator()(ffi::Array indices) const; /*! * \brief Take elements from the tensor with support for negative indices. * \param args The indices @@ -152,7 +152,7 @@ class Tensor : public DataProducer { */ template TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const { - Array indices{std::forward(args)...}; + ffi::Array indices{std::forward(args)...}; return IndexWithNegativeIndices(indices); } /*! @@ -160,13 +160,13 @@ class Tensor : public DataProducer { * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array indices) const; /*! * \brief Take elements from the tensor with support for negative indices. * \param indices the indices. * \return the result expression representing tensor read. */ - TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array indices) const; /*! * \brief data structure to represent a slice that fixes first k coordinates. diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index a21112b7d6f6..0f4b6afd62fb 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -99,14 +99,14 @@ TVM_DLL double EstimateTIRFlops(const IRModule& mod); * \param defs The vars that is defined. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); +TVM_DLL ffi::Array UndefinedVars(const Stmt& stmt, const ffi::Array& defs); /*! * \brief Find undefined vars in the expression. * \param expr The expression to be checked. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const PrimExpr& expr); +TVM_DLL ffi::Array UndefinedVars(const PrimExpr& expr); /*! * \brief Find undefined vars in the expression. @@ -114,7 +114,7 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); * \param defs The vars that is defined. * \return Array of undefined vars. */ -TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); +TVM_DLL ffi::Array UndefinedVars(const PrimExpr& expr, const ffi::Array& defs); /*! * \brief Analyze the side effect of an expression @@ -195,7 +195,7 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); * \return valid Whether it is a valid GPU code * */ -TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, ffi::Map constraints); /** * @brief Utility function to get the list of lowering passes to be applied to calculate the @@ -203,7 +203,7 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain * * @return returns list of passes */ -TVM_DLL Array GetVTCMCompactionPasses(); +TVM_DLL ffi::Array GetVTCMCompactionPasses(); /*! * \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule @@ -233,8 +233,8 @@ TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit); * - second: write regions * - third: opaque regions */ -TVM_DLL Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL ffi::Array> GetBlockAccessRegion( + const Block& block, const ffi::Map& buffer_var_map); /*! * \brief Auto detect the block read/write region according to its body stmt. An opaque access will @@ -244,8 +244,8 @@ TVM_DLL Array> GetBlockAccessRegion(const Block& block, * It is a map from buffer var to the buffer * \return An array only consisting of the read regions and write regions of the input block */ -TVM_DLL Array> GetBlockReadWriteRegion(const Block& block, - const Map& buffer_var_map); +TVM_DLL ffi::Array> GetBlockReadWriteRegion( + const Block& block, const ffi::Map& buffer_var_map); /*! \brief Helper struct for return value of IdentifyMemCpy * @@ -298,7 +298,8 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, * \return Allocated memory size per scope in bytes inside the PrimFunc returned as a Map with * key "main" and a Map of allocated sizes as values. */ -TVM_DLL tvm::Map> CalculateAllocatedBytes(const PrimFunc& func); +TVM_DLL tvm::ffi::Map> CalculateAllocatedBytes( + const PrimFunc& func); /*! * \brief Calculate the allocated memory per scope in bytes for each function inside the module @@ -306,7 +307,8 @@ TVM_DLL tvm::Map> CalculateAllocatedBytes(cons * \return Allocated memory size per scope in bytes for each function in the IRModule returned as a Map with function names as keys and a Map of allocated sizes as values. */ -TVM_DLL tvm::Map> CalculateAllocatedBytes(const IRModule& mod); +TVM_DLL tvm::ffi::Map> CalculateAllocatedBytes( + const IRModule& mod); /*! * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level @@ -316,7 +318,7 @@ TVM_DLL tvm::Map> CalculateAllocatedBytes(cons * \return The Map from buffer to the LCA of all access to it. The lca is function root if the * return stmt is std::nullopt. */ -TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); +TVM_DLL ffi::Map> DetectBufferAccessLCA(const PrimFunc& func); /*! * \brief Verify if the given TIR is well-formed. The verification includes: @@ -410,7 +412,7 @@ TVM_DLL Pass VerifyMemory(); * \returns The pass. * \sa tvm::tir::VerifyGPUCode */ -TVM_DLL Pass VerifyGPUCode(Map constraints); +TVM_DLL Pass VerifyGPUCode(ffi::Map constraints); /*! * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit @@ -421,7 +423,7 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); * \returns The pass. * \sa tvm::tir::CalculateAllocatedBytes */ -TVM_DLL Pass VerifyVTCMLimit(Optional target = std::nullopt); +TVM_DLL Pass VerifyVTCMLimit(ffi::Optional target = std::nullopt); /*! * \brief Statically check TIR code for out of bounds array access. diff --git a/include/tvm/tir/block_dependence_info.h b/include/tvm/tir/block_dependence_info.h index 7b00894ea805..c5fd72173e3c 100644 --- a/include/tvm/tir/block_dependence_info.h +++ b/include/tvm/tir/block_dependence_info.h @@ -78,7 +78,7 @@ class BlockDependenceInfoNode : public Object { auto it = sref2scope.find(scope_root); CHECK(it != sref2scope.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" - << GetRef(scope_root->stmt); + << ffi::GetRef(scope_root->stmt); return it->second; } }; diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index 9ea77d7b9b46..3fc2515d0812 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -262,11 +262,11 @@ class BlockScopeNode : public Object { * \note We intentionally didn't use tvm::Map as the data structure, because we need the values * inside to be mutable so that they could be further maintained properly during transformations. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; /*! \brief Lookup table for the `dst` of dependencies */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; /*! \brief The mapping from the buffer to the blocks who write it */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; static void RegisterReflection() { // No fields to register as they are not visited @@ -282,13 +282,13 @@ class BlockScopeNode : public Object { * \param src The queried block * \return The dependencies */ - TVM_DLL Array GetDepsBySrc(const StmtSRef& src) const; + TVM_DLL ffi::Array GetDepsBySrc(const StmtSRef& src) const; /*! * \brief Get all dependencies whose `dst` equals `dst` * \param dst The queried block * \return The dependencies */ - TVM_DLL Array GetDepsByDst(const StmtSRef& dst) const; + TVM_DLL ffi::Array GetDepsByDst(const StmtSRef& dst) const; }; /*! @@ -305,7 +305,7 @@ class BlockScope : public ObjectRef { * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL explicit BlockScope(const Array& child_block_srefs); + TVM_DLL explicit BlockScope(const ffi::Array& child_block_srefs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 3cc988f49e38..1ca420e5db2e 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -75,7 +75,7 @@ class BufferNode : public Object { * BufferLoad/BufferStore nodes, and used by the low-level code * generators. */ - Array shape; + ffi::Array shape; /*! * \brief Separators between input axes when generating flattened output axes * @@ -84,17 +84,17 @@ class BufferNode : public Object { * non-flat memory, each entry in axis_separators should be the * first input axis that is part of a new flattened axis. */ - Array axis_separators; + ffi::Array axis_separators; /*! * \brief The strides of each dimension * This can be an empty array, indicating array is contiguous */ - Array strides; + ffi::Array strides; /*! \brief The offset in terms of number of dtype elements (including lanes) */ PrimExpr elem_offset; // Meta data /*! \brief optional name of the buffer */ - String name; + ffi::String name; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -140,7 +140,7 @@ class BufferNode : public Object { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - Array ElemOffset(Array index) const; + ffi::Array ElemOffset(ffi::Array index) const; static constexpr const char* _type_key = "tir.Buffer"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; @@ -158,9 +158,10 @@ class Buffer : public ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Array axis_separators = {}, Span span = Span()); + TVM_DLL Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators = {}, + Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -176,7 +177,7 @@ class Buffer : public ObjectRef { * If stride is not needed in the slice, it won't be presented * \return the result buffer. */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + TVM_DLL Buffer MakeSlice(ffi::Array begins, ffi::Array extents) const; /*! * \brief Get access ptr to the entire buffer. * \param access_mask The access mask @@ -187,7 +188,7 @@ class Buffer : public ObjectRef { */ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), - Optional input_extent = std::nullopt) const; + ffi::Optional input_extent = std::nullopt) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index @@ -195,8 +196,8 @@ class Buffer : public ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype, - Optional predicate = std::nullopt) const; + TVM_DLL PrimExpr vload(ffi::Array begin, DataType dtype, + ffi::Optional predicate = std::nullopt) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index @@ -204,8 +205,8 @@ class Buffer : public ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * stored. The number lanes of the mask must be equal to the number of lanes in value. */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value, - Optional predicate = std::nullopt) const; + TVM_DLL Stmt vstore(ffi::Array begin, PrimExpr value, + ffi::Optional predicate = std::nullopt) const; /*! * \brief Get a flattened version of the buffer @@ -218,12 +219,12 @@ class Buffer : public ObjectRef { * without adjusting for number of lanes. (e.g. The number of * float16x4 elements in a buffer of type float16x4.) */ - Array OffsetOf(Array index) const; + ffi::Array OffsetOf(ffi::Array index) const; /*! * \brief Return the storage scope associated with this buffer. */ - TVM_DLL String scope() const; + TVM_DLL ffi::String scope() const; TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); @@ -240,9 +241,9 @@ class Buffer : public ObjectRef { * \return The created buffer. * \sa Buffer for complete constructor. */ -TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", - Optional> axis_separators = std::nullopt, +TVM_DLL Buffer decl_buffer(ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::String name = "buffer", ffi::String storage_scope = "", + ffi::Optional> axis_separators = std::nullopt, Span span = Span()); /*! @@ -265,7 +266,7 @@ class DataProducerNode : public PrimExprConvertibleNode { * \brief Get the shape of the result. * \return The shape. */ - virtual Array GetShape() const = 0; + virtual ffi::Array GetShape() const = 0; /*! * \brief Get the data type of the result. * \return The data type. @@ -275,7 +276,7 @@ class DataProducerNode : public PrimExprConvertibleNode { * \brief Get the name hint of the data producer. * \return The data type. */ - virtual String GetNameHint() const = 0; + virtual ffi::String GetNameHint() const = 0; static constexpr const char* _type_key = "tir.DataProducer"; TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, PrimExprConvertibleNode); @@ -303,7 +304,7 @@ class DataProducer : public PrimExprConvertible { * \param compact If the statement has already bound to a compact buffer. * \param memory_scope memory scope of the buffer */ -TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, +TVM_DLL tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope = ""); diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index a48a8909c4d3..8cef462b0257 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -298,7 +298,7 @@ TVM_DLL const Op& tvm_struct_set(); /*! * \brief See pseudo code - * Type lookup_param(String param_name) { + * Type lookup_param(ffi::String param_name) { * return __tvm_param__param_name; * } */ diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 1395c2b6817b..f6f1582517d0 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -99,14 +99,14 @@ class LayoutAxis { class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ - String name; + ffi::String name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, * it is a variable for a primal axis, but a constant for a subordinate axis. * Empty for scalar's layout. */ - Array axes; + ffi::Array axes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -125,10 +125,10 @@ class LayoutNode : public Object { */ class Layout : public ObjectRef { public: - explicit Layout(const Array& axes); + explicit Layout(const ffi::Array& axes); /*! \brief construct from a string */ - Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) + Layout(const tvm::ffi::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) @@ -300,13 +300,13 @@ class BijectiveLayoutNode : public Object { /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n */ - Array index_forward_rule; + ffi::Array index_forward_rule; /*! \brief Describes how destination axes can be mapped to the source axes */ - Array index_backward_rule; + ffi::Array index_backward_rule; /*! \brief Describes how source shapes can be mapped to the destination shapes */ - Array shape_forward_rule; + ffi::Array shape_forward_rule; /*! \brief Describes how destination shapes can be mapped to the source shapes */ - Array shape_backward_rule; + ffi::Array shape_backward_rule; /*! \brief The source layout */ Layout src_layout; @@ -344,13 +344,13 @@ class BijectiveLayout : public ObjectRef { TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout); // Given the source shape, infer the destination shape. - TVM_DLL Array ForwardShape(const Array& shape) const; + TVM_DLL ffi::Array ForwardShape(const ffi::Array& shape) const; // Given the destination shape, recover the source shape. - TVM_DLL Array BackwardShape(const Array& dst_shape) const; + TVM_DLL ffi::Array BackwardShape(const ffi::Array& dst_shape) const; // Given the destination indices, infer the destination indices. - TVM_DLL Array ForwardIndex(const Array& index) const; + TVM_DLL ffi::Array ForwardIndex(const ffi::Array& index) const; // Given the destination indices, recover the source indices. - TVM_DLL Array BackwardIndex(const Array& dst_index) const; + TVM_DLL ffi::Array BackwardIndex(const ffi::Array& dst_index) const; TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index a9185e97af69..88398cf06f06 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -106,7 +106,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const AttrStmtNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; - Array VisitIndices(Array indices); + ffi::Array VisitIndices(ffi::Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; @@ -124,7 +124,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Buffer VisitBuffer(const Buffer& buffer); Buffer GetRemappedBuffer(const Buffer& buffer); - Map VisitBlockAnnotations(const Map& annotations); + ffi::Map VisitBlockAnnotations( + const ffi::Map& annotations); BufferRegion VisitBufferRegion(const BufferRegion& region); IterVar VisitIterVar(const IterVar& iter_var); // indicator of index expr to rewrite @@ -132,7 +133,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { // indicator of condition bool is_condition_{false}; - Map buffer_remap_; + ffi::Map buffer_remap_; }; /*! diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1b419b569311..24946332e5a2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -49,11 +49,11 @@ namespace tir { using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; -/*! \brief String constants, only used in asserts. */ +/*! \brief ffi::String constants, only used in asserts. */ class StringImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ - String value; + ffi::String value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -70,7 +70,7 @@ class StringImmNode : public PrimExprNode { */ class StringImm : public PrimExpr { public: - TVM_DLL StringImm(String value, Span span = Span()); + TVM_DLL StringImm(ffi::String value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; @@ -543,9 +543,9 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The buffer variable. */ Buffer buffer; /*! \brief The indices location to be loaded. */ - Array indices; + ffi::Array indices; /*! \brief The predicate mask for loading values. */ - Optional predicate; + ffi::Optional predicate; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -581,8 +581,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, - Optional predicate = std::nullopt, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array indices, + ffi::Optional predicate = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; @@ -601,7 +601,7 @@ class ProducerLoadNode : public PrimExprNode { /*! \brief The buffer producer. */ DataProducer producer; /*! \brief The location arguments. */ - Array indices; + ffi::Array indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -620,7 +620,8 @@ class ProducerLoadNode : public PrimExprNode { */ class ProducerLoad : public PrimExpr { public: - TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices, Span span = Span()); + TVM_DLL explicit ProducerLoad(DataProducer producer, ffi::Array indices, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode); @@ -746,7 +747,7 @@ class CallNode : public PrimExprNode { RelaxExpr op; /*! \brief The arguments. */ - Array args; + ffi::Array args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -763,7 +764,7 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelaxExpr op, Array args, Span span = Span()); + TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; @@ -776,9 +777,9 @@ class Call : public PrimExpr { class ShuffleNode : public PrimExprNode { public: /*! \brief the input vectors. */ - Array vectors; + ffi::Array vectors; /*! \brief The indices of each element. */ - Array indices; + ffi::Array indices; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -797,8 +798,8 @@ class ShuffleNode : public PrimExprNode { */ class Shuffle : public PrimExpr { public: - TVM_DLL Shuffle(Array vectors, Array indices, Span span = Span()); - TVM_DLL static PrimExpr Concat(Array vectors, Span span = Span()); + TVM_DLL Shuffle(ffi::Array vectors, ffi::Array indices, Span span = Span()); + TVM_DLL static PrimExpr Concat(ffi::Array vectors, Span span = Span()); TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); @@ -813,19 +814,19 @@ class Shuffle : public PrimExpr { class CommReducerNode : public Object { public: /*! \brief The left argument of reducer */ - Array lhs; + ffi::Array lhs; /*! \brief The right argument of reducer */ - Array rhs; + ffi::Array rhs; /*! \brief The result of reducer */ - Array result; + ffi::Array result; /*! * \brief The identity element of reducer, which leaves other * elements unchanged when combined with it, with respect to * the binary operation of this reducer uses. */ - Array identity_element; + ffi::Array identity_element; /*! \brief Function call operator to combine a and b */ - Array operator()(Array a, Array b) const; + ffi::Array operator()(ffi::Array a, ffi::Array b) const; /*! * \brief Span that points to the original source code. * Reserved debug information. @@ -853,8 +854,8 @@ class CommReducerNode : public Object { */ class CommReducer : public ObjectRef { public: - TVM_DLL CommReducer(Array lhs, Array rhs, Array result, - Array identity_element, Span span = Span()); + TVM_DLL CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); }; @@ -865,11 +866,11 @@ class ReduceNode : public PrimExprNode { /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ - Array source; + ffi::Array source; /*! \brief The init operand */ - Array init; + ffi::Array init; /*! \brief The reduction axis */ - Array axis; + ffi::Array axis; /*! * \brief Predicate on the reduction * Only add the body to reduction if condition is true. @@ -899,8 +900,9 @@ class ReduceNode : public PrimExprNode { */ class Reduce : public PrimExpr { public: - TVM_DLL Reduce(CommReducer combiner, Array src, Array rdom, PrimExpr condition, - int value_index, Array init, Span span = Span()); + TVM_DLL Reduce(CommReducer combiner, ffi::Array src, ffi::Array rdom, + PrimExpr condition, int value_index, ffi::Array init, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); @@ -915,7 +917,7 @@ class Reduce : public PrimExpr { * \tparam V the value of the Map. */ template -inline std::unordered_map as_unordered_map(const Map& dmap) { +inline std::unordered_map as_unordered_map(const ffi::Map& dmap) { std::unordered_map ret; for (auto kv : dmap) { ret[kv.first] = kv.second; @@ -931,8 +933,8 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(String value) { + : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(ffi::String value) { return tvm::tir::StringImm(value); } }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 21a97f986d4f..5e46a5c2c1dd 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -48,7 +48,7 @@ namespace tir { class PrimFuncNode : public BaseFuncNode { public: /*! \brief Function parameters */ - Array params; + ffi::Array params; /*! \brief The return type of the function. */ Type ret_type; /*! @@ -96,7 +96,7 @@ class PrimFuncNode : public BaseFuncNode { * all usage in the body of the function is done through a * flattened alias of the buffer. */ - Map buffer_map; + ffi::Map buffer_map; /*! \brief The body of the function */ tir::Stmt body; @@ -148,8 +148,8 @@ class PrimFunc : public BaseFunc { * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), + TVM_DLL PrimFunc(ffi::Array params, Stmt body, Type ret_type = VoidType(), + ffi::Map buffer_map = ffi::Map(), DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); @@ -198,7 +198,7 @@ class TensorIntrin : public ObjectRef { * \throws This method throws an exception if the TensorIntrin with the specified name already * exists. */ - TVM_DLL static void Register(String name, TensorIntrin intrin, bool override = false); + TVM_DLL static void Register(ffi::String name, TensorIntrin intrin, bool override = false); /*! * \brief Look up TensorIntrin by name. Raises an exception if not found. @@ -209,7 +209,7 @@ class TensorIntrin : public ObjectRef { * \throws This method throws an exception if the TensorIntrin does not exist and allow_missing is * false. */ - TVM_DLL static Optional Get(String name, bool allow_missing = false); + TVM_DLL static ffi::Optional Get(ffi::String name, bool allow_missing = false); TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); }; @@ -252,7 +252,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map>& param_map); +PrimFunc Specialize(PrimFunc func, const ffi::Map>& param_map); /*! * \brief PrimFunc specific attribute names. @@ -264,7 +264,7 @@ namespace attr { /*! * \brief List of thread IterVar that a DeviceLaunch function corresponds to. * - * Type: Array + * Type: ffi::Array * * We call a device kernel launch function f using the following convention: * diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 7c8c9c30c7b5..ef6aa81e0578 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -56,7 +56,7 @@ class IndexMapNode : public Object { * If initial_indices is empty, then final_indices should also be * empty, and no mapping is applied. */ - Array initial_indices; + ffi::Array initial_indices; /*! * \brief Expressions defining the indices after remapping. @@ -68,7 +68,7 @@ class IndexMapNode : public Object { * If final_indices is empty, then initial_indices should also be * empty, and the map is an identity function. */ - Array final_indices; + ffi::Array final_indices; /*! * \brief The inverse index map. @@ -80,7 +80,7 @@ class IndexMapNode : public Object { * * \note ObjectRef is used here instead of IndexMap to avoid circular reference. */ - Optional inverse_index_map; + ffi::Optional inverse_index_map; /*! * \brief Default constructor @@ -102,7 +102,8 @@ class IndexMapNode : public Object { * \returns The indices in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapIndices(const Array& indices, arith::Analyzer* analyzer) const; + ffi::Array MapIndices(const ffi::Array& indices, + arith::Analyzer* analyzer) const; /*! \brief Map a memory range to the output space * @@ -120,7 +121,7 @@ class IndexMapNode : public Object { * \returns The ranges in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapRanges(const Array& ranges, arith::Analyzer* analyzer) const; + ffi::Array MapRanges(const ffi::Array& ranges, arith::Analyzer* analyzer) const; /*! \brief Map a buffer shape to the output space * @@ -133,7 +134,7 @@ class IndexMapNode : public Object { * \returns The buffer shape in the output space. Contains one * value for each expression in `final_indices`. */ - Array MapShape(const Array& shape, arith::Analyzer* analyzer) const; + ffi::Array MapShape(const ffi::Array& shape, arith::Analyzer* analyzer) const; /* \brief Map an Tensor according to this index map * @@ -148,8 +149,8 @@ class IndexMapNode : public Object { * \param f_name_map Optional function to specify the stringified name of the variables. * \return The stringified lambda expression in Python. */ - String ToPythonString( - const std::function(const Var& var)>& f_name_map = nullptr) const; + ffi::String ToPythonString( + const std::function(const Var& var)>& f_name_map = nullptr) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -174,8 +175,8 @@ class IndexMap : public ObjectRef { * \param final_indices Expressions defining the indices after remapping. * \param inverse_index_map The optional pre-defined inverse index map */ - IndexMap(Array initial_indices, Array final_indices, - Optional inverse_index_map = std::nullopt); + IndexMap(ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map = std::nullopt); /*! * \brief Create an index map from a packed function @@ -184,8 +185,8 @@ class IndexMap : public ObjectRef { * \param inverse_index_map The optional pre-defined inverse index map * \return The created index map */ - static IndexMap FromFunc(int ndim, ffi::TypedFunction(Array)> func, - Optional inverse_index_map = std::nullopt); + static IndexMap FromFunc(int ndim, ffi::TypedFunction(ffi::Array)> func, + ffi::Optional inverse_index_map = std::nullopt); /*! \brief Generate the inverse mapping. * @@ -195,7 +196,7 @@ class IndexMap : public ObjectRef { * If the user has supplied an `inverse_index_map`, that map is * assumed to be correct and bijective, and is returned. */ - IndexMap Inverse(Array initial_ranges, arith::Analyzer* analyzer) const; + IndexMap Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const; /*! \brief Rename the variables in the index map and ensure the names are unique. * @@ -206,7 +207,7 @@ class IndexMap : public ObjectRef { * \return The renamed index map. */ IndexMap RenameVariables( - const std::function(const Var& var)>& f_name_map = nullptr) const; + const std::function(const Var& var)>& f_name_map = nullptr) const; /*! \brief Generate the inverse mapping. * @@ -217,7 +218,7 @@ class IndexMap : public ObjectRef { * \return The inverted index map, along with the predicate for * which the inverse maps to a valid range. */ - std::pair NonSurjectiveInverse(Array initial_ranges, + std::pair NonSurjectiveInverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const; TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); @@ -229,7 +230,7 @@ class IndexMap : public ObjectRef { * \param f_subst The substitution function */ IndexMap Substitute(const IndexMap& index_map, - std::function(const Var& var)> f_subst); + std::function(const Var& var)> f_subst); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3dda3f7c63c5..e1be6834fe2b 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -566,7 +566,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span()); * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr sum(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -576,7 +576,7 @@ TVM_DLL PrimExpr sum(PrimExpr source, Array axis, Array * \param init The value with which to initialize the output. * \param span The location of this operation in the source. */ -TVM_DLL PrimExpr all(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr all(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -587,7 +587,7 @@ TVM_DLL PrimExpr all(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr any(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr any(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -598,7 +598,7 @@ TVM_DLL PrimExpr any(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr max(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr max(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -609,7 +609,7 @@ TVM_DLL PrimExpr max(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array init = {}, +TVM_DLL PrimExpr min(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -620,8 +620,8 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr prod(PrimExpr source, Array axis, Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr prod(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief Calculate floor(x) @@ -883,7 +883,7 @@ inline bool is_const_number(const PrimExpr& x); * \tparam FReduce The type of the reduction. */ template -inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values, +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array& values, Span span = Span()) { for (PrimExpr val : values) { init_value = freduce(init_value, val, span); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 883477dd645e..c87ccd741a5e 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -39,7 +39,7 @@ namespace tir { /*! * \brief Global symbol of the op after lowering. */ -using TGlobalSymbol = String; +using TGlobalSymbol = ffi::String; /*! * \brief Whether the op is overloaded for vector form. @@ -59,7 +59,7 @@ using FLegalize = ffi::TypedFunction; /*! * \brief The operator's name in TVMScript printer */ -using TScriptPrinterName = String; +using TScriptPrinterName = ffi::String; /*! * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 146d3e8ec9bb..aff2912a88e3 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -42,8 +42,9 @@ class Schedule; * \param decision Decisions made on the instruction * \return The functor returns an array of output random variables */ -using FInstructionApply = ffi::TypedFunction( - Schedule sch, const Array& inputs, const Array& attrs, const Any& decision)>; +using FInstructionApply = + ffi::TypedFunction(Schedule sch, const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision)>; /*! * \brief Type of the functor that converts the instruction to a statement in python syntax @@ -54,8 +55,8 @@ using FInstructionApply = ffi::TypedFunction( * \return A string representing the python api call */ using FInstructionAsPython = - ffi::TypedFunction& inputs, const Array& attrs, - const Any& decision, const Array& outputs)>; + ffi::TypedFunction& inputs, const ffi::Array& attrs, + const Any& decision, const ffi::Array& outputs)>; /*! * \brief Type of the functor that serialize its attributes to JSON @@ -63,7 +64,7 @@ using FInstructionAsPython = * \return An array, serialized attributes * \note This functor is nullable */ -using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; +using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; /*! * \brief Type of the functor that deserialize its attributes from JSON @@ -71,7 +72,7 @@ using FInstructionAttrsAsJSON = ffi::TypedFunction attrs)>; * \return An array, deserialized attributes * \note This functor is nullable */ -using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_attrs)>; +using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_attrs)>; /*! * \brief Kind of an instruction, e.g. Split, Reorder, etc. @@ -88,7 +89,7 @@ using FInstructionAttrsFromJSON = ffi::TypedFunction(ObjectRef json_a class InstructionKindNode : public runtime::Object { public: /*! \brief The name of a kind of instructions */ - String name; + ffi::String name; /*! * \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule * state. For example, the instruction `GetBlock` is pure because it changes @@ -136,7 +137,7 @@ class InstructionKind : public runtime::ObjectRef { * \param name The registered name of the InstructionKind * \return The InstructionKind retrieved */ - static InstructionKind Get(const String& name); + static InstructionKind Get(const ffi::String& name); TVM_DEFINE_OBJECT_REF_METHODS(InstructionKind, runtime::ObjectRef, InstructionKindNode); }; @@ -156,20 +157,20 @@ class InstructionNode : public runtime::Object { * - String * - null pointer */ - Array inputs; + ffi::Array inputs; /*! * \brief The attributes of the instruction. Similar to attributes of an operator, * attributes of an instruction are arbitrary constant metadata required by the instructions. * For example, the name of the block to be retrieved in `GetBlock`. */ - Array attrs; + ffi::Array attrs; /*! \brief The output random variables of the instruction, and the type of each element can be one * of the following: * - BlockRV * - LoopRV * - ExprRV, atomic variables only, won't be constants or composite PrimExpr */ - Array outputs; + ffi::Array outputs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -197,8 +198,8 @@ class Instruction : public runtime::ObjectRef { * \param attrs The attributes of the instruction * \param outputs The output random variables of the instruction */ - explicit Instruction(InstructionKind kind, Array inputs, Array attrs, - Array outputs); + explicit Instruction(InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs); TVM_DEFINE_OBJECT_REF_METHODS(Instruction, runtime::ObjectRef, InstructionNode); }; @@ -235,7 +236,7 @@ class Instruction : public runtime::ObjectRef { /*! \brief An entry in the registry of InstructionKind */ class InstructionKindRegEntry { public: - static InstructionKindRegEntry& RegisterOrGet(const String& name); + static InstructionKindRegEntry& RegisterOrGet(const ffi::String& name); InstructionKindRegEntry& set_name() { get_mutable()->name = this->name; @@ -276,7 +277,7 @@ class InstructionKindRegEntry { } /*! \brief The name of the registry entry */ - String name; + ffi::String name; /*! \brief The instruction kind */ InstructionKind inst_kind_; template diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9fbb9981e55c..38003fc37e7b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -120,9 +120,9 @@ class ScheduleNode : public runtime::Object { /*! \return The internal state of scheduling */ virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ - virtual Optional trace() const = 0; + virtual ffi::Optional trace() const = 0; /*! \return The GlobalVar of the func that the schedule is currently working on */ - virtual Optional func_working_on() const = 0; + virtual ffi::Optional func_working_on() const = 0; /*! * \brief Instruct the schedule to work on a function in the IRModule. * @@ -137,7 +137,7 @@ class ScheduleNode : public runtime::Object { * * \sa GetBlock */ - virtual void WorkOn(const String& func_name) = 0; + virtual void WorkOn(const ffi::String& func_name) = 0; /*! * \brief Returns a copy of the schedule, including both its state and its symbol table, * guaranteeing that @@ -230,8 +230,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) = 0; + virtual ExprRV SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision = std::nullopt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled @@ -240,8 +241,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return A list of length `n`, the random perfect tile sizes sampled */ - virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) = 0; + virtual ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) = 0; /*! * \brief Sample the factors to a partitioned tile for a specific loop * @@ -257,9 +259,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return A list of length `n`, the random partitioned tile sizes sampled */ - virtual Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) = 0; + virtual ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) = 0; /*! * \brief Sample a compute-at location of the given block * \param block_rv The block whose compute-at location is to be sampled @@ -267,7 +269,7 @@ class ScheduleNode : public runtime::Object { * \return The sampled loop where the input block is to be computed at */ virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) = 0; + ffi::Optional decision = std::nullopt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -284,40 +286,40 @@ class ScheduleNode : public runtime::Object { * * \sa WorkOn */ - virtual BlockRV GetBlock(const String& name, - const Optional& func_name = std::nullopt) = 0; + virtual BlockRV GetBlock(const ffi::String& name, + const ffi::Optional& func_name = std::nullopt) = 0; /*! * \brief Get the parent loops of the block in its scope, from outer to inner * \param block_rv The query block * \return A list of loops above the given block in its scope, from outer to inner */ - virtual Array GetLoops(const BlockRV& block_rv) = 0; + virtual ffi::Array GetLoops(const BlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of a specific scope * \param block_rv The block where the scope is rooted * \return A list of child blocks */ - virtual Array GetChildBlocks(const BlockRV& block_rv) = 0; + virtual ffi::Array GetChildBlocks(const BlockRV& block_rv) = 0; /*! * \brief Get the leaf blocks of under a specific loop * \param loop_rv The loop under which collecting is conducted * \return A list of child blocks */ - virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; + virtual ffi::Array GetChildBlocks(const LoopRV& loop_rv) = 0; /*! * \brief Get the producer of a specific block, under the same block scope * \param block_rv The block in the query * \return A list of blocks, the producers of the given block under the same scope of the given * block */ - virtual Array GetProducers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetProducers(const BlockRV& block_rv) = 0; /*! * \brief Get the consumers of a specific block, under the same block scope * \param block_rv The block to be queried * \return A list of blocks, the consumers of the given block under the same scope of the given * block */ - virtual Array GetConsumers(const BlockRV& block_rv) = 0; + virtual ffi::Array GetConsumers(const BlockRV& block_rv) = 0; /*! * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written @@ -326,7 +328,7 @@ class ScheduleNode : public runtime::Object { * \return A list of all blocks that write to some output buffer * block */ - virtual Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; + virtual ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: @@ -337,7 +339,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rvs The loops to be merged * \return The new loop after merge */ - virtual LoopRV Merge(const Array& loop_rvs) = 0; + virtual LoopRV Merge(const ffi::Array& loop_rvs) = 0; /*! * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. @@ -348,7 +350,7 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new loop after fusion */ - virtual LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters = true) = 0; + virtual LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters = true) = 0; /*! * \brief Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. @@ -361,9 +363,10 @@ class ScheduleNode : public runtime::Object { * schedule writer knows are divisible by the loop bound. Warning: enabling this feature may * result in incorrect code generation if not used carefully. \return The new loops after split. */ - virtual Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true, - bool disable_predication = false) = 0; + virtual ffi::Array Split(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters = true, + bool disable_predication = false) = 0; /*! * \brief Partition the loops into sequence of multiple loops * 1) The loop can't have annotation or thread binding. @@ -373,8 +376,9 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new loops after partition */ - virtual Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true) = 0; + virtual ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters = true) = 0; /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: @@ -387,13 +391,14 @@ class ScheduleNode : public runtime::Object { * 4) No duplicated loops are allowed in the arguments. * \param ordered_loop_rvs The loops in the new order */ - virtual void Reorder(const Array& ordered_loop_rvs) = 0; + virtual void Reorder(const ffi::Array& ordered_loop_rvs) = 0; /*! * \brief Reorder the itervars inside a block. * \param block_rv The block to be transformed. * \param new_order The new itervar order. */ - virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) = 0; + virtual void ReorderBlockIterVar(const BlockRV& block_rv, + const ffi::Array new_order) = 0; /*! * \brief Create a new unit loop on top of the specific block. * \param block_rv The block above which the new loop is created @@ -438,7 +443,7 @@ class ScheduleNode : public runtime::Object { * \param loop_rv The loop to be bound to the thread axis * \param thread_axis The thread axis to be bound to the loop */ - virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; + virtual void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) = 0; /*! * \brief Unroll the input loop. It requires nothing * \param loop_rv The loop to be unrolled @@ -456,8 +461,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}) = 0; + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -469,8 +474,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}) = 0; + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) = 0; /*! * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. @@ -484,7 +489,7 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) = 0; + const ffi::String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -498,7 +503,8 @@ class ScheduleNode : public runtime::Object { * \return The cache stage block. */ virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) = 0; + const ffi::String& storage_scope, + const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the target block both read & write the target buffer. @@ -507,8 +513,8 @@ class ScheduleNode : public runtime::Object { * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ - virtual Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) = 0; + virtual ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) = 0; /*! * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. @@ -517,8 +523,8 @@ class ScheduleNode : public runtime::Object { * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage blocks. */ - virtual Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) = 0; + virtual ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the @@ -534,9 +540,9 @@ class ScheduleNode : public runtime::Object { BufferIndexType buffer_index_type) = 0; /******** Schedule: Data movement ********/ virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) = 0; + const ffi::String& storage_scope) = 0; virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) = 0; + const ffi::String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -661,7 +667,8 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's write region * \param storage_scope The storage scope to be set */ - virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; + virtual void SetScope(const BlockRV& block_rv, int buffer_index, + const ffi::String& storage_scope) = 0; /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index @@ -671,7 +678,8 @@ class ScheduleNode : public runtime::Object { * \param buffer_index the index of the buffer in block's write region * \param dtype The data type to be set */ - virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0; + virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, + const ffi::String& dtype) = 0; /******** Schedule: Blockize & Tensorize ********/ /*! * \brief Convert the subtree rooted at a specific loop into a block. @@ -686,14 +694,14 @@ class ScheduleNode : public runtime::Object { * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return the new block */ - virtual BlockRV Blockize(const Array& blocks, bool preserve_unit_iters = true) = 0; + virtual BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. * \param loop_rv The loop to be tensorized * \param intrin Name of the tensor intrinsic * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const LoopRV& loop_rv, const String& intrin, + virtual void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters = true) = 0; /*! * \brief Tensorize the computation enclosed by loop with the tensor intrin. @@ -701,7 +709,7 @@ class ScheduleNode : public runtime::Object { * \param intrin Name of the tensor intrinsic * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings */ - virtual void Tensorize(const BlockRV& block_rv, const String& intrin, + virtual void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters = true) = 0; /******** Schedule: Annotation ********/ @@ -711,26 +719,27 @@ class ScheduleNode : public runtime::Object { * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ - virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) = 0; + virtual void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) = 0; /*! * \brief Annotate a block with a key value pair * \param block_rv The block to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ - virtual void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) = 0; + virtual void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, + const Any& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key * \param loop_rv The loop to be unannotated * \param ann_key The annotation key */ - virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; + virtual void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) = 0; /*! * \brief Unannotate a block's annotation with key ann_key * \param block_rv The block to be unannotated * \param ann_key The annotation key */ - virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; + virtual void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) = 0; /******** Schedule: Layout transformation ********/ /*! @@ -766,7 +775,7 @@ class ScheduleNode : public runtime::Object { */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value = std::nullopt, + const ffi::Optional& pad_value = std::nullopt, bool assume_injective_transform = false) = 0; /*! @@ -789,7 +798,7 @@ class ScheduleNode : public runtime::Object { */ virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) = 0; + const ffi::Array& axis_separators) = 0; /******** Schedule: Padding ********/ /*! @@ -818,7 +827,7 @@ class ScheduleNode : public runtime::Object { * The size of the producer buffers are infered from the padding size of the Einsum computation. * The producer buffers are padded by the initial value of the corresponding reduction. */ - virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; + virtual void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) = 0; /******** Schedule: Buffer transformation ********/ /*! @@ -858,8 +867,8 @@ class ScheduleNode : public runtime::Object { * \param buf_type The buffer type: read/write * \param buf_index_array The array of buffer indices we hide access. */ - virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) = 0; + virtual void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) = 0; }; /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 99994d2bf68a..8cb0053df79c 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -147,7 +147,7 @@ class ScheduleStateNode : public Object { * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, - const Map& block_sref_reuse); + const ffi::Map& block_sref_reuse); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index 6e3dd29551ef..b20e070daf88 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -37,8 +37,8 @@ class Trace; * \return A new decision */ using FTraceDecisionProvider = - ffi::TypedFunction& inputs, - const Array& attrs, const Any& decision)>; + ffi::TypedFunction& inputs, + const ffi::Array& attrs, const Any& decision)>; /*! * \brief An execution trace of a scheduling program @@ -58,9 +58,9 @@ using FTraceDecisionProvider = class TraceNode : public runtime::Object { public: /*! \brief The instructions invoked so far in the program execution */ - Array insts; + ffi::Array insts; /*! \brief The random decisions made upon those instructions */ - Map decisions; + ffi::Map decisions; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -89,14 +89,14 @@ class TraceNode : public runtime::Object { * \param inst The new instruction to be appended * \param decision The random decision made on this instruction * The type of `decision` depends on the instruction, e.g. - * the decision of `SamplePerfectTile` has type `Array` + * the decision of `SamplePerfectTile` has type `ffi::Array` */ void Append(Instruction inst, Any decision); /*! * \brief Remove the last instruction, along with the decision made on that instruction, if any * \return The instruction removed; std::nullopt if the trace is empty */ - Optional Pop(); + ffi::Optional Pop(); /*! * \brief Apply the trace to a TensorIR schedule * \param sch The schedule to be applied onto @@ -118,7 +118,7 @@ class TraceNode : public runtime::Object { * \param remove_postproc If postprocessing instructions are removed * \return A sequence of python statements */ - Array AsPython(bool remove_postproc) const; + ffi::Array AsPython(bool remove_postproc) const; /*! * \brief Create a new trace with an instruction whose decision is changed, * assuming this instruction exists in the resulting trace @@ -149,7 +149,7 @@ class Trace : public runtime::ObjectRef { * \param insts The instructions used * \param decisions The decisions made in sampling */ - explicit Trace(Array insts, Map decisions); + explicit Trace(ffi::Array insts, ffi::Map decisions); /*! * \brief Apply a JSON-serialized trace to a TensorIR schedule * \param json The JSON-serialized trace diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b8c7ea594abe..705359118d68 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -117,7 +117,7 @@ class AttrStmtNode : public StmtNode { /*! \brief this is attribute about certain node */ ffi::Any node; /*! \brief the type key of the attribute */ - String attr_key; + ffi::String attr_key; /*! \brief The attribute value, value is well defined at current scope. */ PrimExpr value; /*! \brief The body statement to be executed */ @@ -142,7 +142,8 @@ class AttrStmtNode : public StmtNode { */ class AttrStmt : public Stmt { public: - TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); + TVM_DLL AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); @@ -204,9 +205,9 @@ class BufferStoreNode : public StmtNode { /*! \brief The value to be stored. */ PrimExpr value; /*! \brief The indices location to be stored. */ - Array indices; + ffi::Array indices; /*! \brief The predicate mask for storing values. */ - Optional predicate; + ffi::Optional predicate; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -227,8 +228,9 @@ class BufferStoreNode : public StmtNode { */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = std::nullopt, Span span = Span()); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate = std::nullopt, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); @@ -250,7 +252,7 @@ class BufferRealizeNode : public StmtNode { /*! \brief The buffer variable. */ Buffer buffer; /*! \brief Bounds to be realized */ - Array bounds; + ffi::Array bounds; /*! \brief Only realize if condition holds. */ PrimExpr condition; /*! \brief The body of realization. */ @@ -266,7 +268,7 @@ class BufferRealizeNode : public StmtNode { } BufferRealizeNode() = default; - BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + BufferRealizeNode(Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span = Span()) : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {} @@ -280,8 +282,8 @@ class BufferRealizeNode : public StmtNode { */ class BufferRealize : public Stmt { public: - TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, - Span span = Span()); + TVM_DLL explicit BufferRealize(Buffer buffer, ffi::Array bounds, PrimExpr condition, + Stmt body, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); @@ -297,7 +299,7 @@ class AllocateNode : public StmtNode { /*! \brief The type of the buffer. */ DataType dtype; /*! \brief The extents of the buffer. */ - Array extents; + ffi::Array extents; /*! \brief Only allocate buffer when condition is satisfied. */ PrimExpr condition; /*! \brief The body to be executed. */ @@ -308,7 +310,7 @@ class AllocateNode : public StmtNode { * These annotations can be used as auxiliary hint * to future transformations. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -333,7 +335,7 @@ class AllocateNode : public StmtNode { * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); + TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array& extents); static constexpr const char* _type_key = "tir.Allocate"; @@ -346,8 +348,9 @@ class AllocateNode : public StmtNode { */ class Allocate : public Stmt { public: - TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations = Map(), + TVM_DLL Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, + Stmt body, + ffi::Map annotations = ffi::Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); @@ -363,16 +366,16 @@ class AllocateConstNode : public StmtNode { Var buffer_var; /*! \brief The optional data associated to the constant. */ - Optional data; + ffi::Optional data; /*! * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index - * to indicate the index within "constants" attribute, that is a Array of IRModule. + * to indicate the index within "constants" attribute, that is a ffi::Array of IRModule. */ - Optional irmod_storage_idx; + ffi::Optional irmod_storage_idx; /*! \brief The type of the buffer. */ DataType dtype; /*! \brief The extents of the buffer. */ - Array extents; + ffi::Array extents; /*! \brief The body to be executed. */ Stmt body; /*! @@ -381,7 +384,7 @@ class AllocateConstNode : public StmtNode { * These annotations can be used as auxiliary hint * to future transformations. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -407,7 +410,7 @@ class AllocateConstNode : public StmtNode { * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int64_t ConstantAllocationSize(const Array& extents); + TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array& extents); static constexpr const char* _type_key = "tir.AllocateConst"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); @@ -423,10 +426,10 @@ class AllocateConst : public Stmt { * depending on the type of ObjectRef, it will either * create AllocateConstNode with irmod_storage_idx or data */ - TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, - Map annotations = Map(), - Span span = Span()); + TVM_DLL AllocateConst( + Var buffer_var, DataType dtype, ffi::Array extents, ObjectRef data_or_idx, + Stmt body, ffi::Map annotations = ffi::Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode); }; @@ -465,7 +468,7 @@ class DeclBuffer : public Stmt { class SeqStmtNode : public StmtNode { public: /*! \brief internal sequence content. */ - Array seq; + ffi::Array seq; /*! \return get the size of the sequence */ size_t size() const { return seq.size(); } @@ -525,7 +528,7 @@ class SeqStmt : public Stmt { * \param seq The sequence. * \param span The location of this object in the source code. */ - TVM_DLL explicit SeqStmt(Array seq, Span span = Span()); + TVM_DLL explicit SeqStmt(ffi::Array seq, Span span = Span()); /*! \return get the size of the sequence */ size_t size() const { return operator->()->size(); } @@ -555,7 +558,7 @@ class SeqStmt : public Stmt { */ template static Stmt Flatten(Args&&... seq_args) { - Array seq; + ffi::Array seq; ffi::details::for_each(Flattener(&seq), std::forward(seq_args)...); @@ -593,10 +596,10 @@ class SeqStmt : public Stmt { /*! \brief Helper class to flatten sequence of arguments into Array. */ class Flattener { public: - explicit Flattener(Array* seq) : seq_(seq) {} + explicit Flattener(ffi::Array* seq) : seq_(seq) {} template - static Optional AsSeqStmt(const T& t) { + static ffi::Optional AsSeqStmt(const T& t) { if constexpr (std::is_same_v) { return t; } @@ -605,7 +608,7 @@ class SeqStmt : public Stmt { } if constexpr (std::is_base_of_v) { if (const SeqStmtNode* ptr = t.template as()) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -661,7 +664,7 @@ class SeqStmt : public Stmt { } private: - Array* seq_; + ffi::Array* seq_; }; TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); @@ -678,7 +681,7 @@ class IfThenElseNode : public StmtNode { /*! \brief The branch to be executed when condition is true. */ Stmt then_case; /*! \brief The branch to be executed when condition is false, can be null. */ - Optional else_case; + ffi::Optional else_case; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -698,8 +701,8 @@ class IfThenElseNode : public StmtNode { */ class IfThenElse : public Stmt { public: - TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case = std::nullopt, - Span span = Span()); + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, + ffi::Optional else_case = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode); @@ -759,7 +762,7 @@ class ForNode : public StmtNode { * \brief Only valid when kind == ForKind::kThreadBinding * The context thread that this loop variable bounds to. */ - Optional thread_binding; + ffi::Optional thread_binding; /*! * \brief Additional annotations about the loop. * @@ -768,7 +771,7 @@ class ForNode : public StmtNode { * not change the control flow semantics of the loop * and can be ignored in most passes. */ - Map annotations; + ffi::Map annotations; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -793,8 +796,9 @@ class ForNode : public StmtNode { class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding = std::nullopt, - Map annotations = Map(), Span span = Span()); + ffi::Optional thread_binding = std::nullopt, + ffi::Map annotations = ffi::Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); @@ -848,7 +852,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { /*! \brief The buffer of the buffer region. */ Buffer buffer; /*! \brief The region array of the buffer region. */ - Array region; + ffi::Array region; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -870,7 +874,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { */ class BufferRegion : public PrimExprConvertible { public: - TVM_DLL explicit BufferRegion(Buffer buffer, Array region); + TVM_DLL explicit BufferRegion(Buffer buffer, ffi::Array region); /*! * \brief Create a BufferRegion which is full region of the given buffer. @@ -885,7 +889,7 @@ class BufferRegion : public PrimExprConvertible { * \param indices The access point indices of the buffer * \return The BufferRegion which is the single point of the given buffer. */ - TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array indices); + TVM_DLL static BufferRegion FromPoint(Buffer buffer, ffi::Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, PrimExprConvertible, BufferRegionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode); @@ -955,19 +959,19 @@ class MatchBufferRegion : public ObjectRef { class BlockNode : public StmtNode { public: /*! \brief The variables of the block. */ - Array iter_vars; + ffi::Array iter_vars; /*! \brief The read buffer regions of the block. */ - Array reads; + ffi::Array reads; /*! \brief The write buffer regions of the block. */ - Array writes; + ffi::Array writes; /*! \brief The name_hint of the block. */ - String name_hint; + ffi::String name_hint; /*! \brief The buffer allocated in the block. */ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The match buffer regions. */ - Array match_buffers; + ffi::Array match_buffers; /*! \brief The annotation of the block. */ - Map annotations; + ffi::Map annotations; /*! * \brief The init statement is executed during the first iteration of reduction loops in a * reduction block. The optional init field allows us to represent initialization and @@ -975,7 +979,7 @@ class BlockNode : public StmtNode { * We also provide primitives to decompose the init into a separate block during scheduling. * Init field is `std::nullopt` if there is no reduction iter_vars */ - Optional init; + ffi::Optional init; /*! \brief The body of the block. */ Stmt body; @@ -1003,13 +1007,14 @@ class BlockNode : public StmtNode { */ class Block : public Stmt { public: - TVM_DLL explicit Block(Array iter_vars, Array reads, - Array writes, String name_hint, Stmt body, - Optional init = std::nullopt, - Array alloc_buffers = Array(), - Array match_buffers = Array(), - Map annotations = Map(), - Span span = Span()); + TVM_DLL explicit Block( + ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init = std::nullopt, + ffi::Array alloc_buffers = ffi::Array(), + ffi::Array match_buffers = ffi::Array(), + ffi::Map annotations = ffi::Map(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); @@ -1021,7 +1026,7 @@ class Block : public Stmt { class BlockRealizeNode : public StmtNode { public: /*! \brief The corresponding values of the iter vars. */ - Array iter_values; + ffi::Array iter_values; /*! * \brief The predicate of the block realization, the block will only be executed when the * predicate is true. @@ -1048,7 +1053,7 @@ class BlockRealizeNode : public StmtNode { */ class BlockRealize : public Stmt { public: - TVM_DLL explicit BlockRealize(Array iter_values, PrimExpr predicate, Block block, + TVM_DLL explicit BlockRealize(ffi::Array iter_values, PrimExpr predicate, Block block, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode); @@ -1146,7 +1151,7 @@ constexpr const char* buffer_dim_align = "buffer_dim_align"; constexpr const char* buffer_bound = "buffer_bound"; /*! * \brief Bind the buffer specification to the region of the op - * When this scope occurs, the stmt.node is a Array = [buffer, tensor] + * When this scope occurs, the stmt.node is a ffi::Array = [buffer, tensor] * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). * The scope represents that we need to bind the storage region of tensor to buffer. * This will affect replacement of some variables inside the scope that diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 23747a7e936c..b3c43bdc1459 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -325,7 +325,7 @@ class StmtExprMutator : public StmtMutator, public ExprMutator { * when the IRNode's type key is in the list. */ TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder, - Optional> only_enable = std::nullopt); + ffi::Optional> only_enable = std::nullopt); /*! * \brief Recursively visit the ir in post DFS order node, apply fvisit @@ -341,7 +341,7 @@ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function(const Var& var)> vmap); +TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -349,7 +349,8 @@ TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& v * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. * \return The result. */ -TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); +TVM_DLL PrimExpr Substitute(PrimExpr expr, + std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -358,7 +359,8 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(cons * \return The result. */ template -Array Substitute(const Array& arr, std::function(const Var& var)> vmap) { +ffi::Array Substitute(const ffi::Array& arr, + std::function(const Var& var)> vmap) { return arr.Map([&vmap](const auto& elem) { return Substitute(elem, vmap); }); } @@ -369,7 +371,7 @@ Array Substitute(const Array& arr, std::function(const * \return The modified Range. */ inline Range Substitute(const Range& range, - std::function(const Var& var)> vmap) { + std::function(const Var& var)> vmap) { return Range::FromMinExtent(Substitute(range->min, vmap), Substitute(range->extent, vmap)); } @@ -385,8 +387,8 @@ inline Range Substitute(const Range& range, * \return The modified object. */ template -auto Substitute(Obj&& obj, const Map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { return vmap.Get(var); }; +auto Substitute(Obj&& obj, const ffi::Map& vmap) { + auto func = [&vmap](const Var& var) -> ffi::Optional { return vmap.Get(var); }; return Substitute(std::forward(obj), func); } @@ -401,8 +403,8 @@ auto Substitute(Obj&& obj, const Map& vmap) { */ template >> -auto Substitute(Obj&& obj, const Map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { +auto Substitute(Obj&& obj, const ffi::Map& vmap) { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto opt = vmap.Get(var)) { return opt.value(); } else { @@ -424,7 +426,7 @@ auto Substitute(Obj&& obj, const Map& vmap) { template >> auto Substitute(Obj&& obj, const std::unordered_map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { @@ -446,7 +448,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& vmap) template >> auto Substitute(Obj&& obj, const std::unordered_map& vmap) { - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var); it != vmap.end()) { return it->second; } else { @@ -473,7 +475,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& iter_vmap) { vmap[iter_var->var.get()] = expr; } - auto func = [&vmap](const Var& var) -> Optional { + auto func = [&vmap](const Var& var) -> ffi::Optional { if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { @@ -493,8 +495,8 @@ auto Substitute(Obj&& obj, const std::unordered_map& iter_vmap) { * \sa Substitute * \return The result. */ -TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, - std::function(const Var&)> vmap); +TVM_DLL Stmt SubstituteWithDataTypeLegalization( + Stmt stmt, std::function(const Var&)> vmap); /*! * \brief Substitute the var specified by vmap and legalize data types after substitution. @@ -507,7 +509,7 @@ TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, * \return The result. */ TVM_DLL PrimExpr SubstituteWithDataTypeLegalization( - PrimExpr expr, std::function(const Var&)> vmap); + PrimExpr expr, std::function(const Var&)> vmap); /*! * \brief Recursively visit the IR in pre DFS order node, apply fvisit. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index bd6a5d537239..af59db38771d 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,8 +56,8 @@ using tvm::transform::Sequential; * \return The created function pass. */ TVM_DLL Pass CreatePrimFuncPass(std::function pass_func, - int opt_level, String name, tvm::Array required, - bool traceable = false); + int opt_level, ffi::String name, + tvm::ffi::Array required, bool traceable = false); /*! * \brief partition loops in the stmt. @@ -197,7 +197,7 @@ TVM_DLL Pass MakeUnpackedAPI(); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); +TVM_DLL Pass RemapThreadAxis(ffi::Map axis_map); /*! * \brief Lower custom datatypes. @@ -273,7 +273,7 @@ TVM_DLL Pass SkipAssert(); * \param storage_scope The storage scope considered. * \return The pass. */ -TVM_DLL Pass ThreadSync(String storage_scope); +TVM_DLL Pass ThreadSync(ffi::String storage_scope); /*! * \brief Lower cross thread alleduce. @@ -361,7 +361,7 @@ TVM_DLL Pass BF16ComputeLegalize(); * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16"); /*! * \brief Legalize bf16 storage types to u16. @@ -676,7 +676,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); */ TVM_DLL Pass InjectSoftwarePipeline(); -TVM_DLL Pass BindParams(const Array& constants); +TVM_DLL Pass BindParams(const ffi::Array& constants); /*! * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 7bf29265ceea..578b00fc08d4 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -51,7 +51,7 @@ class VarNode : public PrimExprNode { * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ - String name_hint; + ffi::String name_hint; /*! * \brief type annotation of the variable. * @@ -84,7 +84,7 @@ class Var : public PrimExpr { * \param dtype data type * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32), + TVM_DLL explicit Var(ffi::String name_hint = "v", DataType dtype = DataType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -92,19 +92,19 @@ class Var : public PrimExpr { * \param type_annotation The type annotation. * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span()); + TVM_DLL explicit Var(ffi::String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Make a new copy of var with same type, but a different nam * \param name The new name to be used. * \return the new Var copy */ - TVM_DLL Var copy_with_name(const String& name) const; + TVM_DLL Var copy_with_name(const ffi::String& name) const; /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ - TVM_DLL Var copy_with_suffix(const String& suffix) const; + TVM_DLL Var copy_with_suffix(const ffi::String& suffix) const; /*! * \brief Make a new copy of the variable with specified dtype * \param dtype The specified dtype @@ -150,7 +150,7 @@ class SizeVar : public Var { * \param t data type * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), + TVM_DLL explicit SizeVar(ffi::String name_hint = "s", DataType t = DataType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -158,7 +158,7 @@ class SizeVar : public Var { * \param type_annotation The type annotation. * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span()); + TVM_DLL explicit SizeVar(ffi::String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -173,7 +173,7 @@ class SizeVar : public Var { using ContainerType = SizeVarNode; }; -using Region = Array; +using Region = ffi::Array; /*! * \brief Type of iteration variable. @@ -266,7 +266,7 @@ class IterVarNode : public PrimExprConvertibleNode { * \brief additional tag on the iteration variable, * set this if this is bound already to a known thread tag. */ - String thread_tag; + ffi::String thread_tag; /*! * \brief Span that points to the original source code. * Reserved debug information. @@ -297,7 +297,7 @@ class IterVarNode : public PrimExprConvertibleNode { */ class IterVar : public PrimExprConvertible { public: - TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "", + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, ffi::String thread_tag = "", Span span = Span()); /*! * \return the corresponding var in the IterVar. diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 9be7256b446e..2aedef4c58b6 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -46,7 +46,7 @@ namespace topi { * \return A Tensor whose op member is a broadcast operation */ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, - const tvm::Array& output_shape, + const tvm::ffi::Array& output_shape, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { ICHECK_GE(output_shape.size(), t->shape.size()) @@ -54,7 +54,7 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); ICHECK_EQ(output_shape.size(), bh.common_shape.size()); - Array oshape; + ffi::Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { if (output_shape[i].as() == nullptr) { oshape.push_back(output_shape[i]); @@ -63,30 +63,32 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, oshape.push_back(bh.common_shape[i]); } } - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; return tvm::te::compute(oshape, l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, \ + tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, \ + tag); \ } #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ diff --git a/include/tvm/topi/contrib/cublas.h b/include/tvm/topi/contrib/cublas.h index 3032643ed700..3590b7a54458 100644 --- a/include/tvm/topi/contrib/cublas.h +++ b/include/tvm/topi/contrib/cublas.h @@ -49,7 +49,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, b return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, @@ -74,7 +74,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tra return make_extern( {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, diff --git a/include/tvm/topi/contrib/rocblas.h b/include/tvm/topi/contrib/rocblas.h index 4f0b887fb178..e29b135b7d2c 100644 --- a/include/tvm/topi/contrib/rocblas.h +++ b/include/tvm/topi/contrib/rocblas.h @@ -48,7 +48,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, @@ -71,7 +71,7 @@ inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tr return make_extern( {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { + [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, diff --git a/include/tvm/topi/detail/array_utils.h b/include/tvm/topi/detail/array_utils.h index 89c985695865..f10eff6f61cb 100644 --- a/include/tvm/topi/detail/array_utils.h +++ b/include/tvm/topi/detail/array_utils.h @@ -41,7 +41,7 @@ using namespace tvm::te; * \return True iff the given array contains the given item. */ template -inline bool contains(Array array, T item) { +inline bool contains(ffi::Array array, T item) { for (auto& i : array) { if (i == item) { return true; diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index c861fbb71b2a..aab6fea22d2c 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -48,8 +48,8 @@ static inline DataType CommonType(DataType type1, DataType type2) { return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); } -inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, - const tvm::Array& shape2) { +inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shape1, + const tvm::ffi::Array& shape2) { BroadcastHelper bh; int s1_size = shape1.size(); int s2_size = shape2.size(); @@ -94,8 +94,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } else { ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " << shape2[s2_size - i] - << " in: " << tvm::Array(shape1.begin(), shape1.end()) << " and " - << tvm::Array(shape2.begin(), shape2.end()); + << " in: " << tvm::ffi::Array(shape1.begin(), shape1.end()) + << " and " << tvm::ffi::Array(shape2.begin(), shape2.end()); } } // Remaining dimensions whether on shape1 or shape2 can always be completed @@ -110,10 +110,10 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, return bh; } -inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, const tvm::te::Tensor& T, +inline tvm::ffi::Array InputIndexFromBroadcast( + const tvm::ffi::Array& ovars, const tvm::te::Tensor& T, const std::deque& my_vars, const std::deque& all_vars) { - tvm::Array ivars; + tvm::ffi::Array ivars; ICHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. size_t expected_dims = T->shape.size(); @@ -141,12 +141,12 @@ inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, const tvm::te::Tensor& B, const std::string& name = "tensor", const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); - auto l = [&](tvm::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; - return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, name, tag); + return tvm::te::compute( + tvm::ffi::Array(bh.common_shape.begin(), bh.common_shape.end()), l, name, tag); } } // namespace detail diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 95e68f5f6d61..74b4ce143cad 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -55,7 +55,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance array) { +inline bool IsConstIntArray(ffi::Array array) { bool is_const_int = true; for (auto const& elem : array) { is_const_int &= !elem.defined() || elem->IsInstance(); @@ -88,7 +88,7 @@ inline int64_t GetConstInt(PrimExpr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues(ffi::Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -107,7 +107,7 @@ inline std::vector GetConstIntValues(Array exprs, const std::stri * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values(Array exprs, +inline std::vector GetConstInt64Values(ffi::Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index e54169ea2934..05543f74a50b 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -41,7 +41,7 @@ using namespace tvm::te; * function. The function expects two arguments: an array of Buffers holding the input * tensor values, and a pre-allocated array of Buffers to be filled with the outputs. */ -using FExtern = std::function, Array)>; +using FExtern = std::function, ffi::Array)>; /*! * \brief Create tensors representing the result of invoking an external function. @@ -60,18 +60,19 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array>& out_shapes, - const std::vector& out_types, - const Array& inputs, FExtern fextern, std::string name, - std::string tag, ::tvm::Map attrs) { +inline ffi::Array make_extern(const ffi::Array>& out_shapes, + const std::vector& out_types, + const ffi::Array& inputs, FExtern fextern, + std::string name, std::string tag, + ::tvm::ffi::Map attrs) { ICHECK_EQ(out_shapes.size(), out_types.size()) << "make_extern: out_shapes and out_types must have equal size"; - Array input_placeholders; + ffi::Array input_placeholders; for (auto t : inputs) { input_placeholders.push_back(tvm::tir::decl_buffer(t->shape, t->dtype, t->op->name)); } - Array output_placeholders; + ffi::Array output_placeholders; for (size_t i = 0; i < out_shapes.size(); ++i) { output_placeholders.push_back(tvm::tir::decl_buffer(out_shapes[i], out_types[i], name)); } @@ -81,7 +82,7 @@ inline Array make_extern(const Array>& out_shapes, auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { outputs.push_back(op.output(i)); } @@ -107,12 +108,13 @@ inline PrimExpr pack_buffer(Buffer buf) { } else { strides = 0; } - Array pack_args{buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset}; + ffi::Array pack_args{ + buf->data, + shape, + strides, + make_const(DataType::Int(32), static_cast(buf->shape.size())), + make_const(buf->dtype, 0), + buf->elem_offset}; return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args); } @@ -125,7 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * * \return An expression representing the invocation */ -inline PrimExpr call_packed(Array args) { +inline PrimExpr call_packed(ffi::Array args) { return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args); } diff --git a/include/tvm/topi/detail/fuse.h b/include/tvm/topi/detail/fuse.h index 7305ccef9b1d..993a837ca46c 100644 --- a/include/tvm/topi/detail/fuse.h +++ b/include/tvm/topi/detail/fuse.h @@ -40,7 +40,7 @@ using namespace tvm::te; * * \return The fused iteration variable */ -inline IterVar Fuse(Stage stage, const Array& args) { +inline IterVar Fuse(Stage stage, const ffi::Array& args) { IterVar res; stage.fuse(args, &res); return res; diff --git a/include/tvm/topi/detail/pad_utils.h b/include/tvm/topi/detail/pad_utils.h index 96eb49a505e4..dfb9542e7655 100644 --- a/include/tvm/topi/detail/pad_utils.h +++ b/include/tvm/topi/detail/pad_utils.h @@ -45,7 +45,7 @@ using namespace tvm::te; * \return An array of 4 elements, representing padding sizes for * each individual side. The array is in the order { top, left, bottom, right } */ -inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { +inline ffi::Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { pad_h *= 2; pad_w *= 2; diff --git a/include/tvm/topi/detail/ravel_unravel.h b/include/tvm/topi/detail/ravel_unravel.h index e91d6afb666a..27d2f9180251 100644 --- a/include/tvm/topi/detail/ravel_unravel.h +++ b/include/tvm/topi/detail/ravel_unravel.h @@ -42,7 +42,7 @@ using namespace tvm::te; * * \return The index after flattening */ -inline PrimExpr RavelIndex(Array indices, Array shape) { +inline PrimExpr RavelIndex(ffi::Array indices, ffi::Array shape) { ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; if (indices.size() == 0U) { return 0; @@ -66,7 +66,7 @@ inline PrimExpr RavelIndex(Array indices, Array shape) { * * \return The coordinate corresponding to the 1D index */ -inline Array UnravelIndex(PrimExpr idx, Array shape) { +inline ffi::Array UnravelIndex(PrimExpr idx, ffi::Array shape) { std::vector indices; for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index f2e021ed98bc..e75aeed8b97d 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -50,8 +50,8 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) } inline std::tuple, std::vector, std::vector> ConvertToVec( - const Array& begin, const Array& end, const Array& strides, - std::string slice_mode) { + const ffi::Array& begin, const ffi::Array& end, + const ffi::Array& strides, std::string slice_mode) { std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { @@ -88,12 +88,13 @@ inline std::tuple, std::vector, std::vector StridedSliceCanonicalizeBegin(const Array& ishape, - const std::vector& begin, - const std::vector& strides, - const Array& axes, DataType dtype, - std::string slice_mode = "end") { - Array begin_expr; +inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::Array& ishape, + const std::vector& begin, + const std::vector& strides, + const ffi::Array& axes, + DataType dtype, + std::string slice_mode = "end") { + ffi::Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { if (ishape[axes[i].IntValue()]->IsInstance()) { int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); @@ -115,16 +116,14 @@ inline Array StridedSliceCanonicalizeBegin(const Array& isha return begin_expr; } -inline Array StridedSliceOutputShape(const Array& ishape, - const std::vector& begin, - const std::vector& end, - const std::vector& strides, - const Array& axes, std::string slice_mode, - const Array& begin_canonicalized, - bool use_any = false) { +inline ffi::Array StridedSliceOutputShape( + const ffi::Array& ishape, const std::vector& begin, + const std::vector& end, const std::vector& strides, + const ffi::Array& axes, std::string slice_mode, + const ffi::Array& begin_canonicalized, bool use_any = false) { ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; const size_t src_tensor_dim = ishape.size(); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < src_tensor_dim; ++i) { out_shape.push_back(ishape[i]); } diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index 397c70c9451e..d67ad6359434 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -40,7 +40,7 @@ using namespace tvm::te; * * \return True if the input shape is empty. */ -inline bool is_empty_shape(const Array& x) { +inline bool is_empty_shape(const ffi::Array& x) { bool is_empty = false; for (const auto& dim : x) { if (auto int_dim = dim.as()) { @@ -63,7 +63,7 @@ inline bool is_empty_shape(const Array& x) { * * \return The interpolated value in the given index. */ -inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, +inline PrimExpr bilinear_sample_nchw(const Tensor& input, const ffi::Array& indices, const PrimExpr max_y, const PrimExpr max_x) { auto batch_id = indices[0]; auto channel_id = indices[1]; @@ -107,7 +107,7 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& * * \return The interpolated value in the given index. */ -inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array& indices, +inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const ffi::Array& indices, const PrimExpr max_y, const PrimExpr max_x) { auto batch_id = indices[0]; auto channel_id = indices[3]; diff --git a/include/tvm/topi/einsum.h b/include/tvm/topi/einsum.h index 5e7813f8431b..44f01b0a967c 100644 --- a/include/tvm/topi/einsum.h +++ b/include/tvm/topi/einsum.h @@ -56,8 +56,8 @@ using namespace topi::detail; * * \return the shape of the output. */ -Array InferEinsumShape(const std::string& subscripts, - const std::vector>& operands); +ffi::Array InferEinsumShape(const std::string& subscripts, + const std::vector>& operands); /*! * \brief Evaluates the Einstein summation convention on the operands. @@ -70,7 +70,7 @@ Array InferEinsumShape(const std::string& subscripts, * * \return The calculation based on the Einstein summation convention. */ -Tensor einsum(const std::string& subscripts_str, const Array inputs, +Tensor einsum(const std::string& subscripts_str, const ffi::Array inputs, std::string name = "T_einsum", std::string tag = kEinsum); struct EinsumEquation { diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 806ddcb662f9..0ed082b0c140 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -40,11 +40,11 @@ namespace topi { using namespace tvm::te; // Unary intrinsic operators -#define TOPI_DECLARE_UNARY_OP(OpName) \ - inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ - std::string tag = kElementWise) { \ - return compute( \ - x->shape, [&](const Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ +#define TOPI_DECLARE_UNARY_OP(OpName) \ + inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ + std::string tag = kElementWise) { \ + return compute( \ + x->shape, [&](const ffi::Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ } TOPI_DECLARE_UNARY_OP(exp); @@ -101,7 +101,7 @@ inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string ta return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto x2 = x(i) * x(i); auto p = x2 * alpha_13 + alpha_11; p = x2 * p + alpha_9; @@ -136,7 +136,7 @@ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", } else { // fallback to default implementation return compute( - x->shape, [&](const Array& i) { return ::tvm::tanh(x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ::tvm::tanh(x(i)); }, name, tag); } } @@ -152,7 +152,7 @@ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", inline Tensor identity(const Tensor& x, std::string name = "T_identity", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return x(i); }, name, tag); } /*! @@ -167,7 +167,7 @@ inline Tensor identity(const Tensor& x, std::string name = "T_identity", inline Tensor negative(const Tensor& x, std::string name = "T_negative", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return -x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return -x(i); }, name, tag); } /*! @@ -182,7 +182,7 @@ inline Tensor negative(const Tensor& x, std::string name = "T_negative", inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return !x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return !x(i); }, name, tag); } /*! @@ -197,7 +197,7 @@ inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return ~x(i); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ~x(i); }, name, tag); } /*! @@ -212,7 +212,7 @@ inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { PrimExpr zero = make_zero(x->dtype); PrimExpr one = make_const(x->dtype, 1); PrimExpr minus_one = make_const(x->dtype, -1); @@ -235,7 +235,7 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { PrimExpr one = make_const(x->dtype, 1); return one / tvm::sqrt(x(i)); }, @@ -258,7 +258,7 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max std::string name = "T_clip", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto min_val = tvm::cast(x->dtype, a_min); auto max_val = tvm::cast(x->dtype, a_max); return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) @@ -282,7 +282,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) -> PrimExpr { + [&](const ffi::Array& i) -> PrimExpr { auto expr = x(i); if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { @@ -310,7 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { return compute( - x->shape, [&](const Array& i) { return reinterpret(type, x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return reinterpret(type, x(i)); }, name, tag); } /*! @@ -322,12 +322,12 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te * * \return A Tensor whose op member is the sum operation */ -inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", +inline Tensor elemwise_sum(const ffi::Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; return compute( xs[0]->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { auto sum_expr = xs[0](i); for (size_t j = 1; j < xs.size(); j++) { sum_expr = sum_expr + xs[j](i); @@ -348,14 +348,14 @@ inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwi * * \return A Tensor whose op member is the full operation */ -inline Tensor full(const Array& shape, DataType dtype, const PrimExpr fill_value, +inline Tensor full(const ffi::Array& shape, DataType dtype, const PrimExpr fill_value, std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } return compute( - shape, [&](const Array& i) { return ev; }, name, tag); + shape, [&](const ffi::Array& i) { return ev; }, name, tag); } /*! @@ -373,7 +373,7 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, std::string name = "T_full_like", std::string tag = kElementWise) { PrimExpr ev = cast(x->dtype, fill_value); return compute( - x->shape, [&](const Array& i) { return ev; }, name, tag); + x->shape, [&](const ffi::Array& i) { return ev; }, name, tag); } /*! @@ -414,7 +414,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t return compute( _x->shape, - [&](const Array& i) { + [&](const ffi::Array& i) { // clamp x auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); // integer part @@ -448,7 +448,7 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", return ret; } else { return compute( - x->shape, [&](const Array& i) { return ::tvm::exp(x(i)); }, name, tag); + x->shape, [&](const ffi::Array& i) { return ::tvm::exp(x(i)); }, name, tag); } } @@ -457,7 +457,7 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", */ inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) { return compute( - data->shape, [&](const Array& i) { return fast_erf_float_expr(data(i), 32); }, name, + data->shape, [&](const ffi::Array& i) { return fast_erf_float_expr(data(i), 32); }, name, tag); } @@ -466,7 +466,7 @@ inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string */ inline Tensor fast_erf_float16(const Tensor& data, std::string name, std::string tag) { return compute( - data->shape, [&](const Array& i) { return fast_erf_float_expr(data(i), 16); }, name, + data->shape, [&](const ffi::Array& i) { return fast_erf_float_expr(data(i), 16); }, name, tag); } diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 6bef5d0f1c2a..36ce8594b3db 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -56,7 +56,7 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::Array& i) { + [&](const tvm::ffi::Array& i) { auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, @@ -78,7 +78,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::Array& i) { + [&](const tvm::ffi::Array& i) { auto value = t(i); auto calpha = tvm::tir::make_const(value.dtype(), alpha); return tvm::tir::Select(value > 0, value, value * calpha); @@ -106,7 +106,7 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl return tvm::te::compute( x->shape, - [&](const tvm::Array& indices) { + [&](const tvm::ffi::Array& indices) { auto xval = x(indices); return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); }, @@ -152,11 +152,11 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl * * */ -inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", - std::string tag = kElementWise, std::string pad_mode = "constant", - const Array* dyn_output_shape = nullptr) { +inline tvm::te::Tensor pad( + const tvm::te::Tensor& t, const tvm::ffi::Array& pad_before, + tvm::ffi::Array pad_after = tvm::ffi::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", std::string tag = kElementWise, + std::string pad_mode = "constant", const ffi::Array* dyn_output_shape = nullptr) { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); @@ -166,8 +166,8 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array pad_before_int32; - tvm::Array pad_after_int32; + tvm::ffi::Array pad_before_int32; + tvm::ffi::Array pad_after_int32; for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); @@ -176,7 +176,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array output_shape; + tvm::ffi::Array output_shape; if (dyn_output_shape == nullptr) { for (size_t i = 0; i < t->shape.size(); ++i) { if (i >= pad_before.size()) { @@ -196,10 +196,10 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Arraydtype, 0); } - auto l = [&](tvm::Array ovars) { - tvm::Array indices; - tvm::Array sel; - tvm::Array pad_idx; + auto l = [&](tvm::ffi::Array ovars) { + tvm::ffi::Array indices; + tvm::ffi::Array sel; + tvm::ffi::Array pad_idx; for (size_t i = 0; i < t->shape.size(); ++i) { if (i >= pad_before_int32.size()) { indices.push_back(ovars[i]); @@ -273,7 +273,7 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tens ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B W->shape[0], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -317,7 +317,7 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tens ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W I->shape[2], // B @@ -363,7 +363,7 @@ inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H @@ -392,7 +392,7 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W @@ -440,7 +440,7 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t ICHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; - tvm::Array output_shape{ + tvm::ffi::Array output_shape{ I->shape[0], // B I->shape[1], // G W->shape[2], // O @@ -454,7 +454,7 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::Array args) { + auto l = [&](tvm::ffi::Array args) { tvm::tir::Var b = args[0]; tvm::tir::Var g = args[1]; tvm::tir::Var o = args[2]; @@ -480,9 +480,9 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t * \return A Tensor whose op member is the space_to_batch_nd operation */ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, - const tvm::Array& block_shape, - const tvm::Array& pad_before, - const tvm::Array& pad_after, + const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& pad_before, + const tvm::ffi::Array& pad_after, PrimExpr pad_value = PrimExpr(), std::string name = "space_to_batch_nd", std::string tag = kInjective) { @@ -490,8 +490,8 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, CHECK_EQ(pad_before.size(), pad_after.size()); CHECK_EQ(block_shape.size(), pad_before.size()) << "Paddings must be provided for each spatial dimension"; - tvm::Array pad_before_int32; - tvm::Array pad_after_int32; + tvm::ffi::Array pad_before_int32; + tvm::ffi::Array pad_after_int32; // pad size for batch dimension is 0 pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); @@ -514,9 +514,9 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, auto padded_shape = padded_t->shape; // infer shapes - tvm::Array r_shape; - tvm::Array axis; - tvm::Array o_shape; + tvm::ffi::Array r_shape; + tvm::ffi::Array axis; + tvm::ffi::Array o_shape; size_t num_block_dims = block_shape.size(); int batch = static_cast(GetConstInt(input_shape[0])); @@ -576,15 +576,15 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, * \return A Tensor whose op member is the batch_to_space_nd operation */ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, - const tvm::Array& block_shape, - const tvm::Array& crop_begin_list, - const tvm::Array& crop_end_list, + const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& crop_begin_list, + const tvm::ffi::Array& crop_end_list, std::string name = "batch_to_space_nd", std::string tag = kInjective) { // Construct shapes for reshape and transpose operation - Array in_shape = data->shape; - Array r_shape; - Array axis; + ffi::Array in_shape = data->shape; + ffi::Array r_shape; + ffi::Array axis; size_t num_block_dims = block_shape.size(); size_t num_input_dims = in_shape.size(); tvm::PrimExpr block_shape_prod(1); @@ -605,7 +605,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, r_shape.push_back(in_shape[i]); } - Array r_p_shape; + ffi::Array r_p_shape; r_p_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { r_p_shape.push_back(in_shape[i] * block_shape[i - 1]); @@ -620,7 +620,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + ffi::Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { @@ -665,7 +665,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T // prediction->shape = (C,), targets->shape = (), weights->shape = (C,) auto T = tvm::te::compute( {}, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c), tvm::tir::make_const(predictions->dtype, 0)); @@ -674,7 +674,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T if (reduction == "mean") { auto W = tvm::te::compute( {}, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tir::Select(c != ignore_index, weights(c), tvm::tir::make_const(predictions->dtype, 0)); @@ -687,9 +687,9 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T } auto T = tvm::te::compute( targets->shape, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); - tvm::Array pred_indices; + tvm::ffi::Array pred_indices; pred_indices.push_back(target_indices[0]); // batch index pred_indices.push_back(c); // class index for (size_t i = 1; i < target_indices.size(); i++) { @@ -703,16 +703,16 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T if (reduction == "mean") { auto W = tvm::te::compute( targets->shape, - [&](const tvm::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); return tvm::tir::Select(c != ignore_index, weights(c), tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); - return topi::divide(topi::sum(T, tvm::Array(nullptr)), - topi::sum(W, tvm::Array(nullptr))); + return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), + topi::sum(W, tvm::ffi::Array(nullptr))); } else if (reduction == "sum") { - return topi::sum(T, tvm::Array(nullptr)); + return topi::sum(T, tvm::ffi::Array(nullptr)); } else { // reduction == "none" return T; } diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 815b8a23c998..2cc494eaa9d4 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -57,7 +57,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, arith::Analyzer analyzer; auto n = ishape.size(); - Array oshape; + ffi::Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) : ishape[i]); @@ -65,15 +65,15 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, return tvm::te::compute( oshape, - [&](const Array& indices) { - Array start_idx; + [&](const ffi::Array& indices) { + ffi::Array start_idx; for (size_t i = 0; i < n; ++i) { start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : static_cast(indices[i])); } auto packed = make_const(DataType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { - Array idx; + ffi::Array idx; for (size_t i = 0; i < n; ++i) { idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) : start_idx[i]); diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 74c46e2694b3..816d489c400e 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -44,7 +44,7 @@ using namespace tvm::te; * * \return The logical conjunction expression */ -PrimExpr all(Array args) { +PrimExpr all(ffi::Array args) { ICHECK_GT(args.size(), 0) << "all requires at least one argument"; PrimExpr ret = args[0]; @@ -67,13 +67,13 @@ PrimExpr all(Array args) { * * \return The output tensor. */ -inline Tensor dilate(const Tensor& x, Array strides, double dilation_value, +inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilation_value, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); ICHECK_EQ(n, strides.size()) << "strides size (" << strides.size() << ") must match dimension of x (" << n << ")"; - Array out_shape; + ffi::Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { out_shape.push_back(analyzer.Simplify((x->shape[i] - 1) * (strides[i] + 1))); @@ -81,9 +81,9 @@ inline Tensor dilate(const Tensor& x, Array strides, double dilation_v return tvm::te::compute( out_shape, - [&](const Array& indices) { - Array not_zero; - Array index_tuple; + [&](const ffi::Array& indices) { + ffi::Array not_zero; + ffi::Array index_tuple; for (size_t i = 0; i < n; ++i) { if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { index_tuple.push_back(indices[i]); diff --git a/include/tvm/topi/nn/flatten.h b/include/tvm/topi/nn/flatten.h index cd96d303b920..e60ae1e1d641 100644 --- a/include/tvm/topi/nn/flatten.h +++ b/include/tvm/topi/nn/flatten.h @@ -54,7 +54,7 @@ inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string dim = dim * ishape[i]; } - Array oshape({ishape[0], dim}); + ffi::Array oshape({ishape[0], dim}); std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 9dcc1dda9e43..9c03b682407d 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -37,7 +37,7 @@ namespace nn { using namespace tvm::te; inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int num_groups, int channel_axis, const Array& axes, + int num_groups, int channel_axis, const ffi::Array& axes, double epsilon, std::string name = "T_group_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; @@ -50,11 +50,11 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& bool is_float16 = data_type == DataType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); - channel_axis = GetRealAxis(static_cast(ndim), Array({channel_axis}))[0]; + channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; auto shape = data->shape; auto group_size = floordiv(shape[channel_axis], num_groups); - auto new_shape = Array(); + auto new_shape = ffi::Array(); for (int i = 0; i < ndim; ++i) { if (i == channel_axis) { new_shape.push_back(num_groups); @@ -82,7 +82,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& // get the new axes to normalize after reshape std::vector new_axes{channel_axis + 1}; for (auto axis : axes) { - int new_axis = GetRealAxis(static_cast(ndim), Array({axis}))[0]; + int new_axis = GetRealAxis(static_cast(ndim), ffi::Array({axis}))[0]; if (new_axis < channel_axis) { new_axes.push_back(new_axis); } else if (new_axis > channel_axis) { @@ -100,8 +100,9 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer(); - auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const Array& indices) { - Array eval_range; + auto compute = [ndim, &new_axes, &reduce_axes, &func, + &data_reshaped](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -129,8 +130,8 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } - auto group_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices, gamma_indices; + auto group_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices, gamma_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index d400721215ec..c6a10ec89f0a 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -51,7 +51,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int channel_axis, const Array& axis, double epsilon, + int channel_axis, const ffi::Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; @@ -71,8 +71,8 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; + &data](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -110,8 +110,8 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso for (int i : real_axis) { reduce_extent *= data->shape[i]; } - auto instance_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto instance_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index f1b0e4ac9eaa..7caa30b0a23b 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -49,7 +49,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - const Array& axis, double epsilon, + const ffi::Array& axis, double epsilon, std::string name = "T_layer_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; @@ -69,8 +69,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; + &data](const ffi::Array& indices) { + ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -108,8 +108,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& for (int i : real_axis) { reduce_extent *= data->shape[i]; } - auto layer_norm_func = [&](const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto layer_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index a9d72250bbb0..119ab0c19eb0 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -57,8 +57,8 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; ICHECK(data->dtype.is_float()) << "datatype should be float"; auto input_shape = data->shape; - Array pad_before{0, 0, 0, 0}; - Array pad_after{0, 0, 0, 0}; + ffi::Array pad_before{0, 0, 0, 0}; + ffi::Array pad_after{0, 0, 0, 0}; pad_before.Set(axis, static_cast(size / 2)); pad_after.Set(axis, static_cast(size / 2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 8e13ae49afdf..b977a54a5920 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -47,8 +47,9 @@ enum PoolType : int { }; inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, - const Array& kernel_size, const Array& stride_size, - const Array& padding_size, PoolType pool_type, + const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; @@ -77,11 +78,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, pad_right += stride_width - 1; } - Array pad_before(std::vector(x->shape.size(), 0)); + ffi::Array pad_before(std::vector(x->shape.size(), 0)); pad_before.Set(height_axis, pad_top); pad_before.Set(width_axis, pad_left); - Array pad_after(std::vector(x->shape.size(), 0)); + ffi::Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); arith::Analyzer analyzer; @@ -93,8 +94,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); - Array data_shape = x->shape; - Array out_shape = data_shape; + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); @@ -106,7 +107,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - Array ravel_shape{data_shape.begin(), data_shape.end()}; + ffi::Array ravel_shape{data_shape.begin(), data_shape.end()}; ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); @@ -120,8 +121,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, auto mp_argmax = tvm::te::compute( out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; + [&](const ffi::Array& inds) { + ffi::Array window_inds{inds.begin(), inds.end()}; window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); auto idx = detail::RavelIndex(window_inds, ravel_shape); @@ -133,13 +134,13 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, return tvm::te::compute( data_shape, - [&](const Array& inds) { - Array pad_inds{inds.begin(), inds.end()}; + [&](const ffi::Array& inds) { + ffi::Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx{inds.begin(), inds.end()}; + ffi::Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); @@ -165,12 +166,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); return tvm::te::compute( data_shape, - [&](const Array& inds) { + [&](const ffi::Array& inds) { PrimExpr pad_h_idx = inds[height_axis] + pad_top; PrimExpr pad_w_idx = inds[width_axis] + pad_left; // output indices whose pooling windows cover current input element (can be out-of-bound) - Array out_idx{inds.begin(), inds.end()}; + ffi::Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); @@ -290,9 +291,11 @@ inline bool find_width(const std::string& layout, int* width_axis) { * * \return The output tensor in the same layout */ -inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", +inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, + const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& padding_size, PoolType pool_type, + bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; @@ -319,24 +322,24 @@ inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const Prim * * \return The output tensor in same layout order */ -inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; - Array data_shape = x->shape; - Array out_shape = data_shape; - Array in_size, out_size; + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; + ffi::Array in_size, out_size; for (size_t i = 0; i < n_dim; ++i) { in_size.push_back(data_shape[axes[i]]); out_size.push_back(output_size[i]); out_shape.Set(axes[i], out_size[i]); } - auto get_iter_vars = [=](const Array& output, bool reduce_indices) { - Array indices; + auto get_iter_vars = [=](const ffi::Array& output, bool reduce_indices) { + ffi::Array indices; for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]); - Array reduce_axes; + ffi::Array reduce_axes; for (size_t i = 0; i < n_dim; ++i) { auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]); auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]); @@ -350,25 +353,25 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ return std::make_tuple(indices, reduce_axes); }; - Map attrs; + ffi::Map attrs; if (pool_type == kMaxPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_max")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_max")); return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::max(x(indices), reduce_axes); // NOLINT(*) }, "adaptive_pool_max", "adaptive_pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_avg")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_avg")); auto pool_sum = tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::sum(x(indices), reduce_axes); }, @@ -376,9 +379,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; + [&](const ffi::Array& output) { + ffi::Array indices; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, false); PrimExpr divide_factor = tvm::cast(x->dtype, 1); @@ -421,8 +424,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ * * \return The output tensor in same layout order */ -inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, PoolType pool_type, - const std::string& layout = "NCHW") { +inline Tensor adaptive_pool(const Tensor& x, const ffi::Array& output_size, + PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); @@ -436,7 +439,7 @@ inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, * \param pool_type The type of pooling operator * \param layout The input layout. The default is "NCDHW". */ -inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool3d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) @@ -452,7 +455,7 @@ inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_siz * \param pool_type The type of pooling operator * \param layout The input layout. The default is "NCW". */ -inline Tensor adaptive_pool1d(const Tensor& x, const Array& output_size, +inline Tensor adaptive_pool1d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCW") { int width_axis = -1; ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; @@ -485,7 +488,7 @@ inline Tensor adaptive_pool1d(const Tensor& x, const Array& output_siz * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] */ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { - return adaptive_pool(x, Array{1, 1}, pool_type, layout); + return adaptive_pool(x, ffi::Array{1, 1}, pool_type, layout); } /*! @@ -504,10 +507,11 @@ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string * * \return The output tensor in same layout order */ -inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, - const std::vector& axis, bool count_include_pad) { +inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, + bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; @@ -515,17 +519,17 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, " kernel"; ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; - Array daxis; + ffi::Array daxis; std::vector kernel(k_size); std::vector stride(k_size); std::vector dilation(k_size); std::vector pad_head(k_size); std::vector pad_tail(k_size); std::vector offset(k_size, 0); - Array pad_before(std::vector(x_size, 0)); - Array pad_after(std::vector(x_size, 0)); - Array data_shape = x->shape; - Array out_shape = data_shape; + ffi::Array pad_before(std::vector(x_size, 0)); + ffi::Array pad_after(std::vector(x_size, 0)); + ffi::Array data_shape = x->shape; + ffi::Array out_shape = data_shape; bool do_pad = false; for (int i = 0; i < k_size; i++) { @@ -563,14 +567,14 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, out_shape.Set(ii, out_dim); } - Map attrs; + ffi::Map attrs; if (pool_type == kMaxPool) { auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_max")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_max")); return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -581,15 +585,15 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, }, "pool_max", "pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_avg")); + attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_avg")); // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. auto pool_sum = tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); for (int i = 0; i < k_size; i++) { @@ -603,8 +607,8 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, // TVM compute for dividing the reduced window sum by kernel size. return tvm::te::compute( out_shape, - [&](const Array& output) { - Array indices; + [&](const ffi::Array& output) { + ffi::Array indices; for (const Var& var : output) indices.push_back(var); if (count_include_pad) { std::vector start(k_size); @@ -687,9 +691,10 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool1d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool1d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; @@ -728,9 +733,10 @@ inline Tensor pool1d(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool2d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool2d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; @@ -770,9 +776,10 @@ inline Tensor pool2d(const Tensor& x, const Array& kernel_size, * * \return The output tensor in the same layout */ -inline Tensor pool3d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& dilation_size, - const Array& padding_size, PoolType pool_type, bool ceil_mode, +inline Tensor pool3d(const Tensor& x, const ffi::Array& kernel_size, + const ffi::Array& stride_size, + const ffi::Array& dilation_size, + const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 7e95000f1ee2..66a2ae62dfec 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -47,7 +47,7 @@ using namespace tvm::te; * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& axis, +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Array& axis, double epsilon, std::string name = "T_rms_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; @@ -67,8 +67,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Arrayshape[i]; } - auto rsqrt_func = [&](const Array& indices) { - Array non_reduce_indices; + auto rsqrt_func = [&](const ffi::Array& indices) { + ffi::Array non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { non_reduce_indices.push_back(indices[i]); @@ -78,7 +78,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array(); + auto rsqrt_shape = ffi::Array(); for (int i = 0, n = static_cast(data_fp32->shape.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { rsqrt_shape.push_back(data_fp32->shape[i]); @@ -86,8 +86,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& indices) { - Array reduce_indices, non_reduce_indices; + auto rms_norm_func = [&](const ffi::Array& indices) { + ffi::Array reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { reduce_indices.push_back(indices[i]); diff --git a/include/tvm/topi/nn/softmax.h b/include/tvm/topi/nn/softmax.h index 6679b84c8d03..f58d66ece139 100644 --- a/include/tvm/topi/nn/softmax.h +++ b/include/tvm/topi/nn/softmax.h @@ -60,11 +60,12 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); - tvm::Map attrs; + tvm::ffi::Map attrs; attrs.Set("axis", Integer(axis)); - auto insert_reduce_index = [axis, ndim](const Array& indices, const IterVar& reduce_index) { - Array eval_range; + auto insert_reduce_index = [axis, ndim](const ffi::Array& indices, + const IterVar& reduce_index) { + ffi::Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) == axis) { @@ -76,41 +77,41 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor return eval_range; }; - auto get_non_reduce_indices = [axis, ndim](const Array& indices) { - Array non_reduce_indices; + auto get_non_reduce_indices = [axis, ndim](const ffi::Array& indices) { + ffi::Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); } return non_reduce_indices; }; - auto _compute_max = [&](const Array& indices) { + auto _compute_max = [&](const ffi::Array& indices) { auto eval_range = insert_reduce_index(indices, k1); return topi::MaxOp(x(eval_range), {k1}); }; - auto _compute_exp = [&](const Tensor& max_elem, const Array& indices) { + auto _compute_exp = [&](const Tensor& max_elem, const ffi::Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return tvm::exp(x(indices) - max_elem(non_reduce_indices)); }; - auto _compute_expsum = [&](const Tensor& exp, const Array& indices) { + auto _compute_expsum = [&](const Tensor& exp, const ffi::Array& indices) { auto eval_range = insert_reduce_index(indices, k2); return tvm::sum(exp(eval_range), {k2}); }; - auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array& indices) { + auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const ffi::Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return exp(indices) / expsum(non_reduce_indices); }; auto max_elem = tvm::te::compute(reduced_shape, _compute_max); auto exp = tvm::te::compute( - input_shape, [&](const Array& indices) { return _compute_exp(max_elem, indices); }); + input_shape, [&](const ffi::Array& indices) { return _compute_exp(max_elem, indices); }); auto expsum = tvm::te::compute( - reduced_shape, [&](const Array& indices) { return _compute_expsum(exp, indices); }); + reduced_shape, [&](const ffi::Array& indices) { return _compute_expsum(exp, indices); }); return tvm::te::compute( - input_shape, [&](const Array& indices) { return _normalize(exp, expsum, indices); }, + input_shape, [&](const ffi::Array& indices) { return _normalize(exp, expsum, indices); }, name, tag, attrs); } @@ -132,7 +133,7 @@ inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", auto k = tvm::te::reduce_axis(Range(0, n), "k"); auto max_elem = - tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array{k}); }); + tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), ffi::Array{k}); }); k = tvm::te::reduce_axis(Range(0, n), "k"); auto expsum = diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 277de68e972e..fda754061bbe 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -43,12 +43,12 @@ namespace topi { using namespace tvm::te; /*! \brief The operation to use for CommReduce */ -using FReduce = std::function& axis, - Array init, Span span)>; +using FReduce = std::function& axis, + ffi::Array init, Span span)>; /*! \brief The operation to use for CommReduceIdx */ -using FCommReduce = std::function(Array exprs, const Array& axis, - PrimExpr* condition)>; +using FCommReduce = std::function( + ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition)>; /*! * \brief Convert a reduction axis which could be empty or have negative @@ -62,7 +62,7 @@ using FCommReduce = std::function(Array exprs, const A * If any input element is negative, it will be treated as an offset from the * last dimension (same as python indexing rules). */ -inline std::vector GetRealAxis(int ndim, const Optional>& axis) { +inline std::vector GetRealAxis(int ndim, const ffi::Optional>& axis) { std::vector real_axis; if (!axis.has_value()) { for (int i = 0; i < ndim; ++i) { @@ -86,8 +86,8 @@ inline std::vector GetRealAxis(int ndim, const Optional>& ax } /*! \brief Enumerate the axes for a reduce op */ -inline Array MakeReduceAxes(const std::vector& real_axis, const Tensor& data) { - Array reduce_axes; +inline ffi::Array MakeReduceAxes(const std::vector& real_axis, const Tensor& data) { + ffi::Array reduce_axes; for (auto i : real_axis) { std::string name = "k" + std::to_string(i); reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); @@ -96,10 +96,11 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, - bool keepdims, bool atleast1d) { +inline ffi::Array MakeReduceTargetShape(const std::vector& real_axis, + const Tensor& data, bool keepdims, + bool atleast1d) { auto ndim = data->shape.size(); - Array target_shape; + ffi::Array target_shape; if (keepdims) { for (size_t i = 0; i < ndim; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { @@ -136,13 +137,14 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, * * \return The result tensor. */ -inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array& target_shape, +inline Tensor DoCommReduce(const Tensor& data, FReduce func, + const ffi::Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes, Span span = Span()) { auto r_axes = MakeReduceAxes(reduce_axes, data); - auto compute = [&](const Array& indices) { - Array eval_range; - Array eval_indices; + auto compute = [&](const ffi::Array& indices) { + ffi::Array eval_range; + ffi::Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -179,8 +181,8 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array>& axis, FReduce func, - bool keepdims, bool atleast1d) { +inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, + FReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); @@ -202,7 +204,7 @@ inline Tensor CommReduce(const Tensor& data, const Optional>& axi * * \return The result tensor. */ -inline Tensor CommReduceIdx(const Tensor& data, const Optional>& axis, +inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional>& axis, FCommReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -211,9 +213,9 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, - &data](const Array& indices) { - Array eval_range; - Array eval_indices; + &data](const ffi::Array& indices) { + ffi::Array eval_range; + ffi::Array eval_indices; int arg_counter = 0; int red_counter = 0; @@ -233,7 +235,7 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& } } - Array ravel_shape; + ffi::Array ravel_shape; for (auto i : real_axis) { ravel_shape.push_back(data->shape[i]); } @@ -246,15 +248,15 @@ inline Tensor CommReduceIdx(const Tensor& data, const Optional>& auto temp_idx = temp_idx_val[0]; auto temp_val = temp_idx_val[1]; return tvm::te::compute( - target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); }, + target_shape, [&temp_idx](const ffi::Array& indices) { return temp_idx(indices); }, data->op->name + "_red", kCommReduceIdx); } /*! \brief A combiner function for a reduction */ -using FCombine = std::function(Array lhs, Array rhs)>; +using FCombine = std::function(ffi::Array lhs, ffi::Array rhs)>; /*! \brief An initializer function for a reduction */ -using FIdentity = std::function(std::vector types)>; +using FIdentity = std::function(std::vector types)>; /*! * \brief Create a commutative reducer for a reduction @@ -267,9 +269,9 @@ using FIdentity = std::function(std::vector types)>; */ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { - return [fcombine, fidentity, name](Array exprs, const Array& axis, + return [fcombine, fidentity, name](ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition) { - Array lhs, rhs; + ffi::Array lhs, rhs; std::vector dtypes; for (size_t i = 0; i < exprs.size(); ++i) { @@ -284,7 +286,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, auto cond = condition != nullptr ? *condition : tir::const_true(); auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast(i), {})); } @@ -293,19 +295,19 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline PrimExpr MinOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr MinOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::min(source, axis, init, span); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ -inline PrimExpr MaxOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr MaxOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::max(source, axis, init, span); // NOLINT(*) } /*! \brief Wrap tvm::prod to ensure we get the correct overload */ -inline PrimExpr ProdOp(PrimExpr source, Array axis, Array init = {}, +inline PrimExpr ProdOp(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()) { return tvm::prod(source, axis, init, span); // NOLINT(*) } @@ -323,8 +325,8 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis, Array ini * * \return A Tensor whose op member is the sum operation */ -inline Tensor sum(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor sum(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { if (data->dtype.is_bool()) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } else { @@ -332,7 +334,7 @@ inline Tensor sum(const Tensor& data, const Optional>& axis, bool } } -inline Tensor collapse_sum(const Tensor& data, Array target_shape) { +inline Tensor collapse_sum(const Tensor& data, ffi::Array target_shape) { const auto& ishape = data->shape; const auto& oshape = target_shape; int isize = data->shape.size(); @@ -380,8 +382,8 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { * * \return A Tensor whose op member is the all operation */ -inline Tensor all(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor all(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } @@ -399,8 +401,8 @@ inline Tensor all(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the all operation */ -inline Tensor any(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor any(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } @@ -418,8 +420,8 @@ inline Tensor any(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the min operation */ -inline Tensor min(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor min(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } @@ -437,15 +439,15 @@ inline Tensor min(const Tensor& data, const Optional>& axis, bool * * \return A Tensor whose op member is the max operation */ -inline Tensor max(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor max(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [=](Array lhs, Array rhs) { - Array result; + auto fcombine = [=](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; // Casting to avoid operator ambiguity PrimExpr lhs_idx = static_cast(lhs[0]); @@ -473,7 +475,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { return result; }; auto fidentity = [&](std::vector types) { - Array result; + ffi::Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val return result; @@ -497,7 +499,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { * * \return A Tensor whose op member is the argmin operation */ -inline Tensor argmin(const Tensor& data, const Optional>& axis, +inline Tensor argmin(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgminReducer(select_last_index); @@ -506,8 +508,8 @@ inline Tensor argmin(const Tensor& data, const Optional>& axis, inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [=](Array lhs, Array rhs) { - Array result; + auto fcombine = [=](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; // Casting to avoid operator ambiguity PrimExpr lhs_idx = static_cast(lhs[0]); @@ -535,7 +537,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { return result; }; auto fidentity = [&](std::vector types) { - Array result; + ffi::Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val return result; @@ -558,7 +560,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ -inline Tensor argmax(const Tensor& data, const Optional>& axis, +inline Tensor argmax(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgmaxReducer(select_last_index); @@ -578,8 +580,8 @@ inline Tensor argmax(const Tensor& data, const Optional>& axis, * * \return A Tensor whose op member is the prod operation */ -inline Tensor prod(const Tensor& data, const Optional>& axis, bool keepdims = false, - bool atleast1d = false) { +inline Tensor prod(const Tensor& data, const ffi::Optional>& axis, + bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } @@ -587,8 +589,8 @@ inline Tensor prod(const Tensor& data, const Optional>& axis, boo * \brief Create communitive reducer summing over tuples */ inline FCommReduce MakeTupleSumReducer() { - auto fcombine = [](Array lhs, Array rhs) { - Array result; + auto fcombine = [](ffi::Array lhs, ffi::Array rhs) { + ffi::Array result; ICHECK_EQ(lhs.size(), rhs.size()); result.reserve(lhs.size()); for (size_t i = 0; i < lhs.size(); ++i) { @@ -597,7 +599,7 @@ inline FCommReduce MakeTupleSumReducer() { return result; }; auto fidentity = [](std::vector types) { - Array result; + ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { result.push_back(tvm::tir::make_const(types[i], 0)); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 71b1bd3b8d25..2d7096613bdc 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -73,8 +73,8 @@ using namespace topi::detail; * * \return A Tensor whose op member is the sliding_window operation */ -inline Tensor sliding_window(const Tensor& x, int axis, Array window_shape, - Array strides, std::string name = "T_sliding_window", +inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array window_shape, + ffi::Array strides, std::string name = "T_sliding_window", std::string tag = "") { CHECK_GE(axis, 0); auto _axis = size_t(axis); @@ -85,7 +85,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, Array window_sh CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length."; // Compute the new shape. - Array new_shape; + ffi::Array new_shape; // Dimensions up until `axis` remain the same. for (size_t i = 0; i < _axis; ++i) { new_shape.push_back(x->shape[i]); @@ -113,9 +113,9 @@ inline Tensor sliding_window(const Tensor& x, int axis, Array window_sh return compute( new_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { // The index at which to index the old tensor x. - Array idx; + ffi::Array idx; // Dimensions up until `axis` remain the same. for (size_t i = 0; i < _axis; ++i) { @@ -164,7 +164,7 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, // Calculate offset from last dimension axis = ndim + axis + 1; } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -177,8 +177,8 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -201,16 +201,16 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, * * \return A Tensor whose op member is the transpose operation */ -inline Tensor transpose(const Tensor& x, Optional> opt_axes, +inline Tensor transpose(const Tensor& x, ffi::Optional> opt_axes, std::string name = "T_transpose", std::string tag = kInjective) { - Array axes = opt_axes.value_or({}); + ffi::Array axes = opt_axes.value_or({}); if (axes.size() == 0) { for (int i = static_cast(x->shape.size()) - 1; i >= 0; --i) { axes.push_back(i); } } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < axes.size(); ++i) { int axis = static_cast(axes[i]->value); int new_axis = axis; @@ -232,7 +232,7 @@ inline Tensor transpose(const Tensor& x, Optional> opt_axes, return compute( new_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { std::vector idx; for (size_t i = 0; i < axes.size(); ++i) { idx.push_back(1); @@ -292,8 +292,8 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; - auto func = [&](const Array& indices) { - Array real_indices; + auto func = [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < src_tensor_dim; ++i) { if (i == static_cast(seq_axis)) { if (seq_lengths.defined()) { @@ -325,10 +325,10 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s * * \return A Tensor whose op member is the reshape operation */ -inline Tensor reshape(const Tensor& x, Array newshape, std::string name = "T_reshape", - std::string tag = kInjective) { +inline Tensor reshape(const Tensor& x, ffi::Array newshape, + std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; - Array target_shape; + ffi::Array target_shape; for (const auto& ele : newshape) { target_shape.push_back(ele); @@ -337,13 +337,15 @@ inline Tensor reshape(const Tensor& x, Array newshape, std::string nam // If either the input shape or the target shape contains a zero, return an empty tensor. if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) { return compute( - target_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + target_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( target_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); + RavelIndex(ffi::Array{indices.begin(), indices.end()}, target_shape), + x_shape)); }, name, tag); } @@ -365,13 +367,13 @@ inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string na auto x_shape = x->shape; auto shape_shape = shape->shape; - Array oshape; + ffi::Array oshape; oshape.push_back(shape_shape[0]); if (x_shape.size() != 0) { oshape.push_back(x_shape[0]); } - auto func = [&](const Array& indices) { + auto func = [&](const ffi::Array& indices) { auto i = indices[0]; std::vector indices_divs; PrimExpr ret = 0; @@ -408,8 +410,9 @@ inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string na * * \return A Tensor whose op member is the squeeze operation */ -inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool atleast1d = false, - std::string name = "T_squeeze", std::string tag = kInjective) { +inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_axes, + bool atleast1d = false, std::string name = "T_squeeze", + std::string tag = kInjective) { auto ndim = x->shape.size(); std::vector axis_val; if (!opt_axes.has_value()) { @@ -419,7 +422,7 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a } } } else { - Array axis = *std::move(opt_axes); + ffi::Array axis = *std::move(opt_axes); for (size_t i = 0; i < axis.size(); ++i) { int64_t val = axis[i]->value; if (val < 0) { @@ -434,7 +437,7 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a std::unordered_set axis_set(axis_val.begin(), axis_val.end()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { out_shape.push_back(x->shape[i]); @@ -446,8 +449,8 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a return compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; int flag = 0; for (size_t i = 0; i < ndim; ++i) { if (axis_set.count(static_cast(i)) == 0) { @@ -472,8 +475,8 @@ inline Tensor squeeze(const Tensor& x, Optional> opt_axes, bool a * * \return A Tensor whose op member is the concatenate operation */ -inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "T_concat", - std::string tag = kInjective) { +inline Tensor concatenate(const ffi::Array& inputs, int axis = 0, + std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" << ", but got axis = " << axis << ", and ndim = " << ndim; @@ -482,7 +485,7 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string } ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; - Array axis_sizes; + ffi::Array axis_sizes; for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } @@ -492,20 +495,20 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string join_size += axis_sizes[i]; } join_size = analyzer.Simplify(join_size); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); } return compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto ret = inputs[0](indices); auto ind = indices[axis]; for (size_t i = 0; i < inputs.size() - 1; ++i) { ind -= axis_sizes[i]; - Array idx; + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -531,7 +534,7 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string * * \return A Tensor whose op member is the stack operation */ -inline Tensor stack(const Array& inputs, int axis = 0, std::string name = "T_stack", +inline Tensor stack(const ffi::Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); ICHECK(-ndim - 1 <= axis && axis <= ndim) @@ -543,7 +546,7 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); for (size_t i = static_cast(axis); i < static_cast(ndim); ++i) @@ -551,8 +554,8 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name return compute( out_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < indices.size(); ++i) if (i != static_cast(axis)) idx.push_back(indices[i]); auto ind = indices[axis]; @@ -577,9 +580,9 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, - std::string name = "T_split", - std::string tag = kInjective) { +inline ffi::Array split_indices_array(const Tensor& x, ffi::Array split_indices, + int axis, std::string name = "T_split", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -598,7 +601,7 @@ inline Array split_indices_array(const Tensor& x, Array split_ begin_ids.push_back(idx); } - Array> out_shapes; + ffi::Array> out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { @@ -607,7 +610,7 @@ inline Array split_indices_array(const Tensor& x, Array split_ out_axis_size = begin_ids[i + 1] - begin_ids[i]; } - Array shape; + ffi::Array shape; for (size_t i = 0; i < static_cast(axis); ++i) { shape.push_back(x->shape[i]); } @@ -619,13 +622,13 @@ inline Array split_indices_array(const Tensor& x, Array split_ out_shapes.push_back(shape); } - Array result; + ffi::Array result; for (size_t i = 0; i < begin_ids.size(); ++i) { result.push_back(compute( out_shapes[i], - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto begin = begin_ids[i]; - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(indices[j]); } @@ -707,9 +710,10 @@ inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExp * \return A Tensor whose op member is the dynamic_strided_slice operation */ inline te::Tensor dynamic_strided_slice_with_axes( - const te::Tensor& x, const Array& begin, const Array& end, - const Array& strides, const Array& axes, bool assume_inbound = true, - std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) { + const te::Tensor& x, const ffi::Array& begin, const ffi::Array& end, + const ffi::Array& strides, const ffi::Array& axes, + bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes", + std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); ICHECK_EQ(begin.size(), end.size()); ICHECK_EQ(begin.size(), strides.size()); @@ -723,7 +727,7 @@ inline te::Tensor dynamic_strided_slice_with_axes( arith::Analyzer analyzer; - Array out_shape = x->shape; + ffi::Array out_shape = x->shape; for (size_t i = 0; i < begin.size(); i++) { int axis = axes[i]->value; PrimExpr new_shape = @@ -733,8 +737,9 @@ inline te::Tensor dynamic_strided_slice_with_axes( return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices = indices.Map([](const auto& var) -> PrimExpr { return var; }); + [&](const ffi::Array& indices) { + ffi::Array real_indices = + indices.Map([](const auto& var) -> PrimExpr { return var; }); for (size_t i = 0; i < begin.size(); i++) { int axis = axes[i]->value; @@ -761,9 +766,9 @@ inline te::Tensor dynamic_strided_slice_with_axes( * * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - bool assume_inbound = true, +inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, bool assume_inbound = true, std::string name = "T_dynamic_strided_slice", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -774,7 +779,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi ICHECK_EQ(begin.size(), strides.size()); const size_t num_slice_axes = begin.size(); - Array out_shape; + ffi::Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < num_slice_axes; ++i) { @@ -794,8 +799,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < num_slice_axes; ++i) { real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); } @@ -832,7 +837,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); - Array begin_expr, end_expr, strides_expr; + ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { auto ind = make_const(index_dtype, i); begin_expr.push_back(begin(ind)); @@ -856,9 +861,12 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * * \return The output shape of strided_slice using the arguments above */ -inline Array StridedSliceOutputShape( - const Array& ishape, const Array& begin, const Array& end, - const Array& strides, const Array& axes, const std::string& slice_mode) { +inline ffi::Array StridedSliceOutputShape(const ffi::Array& ishape, + const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, + const std::string& slice_mode) { ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); @@ -884,9 +892,11 @@ inline Array StridedSliceOutputShape( * * \return A Tensor whose op member is the sstrided_slice operation */ -inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - const Array& axes, std::string slice_mode = "end", +inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, + std::string slice_mode = "end", std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -903,8 +913,8 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg return te::compute( out_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < axes.size(); ++i) { auto stride = make_const(strides[i].dtype(), strides_vec[i]); @@ -930,15 +940,16 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg * * \return A Tensor whose op member is the strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, - const Array& strides, std::string slice_mode = "end", - std::string name = "T_strided_slice", std::string tag = kInjective) { +inline Tensor strided_slice(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, const ffi::Array& strides, + std::string slice_mode = "end", std::string name = "T_strided_slice", + std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - Array axes; + ffi::Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - Array begin_full(begin); - Array end_full(end); - Array strides_full(strides); + ffi::Array begin_full(begin); + ffi::Array end_full(end); + ffi::Array strides_full(strides); DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64); const IntImm one = IntImm(index_dtype, 1); @@ -971,9 +982,9 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const * * \return A Tensor whose op member is the split operation */ -inline Array split_n_sections(const Tensor& x, int num_sections, int axis, - std::string name = "T_split_sections", - std::string tag = kInjective) { +inline ffi::Array split_n_sections(const Tensor& x, int num_sections, int axis, + std::string name = "T_split_sections", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -983,7 +994,7 @@ inline Array split_n_sections(const Tensor& x, int num_sections, int axi ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; - Array split_indices; + ffi::Array split_indices; auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() @@ -1010,8 +1021,8 @@ inline Array split_n_sections(const Tensor& x, int num_sections, int axi inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, std::string mode = "fast", std::string name = "T_take", std::string tag = kInjective) { - Array a_shape = a->shape; - Array out_shape = indices->shape; + ffi::Array a_shape = a->shape; + ffi::Array out_shape = indices->shape; PrimExpr a_size = 1; for (size_t i = 0; i < a_shape.size(); ++i) { a_size = a_size * a_shape[i]; @@ -1020,7 +1031,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, if (mode == "clip") { return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); }, @@ -1030,12 +1041,14 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, "Make sure input indices are in bound"; return compute( out_shape, - [&](const Array& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, + [&](const ffi::Array& out_index) { + return a(UnravelIndex(indices(out_index), a_shape)); + }, name, tag); } else if (mode == "nan") { return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = tvm::if_then_else( indices(out_index) < 0 || indices(out_index) >= a_size, tvm::FloatImm(a->dtype, std::numeric_limits::quiet_NaN()), indices(out_index)); @@ -1045,7 +1058,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, } else { // mode == "wrap" return compute( out_shape, - [&](const Array& out_index) { + [&](const ffi::Array& out_index) { auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); return a(UnravelIndex(idx, a_shape)); }, @@ -1072,11 +1085,11 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; auto length_dim = data->shape[axis]; auto batch_dim = data->shape[1 - axis]; - Array out_shape = data->shape; + ffi::Array out_shape = data->shape; Tensor out = compute( out_shape, - [&](const Array& out_index) { - Array len_index; + [&](const ffi::Array& out_index) { + ffi::Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); @@ -1103,8 +1116,8 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub * * \return A Tensor whose op member is the take operation */ -inline Tensor take(const Tensor& a, Variant indices, int batch_dims, int axis, - std::string mode = "fast", std::string name = "T_take", +inline Tensor take(const Tensor& a, ffi::Variant indices, int batch_dims, + int axis, std::string mode = "fast", std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); @@ -1112,7 +1125,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch ICHECK_GE(axis, 0) << "axis out of bounds"; ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; auto axis_dim = a->shape[axis]; - auto indices_shape = [&]() -> Array { + auto indices_shape = [&]() -> ffi::Array { if (auto tensor = indices.as()) { return tensor->shape; } else { @@ -1145,7 +1158,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch // The result shape is a.shape[:axis] + indices.shape[batch_dims:] + // a.shape[axis + 1:]. - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < batch_dims_; ++i) { out_shape.push_back(a->shape[i]); } @@ -1159,7 +1172,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch out_shape.push_back(a->shape[i]); } - auto get_index = [&](const Array& indices_position) -> PrimExpr { + auto get_index = [&](const ffi::Array& indices_position) -> PrimExpr { if (auto tensor = indices.as()) { return tensor.value()(indices_position); } else if (auto prim = indices.as()) { @@ -1174,12 +1187,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch if (batch_dims_ == 0) { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1194,15 +1207,15 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = 0; j < static_cast(batch_dims_); ++j) { indices_position.push_back(out_index[j]); } for (size_t j = axis; j < static_cast(axis + indices_len - batch_dims_); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1220,12 +1233,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch "Make sure input indices are in bound"; return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1239,12 +1252,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else if (mode == "nan") { return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1262,12 +1275,12 @@ inline Tensor take(const Tensor& a, Variant indices, int batch } else { // mode == "wrap" return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } - Array real_indices; + ffi::Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } @@ -1299,9 +1312,9 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, << y->dtype; auto get_out_shape = [&]() { auto bh1 = detail::BroadcastShape(x->shape, y->shape); - Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); + ffi::Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); auto bh2 = detail::BroadcastShape(condition->shape, common_shape1); - Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); + ffi::Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); return common_shape2; }; @@ -1311,7 +1324,7 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, auto x_bh = detail::BroadcastShape(x->shape, oshape); auto y_bh = detail::BroadcastShape(y->shape, oshape); - auto select = [&](tvm::Array ovars) { + auto select = [&](tvm::ffi::Array ovars) { auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); @@ -1345,7 +1358,7 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = // Calculate offset from last dimension axis += ndim; } - Array new_shape; + ffi::Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); } @@ -1356,8 +1369,8 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; for (size_t i = 0; i < static_cast(axis); ++i) { idx.push_back(indices[i]); } @@ -1380,14 +1393,14 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = * * \return A Tensor whose op member is the tile operation */ -inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_tile", +inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); size_t tdim = (ndim > rdim) ? ndim : rdim; - Array data_shape; - Array reps_shape; - Array new_shape; + ffi::Array data_shape; + ffi::Array reps_shape; + ffi::Array new_shape; if (ndim == rdim) { for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); @@ -1406,12 +1419,13 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; if (ndim >= rdim) { for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); } else { @@ -1435,17 +1449,18 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t * * \return A Tensor whose op member is the tile operation */ -inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, +inline Tensor dyn_tile(const Tensor& x, ffi::Array new_shape, size_t rdim, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + tag); } else { return compute( new_shape, - [&](const Array& indices) { - Array idx; + [&](const ffi::Array& indices) { + ffi::Array idx; if (ndim >= rdim) { for (size_t i = 0; i < ndim; ++i) { idx.push_back(indexmod(indices[i], x->shape[i])); @@ -1489,19 +1504,19 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; for (size_t i = 0; i < ndim_i; ++i) { indices_position.push_back(out_index[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < ndim_i; ++i) { if (i == static_cast(axis)) { real_indices.push_back(indices(indices_position)); @@ -1533,7 +1548,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim size_t indices_dim0 = static_cast(GetConstInt(indices->shape[0])); ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " << "than dimensions of data tensor"; - Array out_shape; + ffi::Array out_shape; for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } @@ -1542,13 +1557,13 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim } return compute( out_shape, - [&](const Array& out_index) { - Array indices_position; + [&](const ffi::Array& out_index) { + ffi::Array indices_position; indices_position.push_back(0); for (size_t i = 0; i < ndim_i - 1; ++i) { indices_position.push_back(out_index[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < static_cast(batch_dims); ++i) { real_indices.push_back(out_index[i]); } @@ -1589,7 +1604,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, bool trans_a = false, bool trans_b = false, std::string name = "T_matmul", std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; + tvm::ffi::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); @@ -1613,19 +1628,19 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, ICHECK_GE(A->shape.size(), axes); ICHECK_GE(B->shape.size(), axes); - Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); + ffi::Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); - Array iter_vars; + ffi::Array iter_vars; for (int i = 0; i < axes; ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); - auto func = [&A, &B, &iter_vars, axes](const Array& input_indices) { - Array A_indices(input_indices.begin(), - input_indices.begin() + (A->shape.size() - axes)); + auto func = [&A, &B, &iter_vars, axes](const ffi::Array& input_indices) { + ffi::Array A_indices(input_indices.begin(), + input_indices.begin() + (A->shape.size() - axes)); for (auto& v : iter_vars) A_indices.push_back(v); - Array B_indices; + ffi::Array B_indices; for (auto& v : iter_vars) B_indices.push_back(v); auto it = input_indices.begin() + (A->shape.size() - axes); @@ -1654,15 +1669,15 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array A_axes, - Array B_axes, std::string name = "T_tensordot", +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::Array A_axes, + ffi::Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { ICHECK_EQ(A_axes.size(), B_axes.size()); auto A_axes_val = GetConstIntValues(A_axes, "A_axes"); auto B_axes_val = GetConstIntValues(B_axes, "B_axes"); - Array output_shape; + ffi::Array output_shape; for (unsigned i = 0; i < A->shape.size(); ++i) if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end()) output_shape.push_back(A->shape[i]); @@ -1670,13 +1685,13 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Arrayshape[i]); - Array iter_vars; + ffi::Array iter_vars; for (unsigned i = 0; i < B_axes_val.size(); ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); - auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array& input_indices) { + auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const ffi::Array& input_indices) { int idx_input = 0; - Array A_indices; + ffi::Array A_indices; for (unsigned i = 0; i < A->shape.size(); ++i) { auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); if (axes_pos == A_axes_val.end()) { @@ -1686,7 +1701,7 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array B_indices; + ffi::Array B_indices; for (unsigned i = 0; i < B->shape.size(); ++i) { auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); if (axes_pos == B_axes_val.end()) { @@ -1720,8 +1735,8 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr return compute( {num_elem}, - [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, - tag); + [&](const ffi::Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, + name, tag); } /*! @@ -1734,22 +1749,22 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr * * \return A Tensor whose op member is the meshgrid operation */ -inline Array meshgrid(const Array& inputs, const std::string& indexing, - std::string name = "T_meshgrid", std::string tag = kInjective) { +inline ffi::Array meshgrid(const ffi::Array& inputs, const std::string& indexing, + std::string name = "T_meshgrid", std::string tag = kInjective) { const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2; - Array out_shape; + ffi::Array out_shape; for (size_t i = 0; i < inputs.size(); ++i) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]); } - Array result; + ffi::Array result; for (size_t i = 0; i < inputs.size(); ++i) { result.push_back(compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i; auto ndim = inputs[i]->GetShape().size(); - Array real_indices = {}; + ffi::Array real_indices = {}; if (ndim > 0) { real_indices = {indices[src_index]}; } @@ -1789,19 +1804,19 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, ICHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; - Array dst_shape = layout_converter.ForwardShape(src->shape); + ffi::Array dst_shape = layout_converter.ForwardShape(src->shape); - Map attrs = {{"schedule_rule", String(schedule_rule)}, - // Information about layouts needed for the schedule rule - {"src_layout", String(src_layout)}, - {"dst_layout", String(dst_layout)}, - {"input_shape", src->shape}}; + ffi::Map attrs = {{"schedule_rule", ffi::String(schedule_rule)}, + // Information about layouts needed for the schedule rule + {"src_layout", ffi::String(src_layout)}, + {"dst_layout", ffi::String(dst_layout)}, + {"input_shape", src->shape}}; return compute( dst_shape, - [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + [&](const ffi::Array& dst_indices) { + ffi::Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + ffi::Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true for (size_t i = 0; i < src.ndim(); ++i) { in_range = in_range && (src_indices[i] < src->shape[i]); @@ -1812,7 +1827,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, } /*! \brief Utility function for auto_scheduler_layout_transform */ -inline void parse_auto_scheduler_layout(const String& layout, Array* shape, +inline void parse_auto_scheduler_layout(const ffi::String& layout, ffi::Array* shape, std::vector* axes) { int32_t factor = 0; std::string axis = ""; @@ -1848,22 +1863,21 @@ inline void parse_auto_scheduler_layout(const String& layout, Array* s * \param tag output tensor tag. * \return A tensor with shape in \p dst_layout */ -inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout, - const String& dst_layout, - const String name = "T_auto_scheduler_layout_trans", - const String tag = kInjective) { - Array src_shape; +inline Tensor auto_scheduler_layout_transform( + const Tensor& src, const ffi::String& src_layout, const ffi::String& dst_layout, + const ffi::String name = "T_auto_scheduler_layout_trans", const ffi::String tag = kInjective) { + ffi::Array src_shape; std::vector src_axes; - Array dst_shape; + ffi::Array dst_shape; std::vector dst_axes; parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes); parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes); return compute( dst_shape, - [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices; + [&](const ffi::Array& dst_indices) { + ffi::Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + ffi::Array src_indices; for (const std::string& src_axis : src_axes) { PrimExpr src_index = 0; CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); @@ -1915,21 +1929,22 @@ inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& s * In this case, the transformation pattern is: * A'[a, b, c, d] = A[a * 4 + c, b * 16 + d] */ -inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map, - const String name = "T_meta_schedule_layout_trans", - const String tag = kInjective) { +inline Tensor meta_schedule_layout_transform( + const Tensor& src, const tir::IndexMap& index_map, + const ffi::String name = "T_meta_schedule_layout_trans", const ffi::String tag = kInjective) { arith::Analyzer analyzer; - Array iter_domain; + ffi::Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); } - Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); + ffi::Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); return compute( post_transform_shape, [src, inv = index_map.Inverse(iter_domain, &analyzer), - &analyzer](const Array& indices) -> PrimExpr { - return src(inv->MapIndices(Array{indices.begin(), indices.end()}, &analyzer)); + &analyzer](const ffi::Array& indices) -> PrimExpr { + return src( + inv->MapIndices(ffi::Array{indices.begin(), indices.end()}, &analyzer)); }, name, tag); } @@ -1945,10 +1960,10 @@ inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::Index inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_shape{ndim}; + ffi::Array out_shape{ndim}; return compute( out_shape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { auto idx = indices[0]; PrimExpr ret = 0; for (int i = 0; i < ndim; ++i) { @@ -1971,10 +1986,10 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, const std::string& name = "tensor_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); - Array out_tensor_size = {}; + ffi::Array out_tensor_size = {}; return compute( out_tensor_size, - [&](const Array& indices) { + [&](const ffi::Array& indices) { PrimExpr ret = 1; for (int i = 0; i < ndim; ++i) { ret *= src->shape[i]; @@ -2000,7 +2015,7 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType& dtype, - Array oshape = Array(), + ffi::Array oshape = ffi::Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { int true_axis = (axis == -1) ? indices->shape.size() : axis; if (oshape.size() == 0) { @@ -2019,8 +2034,8 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim PrimExpr off_value_cast = cast(dtype, off_value); return compute( oshape, - [&](const Array& iter_vars) { - Array indices_indices; + [&](const ffi::Array& iter_vars) { + ffi::Array indices_indices; for (size_t i = 0; i < iter_vars.size(); i++) { if (static_cast(i) == true_axis) { continue; @@ -2045,8 +2060,9 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim * \param tag output tensor tag. * \return Tensor of output_shape. */ -inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& output_shape, - const Tensor& sparse_values, const PrimExpr& default_value, +inline Tensor sparse_to_dense(const Tensor& sparse_indices, + const ffi::Array& output_shape, const Tensor& sparse_values, + const PrimExpr& default_value, const std::string name = "T_sparse_to_dense", const std::string tag = kInjective) { ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; @@ -2055,13 +2071,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Arrayshape.size(), 2) << "sparse_values tensor should be 0D or 1D only"; const auto rank_sparse_indices = static_cast(sparse_indices->shape.size()); - Array oshape; + ffi::Array oshape; for (auto l : output_shape) { oshape.push_back(l); } return compute( oshape, - [&](const Array& indices) { + [&](const ffi::Array& indices) { PrimExpr ret = default_value; if (0 == rank_sparse_indices) { ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret); @@ -2106,9 +2122,9 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k return compute( input->shape, - [&](const Array& iter_vars) { + [&](const ffi::Array& iter_vars) { auto get_diag = [&]() { - Array diagonal_indices; + ffi::Array diagonal_indices; PrimExpr k, offset = 0; for (size_t i = 0; i < ndim - 1; i++) { diagonal_indices.push_back(iter_vars[i]); @@ -2152,18 +2168,18 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k * \param tag output tensor tag. * \return Output tensor. */ -inline Tensor adv_index(const Tensor& data, const Array& indices, +inline Tensor adv_index(const Tensor& data, const ffi::Array& indices, const std::string name = "advanced_index", const std::string tag = kInjective) { ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!"; - Array oshape; - Array broadcast_shape; - Array bindices; + ffi::Array oshape; + ffi::Array broadcast_shape; + ffi::Array bindices; broadcast_shape = indices[0]->shape; for (size_t i = 1; i < indices.size(); ++i) { auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape); - broadcast_shape = Array(bh.common_shape.begin(), bh.common_shape.end()); + broadcast_shape = ffi::Array(bh.common_shape.begin(), bh.common_shape.end()); } if (indices.size() == 1) { // quick path @@ -2184,12 +2200,12 @@ inline Tensor adv_index(const Tensor& data, const Array& indices, return compute( oshape, - [&](const Array& iter_var) { - Array tensor_indices; + [&](const ffi::Array& iter_var) { + ffi::Array tensor_indices; for (size_t i = 0; i < broadcast_shape.size(); ++i) { tensor_indices.push_back(iter_var[i]); } - Array real_indices; + ffi::Array real_indices; for (size_t i = 0; i < bindices.size(); ++i) { real_indices.push_back(bindices[i](tensor_indices)); } @@ -2206,7 +2222,7 @@ namespace relax { // relax dynamic slice inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, const te::Tensor& end, const te::Tensor& strides, - Array output_shape, + ffi::Array output_shape, std::string name = "T_strided_slice_dynamic", std::string tag = kInjective) { const size_t num_dynamic_axes = x.ndim(); @@ -2225,8 +2241,8 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b return te::compute( output_shape, - [&](const Array& indices) { - Array real_indices; + [&](const ffi::Array& indices) { + ffi::Array real_indices; for (size_t i = 0; i < num_dynamic_axes; ++i) { auto ind = make_const(DataType::Int(64), i); real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1)); diff --git a/include/tvm/topi/utils.h b/include/tvm/topi/utils.h index b5f2d6c38d61..41a2cce0e4f9 100644 --- a/include/tvm/topi/utils.h +++ b/include/tvm/topi/utils.h @@ -32,17 +32,17 @@ namespace topi { using namespace tvm::runtime; -/*! \brief Canonicalize an argument that may be Array or int to Array */ -inline Optional> ArrayOrInt(AnyView arg) { +/*! \brief Canonicalize an argument that may be ffi::Array or int to ffi::Array */ +inline ffi::Optional> ArrayOrInt(AnyView arg) { if (arg == nullptr) { return std::nullopt; } if (auto opt_int = arg.try_cast()) { - Array result; + ffi::Array result; result.push_back(opt_int.value()); return result; } else { - return arg.cast>(); + return arg.cast>(); } } } // namespace topi diff --git a/include/tvm/topi/vision/reorg.h b/include/tvm/topi/vision/reorg.h index 381272bb818c..f9a089d1abdc 100644 --- a/include/tvm/topi/vision/reorg.h +++ b/include/tvm/topi/vision/reorg.h @@ -72,7 +72,7 @@ inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tens int out_h = h_in / stride; int out_w = w_in / stride; - Array out_shape = {batch, out_c, out_h, out_w}; + ffi::Array out_shape = {batch, out_c, out_h, out_w}; return reshape(out, out_shape); } } // namespace vision diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9c4220ce29b6..a96f3cdf223b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -103,7 +103,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // We may consider enhance the sub analyzer to directly take // MarkPositiveVar so their bounds do not overlap if (const auto* var_ptr = symbol.as()) { - Var var = GetRef(var_ptr); + Var var = ffi::GetRef(var_ptr); // skip non-index type, keep it to be compatible // with any_dim that do not represent any value if (!IsIndexType(var.dtype())) return; @@ -116,7 +116,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { } } -void Analyzer::Bind(const Map& variables, bool allow_override) { +void Analyzer::Bind(const ffi::Map& variables, bool allow_override) { for (const auto& iter : variables) { this->Bind(iter.first, iter.second, allow_override); } @@ -202,7 +202,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // This is to avoid repeatitive calling of this function // that causes speed issues. // This strategy can only be called from top-level and not from sub-analyzers. - Optional pos_diff; + ffi::Optional pos_diff; int lower_bound = 0; if (const auto* ptr_lt = expr.as()) { pos_diff = ptr_lt->b - ptr_lt->a; @@ -322,7 +322,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); } else if (name == "int_set") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); + *ret = self->int_set(args[0].cast(), args[1].cast>()); }); } else if (name == "bind") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index f7720095eb2d..ed941c7dbdad 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -390,8 +390,8 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, - const Map& relax_map) { +IntSet DeduceBound(PrimExpr v, PrimExpr e, const ffi::Map& hint_map, + const ffi::Map& relax_map) { std::unordered_map hmap; for (auto kv : hint_map) { hmap[kv.first.get()] = kv.second; @@ -405,10 +405,11 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "arith.DeduceBound", - [](PrimExpr v, PrimExpr cond, const Map hint_map, - const Map relax_map) { return DeduceBound(v, cond, hint_map, relax_map); }); + refl::GlobalDef().def("arith.DeduceBound", + [](PrimExpr v, PrimExpr cond, const ffi::Map hint_map, + const ffi::Map relax_map) { + return DeduceBound(v, cond, hint_map, relax_map); + }); }); } // namespace arith diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 7a02a3bedba8..0f7be4466743 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -680,7 +680,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { expr = op->Normalize(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = expr.dtype(); n->index = std::move(expr); n->div_mode = kTruncDiv; @@ -717,7 +717,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (auto op = expr.as()) { return op.value(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = expr.dtype(); if (const auto* op = expr.as()) { n->base = op->value; @@ -816,7 +816,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { const MulNode* mul = ret.as(); if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ret; } @@ -825,8 +825,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible) { - auto divisible = make_object(); - auto non_divisible = make_object(); + auto divisible = ffi::make_object(); + auto non_divisible = ffi::make_object(); divisible->dtype = psum->dtype; non_divisible->dtype = psum->dtype; @@ -894,7 +894,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // we just skip to save the time if (prhs->as()) return false; // collect lhs products and try to eliminate by matching them to prod in rhs - Array> lhs_prods; + ffi::Array> lhs_prods; PrimExpr new_rhs = make_const(prhs->dtype(), 1); PrimExpr new_common_scale = make_const(prhs->dtype(), 1); int64_t lhs_cscale = 1, rhs_cscale = 1; @@ -939,7 +939,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // construct prod via canonical form PrimExpr new_lhs = make_const(plhs->dtype(), 1); - for (Optional val : lhs_prods) { + for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale); @@ -1006,7 +1006,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { return truncdiv(a, b); } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Div(a, b); } @@ -1066,7 +1066,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { return floordiv(a, b); } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorDiv(a, b); } @@ -1194,7 +1194,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Mod(a, b); } @@ -1259,7 +1259,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorMod(a, b); } @@ -1268,7 +1268,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Simplify reduce expression. PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results - Array simplified_result; + ffi::Array simplified_result; for (const auto& res : op->combiner->result) { PrimExpr new_res = this->VisitExpr(res); simplified_result.push_back(new_res); @@ -1311,12 +1311,12 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) } int new_value_index = op->value_index; - Array new_result; - Array new_identity; - Array new_lhs; - Array new_rhs; - Array new_source; - Array new_init; + ffi::Array new_result; + ffi::Array new_identity; + ffi::Array new_lhs; + ffi::Array new_rhs; + ffi::Array new_source; + ffi::Array new_init; // new stuff is old stuff which is used for (size_t i = 0; i < used.size(); ++i) { diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 2c905dd563ef..dda7f6746598 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -48,7 +48,7 @@ namespace arith { * \return std::nullopt if constant fold fails, otherwise return folded result. */ template -inline Optional TryConstFold(PrimExpr a, PrimExpr b); +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b); /*! * \brief Try to run unary compute with constant folding. @@ -60,7 +60,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b); * \return std::nullopt if constant fold fails, otherwise return folded result. */ template -inline Optional TryConstFold(PrimExpr a); +inline ffi::Optional TryConstFold(PrimExpr a); /*! * \brief Check whether type is used to represent index. @@ -128,7 +128,7 @@ inline double GetFoldResultDoubleRepr(float x) { // specialization of constant folders. template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -152,7 +152,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && (pb && pb->dtype.is_uint() && pb->value > 0U))) @@ -178,7 +178,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -214,7 +214,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -250,7 +250,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -270,7 +270,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -305,7 +305,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -325,7 +325,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); @@ -336,7 +336,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); @@ -347,7 +347,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); @@ -356,7 +356,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); @@ -365,7 +365,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); @@ -374,7 +374,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); @@ -383,7 +383,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); @@ -392,7 +392,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); @@ -401,7 +401,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -412,7 +412,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -423,7 +423,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline Optional TryConstFold(PrimExpr a) { +inline ffi::Optional TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index c2dd8f120a99..9f5a0ab00084 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -42,7 +42,7 @@ using namespace tir; TVM_FFI_STATIC_INIT_BLOCK({ ConstIntBoundNode::RegisterReflection(); }); ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { - auto node = make_object(); + auto node = ffi::make_object(); node->min_value = min_value; node->max_value = max_value; data_ = std::move(node); @@ -387,7 +387,7 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; @@ -397,7 +397,7 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const SizeVarNode* op) final { - SizeVar v = GetRef(op); + SizeVar v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; @@ -744,7 +744,7 @@ class ConstIntBoundAnalyzer::Impl * This expression is used as the implementation of * topi.math.ceil_log2, and can appear in iteration bounds. */ - static Optional FindCeilLog2Arg(const CastNode* op) { + static ffi::Optional FindCeilLog2Arg(const CastNode* op) { if (op->dtype.is_int()) { if (auto as_call = op->value.as()) { if (as_call->op.same_as(Op::Get("tir.ceil"))) { diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 3c7d4e0e4bea..a10105f7c3c8 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -33,7 +33,7 @@ namespace arith { using namespace tir; -Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { +ffi::Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { // Check the threshold in the range of size_t CHECK_GE(thresh, std::numeric_limits::min()); CHECK_LE(thresh, std::numeric_limits::max()); @@ -63,7 +63,7 @@ Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { } // Return the common sub expr that occur more than thresh times - Map results; + ffi::Map results; for (auto& it : semantic_comp_done_by_expr) { if (it.second >= repeat_thr) results.Set(it.first, it.second); } diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index e6746efd3717..d86dace8725d 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -142,14 +142,14 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) { +ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars) { PrimExpr base = e; - Array coeff; + ffi::Array coeff; for (Var v : vars) { LinearEqEntry ret; if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array(); + return ffi::Array(); } coeff.push_back(ret.coeff); base = std::move(ret.base); @@ -162,7 +162,7 @@ Array DetectLinearEquation(const PrimExpr& e, const Array& vars) vset.insert(vars[i - 1].get()); // The previous coeff contains the variable if (UsesVar(coeff[i - 2], vset_contains)) { - return Array(); + return ffi::Array(); } } coeff.push_back(base); @@ -218,8 +218,8 @@ bool DetectClipBound(const PrimExpr& cond, ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; - Optional min_value; - Optional max_value; + ffi::Optional min_value; + ffi::Optional max_value; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift min_value = -ret.base; @@ -265,7 +265,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector* ret) { // Detect the lower and upper bound from the expression. // e must be connected by and. -Array DetectClipBound(const PrimExpr& e, const Array& vars) { +ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars) { std::vector splits; Analyzer analyzer; SplitCommExpr(analyzer.Simplify(e), &splits); @@ -274,9 +274,9 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { rmap[v.get()] = IntervalEntry(); } for (PrimExpr cond : splits) { - if (!DetectClipBound(cond, &rmap)) return Array(); + if (!DetectClipBound(cond, &rmap)) return ffi::Array(); } - Array ret; + ffi::Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { @@ -296,7 +296,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("arith.DetectLinearEquation", DetectLinearEquation) .def("arith.DetectClipBound", - [](const PrimExpr& e, const Array& vars) { return DetectClipBound(e, vars); }); + [](const PrimExpr& e, const ffi::Array& vars) { return DetectClipBound(e, vars); }); }); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 96a269d7294f..319f786f6a37 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -115,7 +115,7 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { } private: - void Touch(BufferTouches* bounds, const Array& args) { + void Touch(BufferTouches* bounds, const ffi::Array& args) { if (args.size() > bounds->size()) { bounds->resize(args.size()); } @@ -136,25 +136,25 @@ Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); } -Map> DomainTouchedAccessMap(const PrimFunc& func) { +ffi::Map> DomainTouchedAccessMap(const PrimFunc& func) { auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions(); - Map> ret; + ffi::Map> ret; auto& buffer_map = func->buffer_map; for (auto& var : func->params) { auto& buffer = buffer_map[var]; auto& access = buffer_access_map[buffer.get()]; - Array> loads, stores, combined; + ffi::Array> loads, stores, combined; for (std::vector& touch : std::get(access).set) { - loads.push_back(Array(touch)); + loads.push_back(ffi::Array(touch)); } for (std::vector& touch : std::get(access).set) { - stores.push_back(Array(touch)); + stores.push_back(ffi::Array(touch)); } for (std::vector& touch : std::get(access).set) { - combined.push_back(Array(touch)); + combined.push_back(ffi::Array(touch)); } - Array fields; + ffi::Array fields; fields.push_back(loads); fields.push_back(stores); fields.push_back(combined); diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index b074e6400aaf..eec0fd2ef1b7 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -45,9 +45,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ IntConstraintsTransformNode::RegisterReflection(); }); -Array AsConditions(const Array& variables, const Map& bounds, - const Array& relations) { - Array res; +ffi::Array AsConditions(const ffi::Array& variables, + const ffi::Map& bounds, + const ffi::Array& relations) { + ffi::Array res; // use variables to keep the order of iteration // so as to get rid of any non-determinism. ICHECK_EQ(variables.size(), bounds.size()); @@ -71,11 +72,11 @@ Array AsConditions(const Array& variables, const Map lower, Array equal, - Array upper) { +IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, + ffi::Array equal, ffi::Array upper) { ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGroupBounds must be integers"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->coef = std::move(coef); node->lower = std::move(lower); node->equal = std::move(equal); @@ -86,9 +87,9 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Arraymin.dtype(), 1); - Array equal; - Array lower; - Array upper; + ffi::Array equal; + ffi::Array lower; + ffi::Array upper; if (tir::is_one(r->extent)) { equal.push_back(r->min); } else { @@ -100,9 +101,9 @@ IntGroupBounds IntGroupBounds::FromRange(const Range& r) { IntGroupBounds IntGroupBounds::operator+(const Range& r) { Analyzer analyzer; - Array equal; - Array lower; - Array upper; + ffi::Array equal; + ffi::Array lower; + ffi::Array upper; const PrimExpr& coef = operator->()->coef; if (tir::is_one(r->extent)) { equal.push_back(analyzer.Simplify(r->min * coef)); @@ -116,7 +117,7 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { return IntGroupBounds(coef, lower, equal, upper); } -IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { +IntGroupBounds IntGroupBounds::Substitute(const ffi::Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; return IntGroupBounds(tir::Substitute(operator->()->coef, subst), tir::UpdateArray(operator->()->lower, apply_fun), @@ -124,7 +125,7 @@ IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const tir::UpdateArray(operator->()->upper, apply_fun)); } -Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { +Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) const { Analyzer analyzer; analyzer.Bind(vranges_addl); @@ -133,7 +134,7 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { var_intsets[kv.first.get()] = IntSet::FromRange(kv.second); } - const Array& equal = operator->()->equal; + const ffi::Array& equal = operator->()->equal; const PrimExpr& coef = operator->()->coef; std::vector lowers(equal.begin(), equal.end()); @@ -161,7 +162,7 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { // Since diff may depend on some other variables, we compute its overapproximation - Optional diff_over; + ffi::Optional diff_over; PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); IntSet diff_set1 = EvalSet(diff_1, var_intsets); if (diff_set1.HasUpperBound()) { @@ -204,9 +205,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("arith.IntGroupBounds", - [](PrimExpr coef, Array lower, Array equal, Array upper) { - return IntGroupBounds(coef, lower, equal, upper); - }) + [](PrimExpr coef, ffi::Array lower, ffi::Array equal, + ffi::Array upper) { return IntGroupBounds(coef, lower, equal, upper); }) .def("arith.IntGroupBounds_from_range", IntGroupBounds::FromRange) .def_packed("arith.IntGroupBounds_FindBestRange", [](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.size() == 1 || args.size() == 2); @@ -214,7 +214,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = bounds.FindBestRange(); } else if (args.size() == 2) { - *ret = bounds.FindBestRange(args[1].cast>()); + *ret = bounds.FindBestRange(args[1].cast>()); } }); }); @@ -226,14 +226,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", equal=" << op->equal << ", upper=" << op->upper << ")"; }); -IntConstraints::IntConstraints(Array variables, Map ranges, - Array relations) { - ObjectPtr node = make_object(); +IntConstraints::IntConstraints(ffi::Array variables, ffi::Map ranges, + ffi::Array relations) { + ObjectPtr node = ffi::make_object(); if (!variables.defined()) { - variables = Array(); + variables = ffi::Array(); } if (!ranges.defined()) { - ranges = Map(); + ranges = ffi::Map(); } ICHECK(relations.defined()); for (const auto& var : variables) { @@ -248,10 +248,11 @@ IntConstraints::IntConstraints(Array variables, Map ranges, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.IntConstraints", [](Array variables, Map ranges, - Array relations) { - return IntConstraints(variables, ranges, relations); - }); + refl::GlobalDef().def( + "arith.IntConstraints", + [](ffi::Array variables, ffi::Map ranges, ffi::Array relations) { + return IntConstraints(variables, ranges, relations); + }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -262,9 +263,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, - Map src_to_dst, - Map dst_to_src) { - ObjectPtr node = make_object(); + ffi::Map src_to_dst, + ffi::Map dst_to_src) { + ObjectPtr node = ffi::make_object(); node->src = std::move(src); node->dst = std::move(dst); node->src_to_dst = std::move(src_to_dst); @@ -275,8 +276,8 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai IntConstraintsTransform IntConstraintsTransform::operator+( const IntConstraintsTransform& other) const { ICHECK(other->src.same_as(operator->()->dst)); - Map dst_to_src; - Map src_to_dst; + ffi::Map dst_to_src; + ffi::Map src_to_dst; Analyzer ana_first; ana_first.Bind(operator->()->src->ranges); @@ -295,8 +296,8 @@ IntConstraintsTransform IntConstraintsTransform::operator+( TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.IntConstraintsTransform", - [](IntConstraints src, IntConstraints dst, Map src_to_dst, - Map dst_to_src) { + [](IntConstraints src, IntConstraints dst, + ffi::Map src_to_dst, ffi::Map dst_to_src) { return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); }); }); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6bd0400673be..b37680376a35 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -50,7 +50,7 @@ PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { - auto node = make_object(); + auto node = ffi::make_object(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -368,7 +368,7 @@ using namespace tir; // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, + IntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map, const std::vector>* dom_constraints = nullptr, bool eval_vec = false) : analyzer_(analyzer), @@ -390,13 +390,13 @@ class IntervalSetEvaluator : public ExprFunctor { } IntervalSet VisitExpr_(const IntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } IntervalSet VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); - Array values; + ffi::Array values; if (dom_constraints_) { for (const auto& constraint : *dom_constraints_) { if (var.same_as(constraint.first)) { @@ -491,7 +491,7 @@ class IntervalSetEvaluator : public ExprFunctor { } } } - DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + DLOG(WARNING) << "cannot evaluate set on expression " << ffi::GetRef(op); return IntervalSet::Everything(); } @@ -530,17 +530,17 @@ class IntervalSetEvaluator : public ExprFunctor { // Otherwise return `IntervalSet::everything()` since we have no knowledge on the buffer data. for (const PrimExpr& index : op->indices) { if (UsesVar(index, [dom_map = &this->dom_map_](const VarNode* var) { - return dom_map->find(GetRef(var)) != dom_map->end(); + return dom_map->find(ffi::GetRef(var)) != dom_map->end(); })) { return IntervalSet::Everything(); } } - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } IntervalSet VisitExpr_(const CallNode* op) final { if (op->op.same_as(tir::builtin::vscale())) - return IntervalSet(GetRef(op), GetRef(op)); + return IntervalSet(ffi::GetRef(op), ffi::GetRef(op)); return IntervalSet::Everything(); } @@ -561,7 +561,7 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntervalSet::SinglePoint(GetRef(op)); + return IntervalSet::SinglePoint(ffi::GetRef(op)); } return Combine(analyzer_, a, b, op->dtype); } @@ -570,7 +570,7 @@ class IntervalSetEvaluator : public ExprFunctor { int recur_depth_{0}; // analyzer Analyzer* analyzer_; - const Map& dom_map_; + const ffi::Map& dom_map_; const std::vector>* dom_constraints_; bool eval_vec_{false}; }; @@ -579,7 +579,7 @@ class IntSetAnalyzer::Impl { public: explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} - IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { + IntSet Eval(const PrimExpr& expr, const ffi::Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } @@ -605,11 +605,11 @@ class IntSetAnalyzer::Impl { // Map of variables to global variable bounds (e.g. loop iterator // ranges) - Map dom_map_; + ffi::Map dom_map_; // List of implicit scope-dependent bounds (e.g. inside the body of // an if-statement). Maintained as a list of constraints, rather - // than as a `Map`, to avoid computing an Intersection + // than as a `ffi::Map`, to avoid computing an Intersection // until required. std::vector> dom_constraints_; }; @@ -618,7 +618,7 @@ IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -861,7 +861,7 @@ bool IntSet::MatchRange(const Range& b) const { ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } -IntSet Union(const Array& sets) { +IntSet Union(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer ana; @@ -872,16 +872,16 @@ IntSet Union(const Array& sets) { return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } -Array UnionRegion(const Array>& nd_int_sets) { +ffi::Array UnionRegion(const ffi::Array>& nd_int_sets) { if (nd_int_sets.empty()) { return {}; } int n = nd_int_sets.size(); int ndim = nd_int_sets[0].size(); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { - Array candidates; + ffi::Array candidates; candidates.reserve(n); for (int j = 0; j < n; ++j) { candidates.push_back(nd_int_sets[j][i]); @@ -891,7 +891,7 @@ Array UnionRegion(const Array>& nd_int_sets) { return result; } -IntSet UnionLowerBound(const Array& sets) { +IntSet UnionLowerBound(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer analyzer; @@ -925,16 +925,16 @@ IntSet UnionLowerBound(const Array& sets) { return IntSet::Interval(min_inclusive, max_inclusive); } -Array UnionRegionLowerBound(const Array>& nd_int_sets) { +ffi::Array UnionRegionLowerBound(const ffi::Array>& nd_int_sets) { if (nd_int_sets.empty()) { return {}; } int n = nd_int_sets.size(); int ndim = nd_int_sets[0].size(); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { - Array candidates; + ffi::Array candidates; candidates.reserve(n); for (int j = 0; j < n; ++j) { candidates.push_back(nd_int_sets[j][i]); @@ -944,7 +944,7 @@ Array UnionRegionLowerBound(const Array>& nd_int_sets) { return result; } -IntSet Intersect(const Array& sets) { +IntSet Intersect(const ffi::Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; Analyzer ana; @@ -955,23 +955,23 @@ IntSet Intersect(const Array& sets) { return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } -Map ConvertDomMap(const Map& dom_map) { - Map dmap; +ffi::Map ConvertDomMap(const ffi::Map& dom_map) { + ffi::Map dmap; for (auto kv : dom_map) { dmap.Set(kv.first->var, kv.second); } return dmap; } -Map ConvertDomMap(const std::unordered_map& dom_map) { - Map dmap; +ffi::Map ConvertDomMap(const std::unordered_map& dom_map) { + ffi::Map dmap; for (auto kv : dom_map) { - dmap.Set(GetRef(kv.first), kv.second); + dmap.Set(ffi::GetRef(kv.first), kv.second); } return dmap; } -IntSet EvalSet(PrimExpr e, const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); } @@ -983,12 +983,12 @@ IntSet IntSet::Vector(PrimExpr x) { } else { // vector case. Analyzer ana; - Map dmap; + ffi::Map dmap; return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); } } -IntSet EvalSet(PrimExpr e, const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } @@ -996,7 +996,7 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, const Map& dom_map) { +IntSet EvalSet(Range r, const ffi::Map& dom_map) { Analyzer ana; if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { return EvalSet(r->min, dom_map); @@ -1012,10 +1012,10 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_ma return EvalSet(r, ConvertDomMap(dom_map)); } -Array EvalSet(const Array& region, const Map& dom_map) { +ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); - Array result; + ffi::Array result; result.reserve(region.size()); for (const Range& r : region) { PrimExpr sum = r->min + (r->extent - 1); @@ -1036,7 +1036,7 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map) + explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -1057,12 +1057,12 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, return m.expr_map; } -IntSet EvalSet(Range r, const Map& dom_map) { +IntSet EvalSet(Range r, const ffi::Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -Map AsIntSet(const Map& var_dom) { - Map result; +ffi::Map AsIntSet(const ffi::Map& var_dom) { + ffi::Map result; for (auto kv : var_dom) { const Var& var = kv.first; const Range& range = kv.second; @@ -1072,8 +1072,8 @@ Map AsIntSet(const Map& var_dom) { } /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ -static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, - Analyzer* analyzer) { +static ffi::Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, + Analyzer* analyzer) { if (analyzer->CanProve(extent == 0)) { return IntSet::Nothing(); } @@ -1105,13 +1105,14 @@ static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& } } -Optional> EstimateRegionStrictBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { +ffi::Optional> EstimateRegionStrictBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + Analyzer* analyzer) { int ndim = region.size(); - Array iter_sum_exprs{nullptr}; + ffi::Array iter_sum_exprs{nullptr}; { - Array affine_indices; + ffi::Array affine_indices; affine_indices.reserve(ndim); for (const Range& range : region) { if (!is_const_number(range->extent)) { @@ -1129,12 +1130,12 @@ Optional> EstimateRegionStrictBound(const Array& region, return std::nullopt; } ICHECK_EQ(iter_sum_exprs.size(), ndim); - Array result; + ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { const IterSumExpr& sum_expr = iter_sum_exprs[i]; const Range& range = region[i]; - Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer); + ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer); if (int_set.defined()) { result.push_back(int_set.value()); } else { @@ -1144,22 +1145,23 @@ Optional> EstimateRegionStrictBound(const Array& region, return result; } -Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer) { +ffi::Optional> EstimateRegionLowerBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer) { return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); } -Array EstimateRegionUpperBound(const Array& region, const Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { - if (Optional> result = EstimateRegionStrictBound( +ffi::Array EstimateRegionUpperBound(const ffi::Array& region, + const ffi::Map& var_dom, + const PrimExpr& predicate, Analyzer* analyzer) { + if (ffi::Optional> result = EstimateRegionStrictBound( /*region=*/region, /*var_dom=*/var_dom, /*predicate=*/predicate, /*analyzer=*/analyzer)) { return result.value(); } - Array result; + ffi::Array result; result.reserve(region.size()); // try estimate each dimension independently for (const Range& range : region) { @@ -1178,7 +1180,7 @@ Array EstimateRegionUpperBound(const Array& region, const Map int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { + if (ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { result.push_back(int_set.value()); continue; } @@ -1207,20 +1209,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("arith.IntSetIsNothing", &IntSet::IsNothing) .def_method("arith.IntSetIsEverything", &IntSet::IsEverything) .def("arith.EstimateRegionLowerBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); }) .def("arith.EstimateRegionStrictBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); }) .def("arith.EstimateRegionUpperBound", - [](Array region, Map var_dom, - PrimExpr predicate) -> Optional> { + [](ffi::Array region, ffi::Map var_dom, + PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); }) diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index d26ac3667620..59b0b0546dab 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -40,14 +40,14 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) { } } -Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array& indices, - bool non_trivial_only) { +ffi::Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext( + const ffi::Array& indices, bool non_trivial_only) { PrimExpr pred = const_true(); for (PrimExpr val : iter_predicates_) { pred = pred && val; } int n = indices.size(); - Array simplified = arith::IterMapSimplify( + ffi::Array simplified = arith::IterMapSimplify( indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_); if (non_trivial_only) { for (int i = 0; i < n; ++i) { @@ -84,7 +84,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { // as sub-class may or maynot choose to replace it. Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); @@ -105,7 +105,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { } Stmt then_case; - Optional else_case; + ffi::Optional else_case; { With ctx(analyzer_, real_condition); WithRecordIterPredicate(real_condition, [&] { then_case = this->VisitStmt(op->then_case); }); @@ -121,7 +121,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->condition = std::move(condition); @@ -152,7 +152,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->condition = std::move(condition); @@ -185,7 +185,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { } if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return GetRef(op); + return ffi::GetRef(op); } else { return Call(op->dtype, op->op, {cond, true_value, false_value}); } @@ -202,7 +202,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { // as sub-class may or maynot choose to replace it. PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -228,7 +228,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { // normal path if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(cond, true_value, false_value); } diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index fb01fd19cee7..28f8e600d38e 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -74,7 +74,8 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { * \brief Use internal bound information to perform inter map simplification of indices. * \note Only do this during layout remapping */ - Array IterMapSimplifyWithContext(const Array& indices, bool non_trivial_only); + ffi::Array IterMapSimplifyWithContext(const ffi::Array& indices, + bool non_trivial_only); /*! \brief internal analyzer field. */ Analyzer* analyzer_; @@ -83,9 +84,9 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { // expensive and we only encourage doing them during // necessary cases like layout remapping /*! \brief Recorded loop iterators */ - Map iter_vars_; + ffi::Map iter_vars_; /*! \brief iterator predicates */ - Array iter_predicates_; + ffi::Array iter_predicates_; /*! * \brief Run callback while trying to record iter predicate * \param conditon Condition to be checked. @@ -94,7 +95,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { template void WithRecordIterPredicate(PrimExpr condition, FLambda callback) { auto f_use_itervar = [this](const tir::VarNode* v) { - return iter_vars_.count(GetRef(v)); + return iter_vars_.count(ffi::GetRef(v)); }; // simple heuristics for detecting predicate if (tir::UsesVar(condition, f_use_itervar)) { diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 42b99abd4063..e8c96c908a7b 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -49,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); IterMark::IterMark(PrimExpr source, PrimExpr extent) { - auto n = make_object(); + auto n = ffi::make_object(); n->source = std::move(source); n->extent = std::move(extent); data_ = std::move(n); @@ -68,7 +68,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); IterSplitExpr::IterSplitExpr(IterMark source) { - auto n = make_object(); + auto n = ffi::make_object(); auto one = make_const(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); @@ -79,7 +79,7 @@ IterSplitExpr::IterSplitExpr(IterMark source) { } IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { - auto n = make_object(); + auto n = ffi::make_object(); auto one = make_const(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); @@ -91,7 +91,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { - auto n = make_object(); + auto n = ffi::make_object(); n->dtype = source->source->dtype; n->source = std::move(source); n->lower_factor = std::move(lower_factor); @@ -115,8 +115,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", extent=" << op->extent << ", scale=" << op->scale << ")"; }); -IterSumExpr::IterSumExpr(Array args, PrimExpr base) { - auto n = make_object(); +IterSumExpr::IterSumExpr(ffi::Array args, PrimExpr base) { + auto n = ffi::make_object(); n->dtype = base->dtype; n->args = std::move(args); n->base = std::move(base); @@ -125,7 +125,7 @@ IterSumExpr::IterSumExpr(Array args, PrimExpr base) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.IterSumExpr", [](Array args, PrimExpr base) { + refl::GlobalDef().def("arith.IterSumExpr", [](ffi::Array args, PrimExpr base) { return IterSumExpr(args, base); }); }); @@ -152,7 +152,7 @@ class IterMarkSplitCollector { * \brief Collect all mark2splits recursively from indices. * \param indices The iterator of interest. */ - void Collect(const Array& indices) { + void Collect(const ffi::Array& indices) { for (IterSumExpr sum_expr : indices) { for (IterSplitExpr split : sum_expr->args) { this->CollectInternal(split->source); @@ -186,9 +186,9 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + explicit IterMapRewriter(Analyzer* analyzer, const ffi::Map& input_iters, IterMapLevel check_level, bool simplify_trivial_iterators, - Array* errors) + ffi::Array* errors) : analyzer_(analyzer), check_level_(check_level), errors_(*errors), @@ -227,8 +227,8 @@ class IterMapRewriter : public ExprMutator { } IterSumExpr RewriteIterConstraint(const PrimExpr& expr, - const Optional& predicate_induced_min, - const Optional& predicate_induced_max) { + const ffi::Optional& predicate_induced_min, + const ffi::Optional& predicate_induced_max) { return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, predicate_induced_max); } @@ -263,7 +263,7 @@ class IterMapRewriter : public ExprMutator { * - bindings = [x / 3] will not pass because x / 3 can not be one split of x * \return whether the bindings are valid */ - bool CheckMapping(const Array& bindings, IterMapLevel check_level) { + bool CheckMapping(const ffi::Array& bindings, IterMapLevel check_level) { IterMarkSplitCollector collector; // We can check that for each iter mark: // All the splits that refers to the iter_mark covers its extent. @@ -447,7 +447,7 @@ class IterMapRewriter : public ExprMutator { // Iter map check level IterMapLevel check_level_; // Error messages for each unresolved expression. - Array& errors_; + ffi::Array& errors_; // The var map std::unordered_map var_map_; // input iter marks @@ -568,9 +568,9 @@ class IterMapRewriter : public ExprMutator { * \param check_level Iteration mapping's check level. * \return The normalized splits. */ - Array TryNormalizeSplits(const IterMark& mark, - const std::vector& splits, - IterMapLevel check_level) { + ffi::Array TryNormalizeSplits(const IterMark& mark, + const std::vector& splits, + IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -586,7 +586,7 @@ class IterMapRewriter : public ExprMutator { if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective if (check_level == IterMapLevel::Bijective) { - return Array(); + return ffi::Array(); } // look for the next split skipping this lower factor // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] @@ -595,7 +595,7 @@ class IterMapRewriter : public ExprMutator { j = SearchSkipLowerFactor(splits, used, expected_lower_factor); // split not found if (j == splits.size()) { - return Array(); + return ffi::Array(); } } @@ -647,24 +647,24 @@ class IterMapRewriter : public ExprMutator { if (match_full_iter) { if (splits.size() != 1) { ErrorLogger(this) << "Dependent iterations on padding iter space"; - return Array(); + return ffi::Array(); } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { ErrorLogger(this) << "Split on padding iteration is not surjective " << "if the split extent equals to the full iter space extent"; - return Array(); + return ffi::Array(); } } else if (match_iter_divisor) { if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { ErrorLogger(this) << "The extent before padding is less than lower factor"; - return Array(); + return ffi::Array(); } } else { ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; return {}; } } - return Array(iters.rbegin(), iters.rend()); + return ffi::Array(iters.rbegin(), iters.rend()); } /*! @@ -674,8 +674,9 @@ class IterMapRewriter : public ExprMutator { * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional predicate_induced_min, - Optional predicate_induced_max) { + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, + ffi::Optional predicate_induced_min, + ffi::Optional predicate_induced_max) { // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { @@ -685,7 +686,7 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max.value() - base; } - Optional opt = TryFuseIters(expr, check_level_, false); + ffi::Optional opt = TryFuseIters(expr, check_level_, false); ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { @@ -739,7 +740,7 @@ class IterMapRewriter : public ExprMutator { // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); IterSumExprNode* normalized_expr = expr.CopyOnWrite(); - normalized_expr->args = Array({split}); + normalized_expr->args = ffi::Array({split}); normalized_expr->base = base; return expr; } @@ -755,7 +756,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; - Optional opt = TryFuseIters(expr, check_level_, true); + ffi::Optional opt = TryFuseIters(expr, check_level_, true); if (opt.defined()) { return opt.value(); } else { @@ -820,7 +821,7 @@ class IterMapRewriter : public ExprMutator { return lhs.symbol_prod_count > rhs.symbol_prod_count; }); - Array args; + ffi::Array args; for (const Item& item : items) { args.push_back(item.split); } @@ -857,7 +858,7 @@ class IterMapRewriter : public ExprMutator { * \return Whether we can find one. */ int FindBaseIter(const IterSumExpr& expr, const std::vector& skip_flag, - Optional match_source, int rbegin = -1) { + ffi::Optional match_source, int rbegin = -1) { if (rbegin == -1) { rbegin = static_cast(expr->args.size()) - 1; } @@ -927,7 +928,7 @@ class IterMapRewriter : public ExprMutator { * \return -1 if not no match found, otherwise return the index. */ int FindIterWithExactScale(const IterSumExpr& expr, const std::vector& skip_flag, - const PrimExpr& expected_scale, Optional match_source, + const PrimExpr& expected_scale, ffi::Optional match_source, int rbegin = -1, int first_possible_unit_extent_pos = 0) { if (rbegin == -1) { rbegin = static_cast(expr->args.size()) - 1; @@ -993,7 +994,7 @@ class IterMapRewriter : public ExprMutator { * \param check_level The check level if iter mapping. * \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryCombineSplitFromSameSource(IterSumExpr expr) { + ffi::Optional TryCombineSplitFromSameSource(IterSumExpr expr) { if (expr->args.size() <= 1) return std::nullopt; std::unordered_map hit_count; // most iter map are small n < 5 @@ -1078,7 +1079,7 @@ class IterMapRewriter : public ExprMutator { IterSumExpr simplified_sum = expr; // flip the order so we preserve the original order simplified_sum.CopyOnWrite()->args = - Array(reverse_flattened_iters.rbegin(), reverse_flattened_iters.rend()); + ffi::Array(reverse_flattened_iters.rbegin(), reverse_flattened_iters.rend()); return simplified_sum; } @@ -1095,8 +1096,8 @@ class IterMapRewriter : public ExprMutator { * (this may cause us to return parameters that are not canonically wrapped as * IterSum(IterMark)) \return The sum with the fused IterMark and extra offset if succeed. */ - Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level, - bool allow_early_skip) { + ffi::Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level, + bool allow_early_skip) { if (auto opt = TryCombineSplitFromSameSource(expr)) { expr = opt.value(); if (expr->args.size() <= 1 && allow_early_skip) { @@ -1146,7 +1147,7 @@ class IterMapRewriter : public ExprMutator { // predicate: j*2 + k < 9 // We need to match the predicate in expr and adjust the expected scale, // otherwise we expect the scale of i to be 2*5=10 - Optional constraint_to_match; + ffi::Optional constraint_to_match; for (const IterSumExpr& iter : constrained_iters_flattened_) { if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) { // find a predicate started from match position @@ -1208,10 +1209,10 @@ class IterMapRewriter : public ExprMutator { // both forms have splits from outermost to innermost IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = - Array(flattened_iters.rbegin(), flattened_iters.rend()); + ffi::Array(flattened_iters.rbegin(), flattened_iters.rend()); flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); structured_form.CopyOnWrite()->args = - Array(grouped_iters.rbegin(), grouped_iters.rend()); + ffi::Array(grouped_iters.rbegin(), grouped_iters.rend()); structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { @@ -1285,14 +1286,14 @@ struct IterConstraint { // The expr of the iter PrimExpr iter; // The expr of the lower_bound, maybe undefined - Optional lower_bound; + ffi::Optional lower_bound; // The expr of the upper_bound, maybe undefined - Optional upper_bound; + ffi::Optional upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, Optional lower_bound, Optional upper_bound, - size_t size) + IterConstraint(PrimExpr iter, ffi::Optional lower_bound, + ffi::Optional upper_bound, size_t size) : iter(std::move(iter)), lower_bound(std::move(lower_bound)), upper_bound(std::move(upper_bound)), @@ -1306,7 +1307,7 @@ struct IterConstraint { * \param result The result of predicate split. * \return A list of IterConstraint, empty if the split failed. */ -bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, +bool MatchBoundConstraints(PrimExpr pred, ffi::Map* input_iters, std::vector* result) { arith::PVar lhs, rhs, rest; for (;;) { @@ -1348,7 +1349,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, // determine iter and bound, if we can not distinguish them simply, // try divide (lhs - rhs) into itervar aware and itervar free parts auto f_use_itervar = [&input_iters](const VarNode* v) { - return input_iters->count(GetRef(v)); + return input_iters->count(ffi::GetRef(v)); }; bool bound_at_left; if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) { @@ -1381,7 +1382,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, lhs_expr = analyzer.Simplify(lhs_expr); rhs_expr = analyzer.Simplify(rhs_expr); } - Optional lower_bound = std::nullopt, upper_bound = std::nullopt; + ffi::Optional lower_bound = std::nullopt, upper_bound = std::nullopt; PrimExpr iter; if (is_greater) { if (bound_at_left) { @@ -1427,19 +1428,20 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, return true; } -bool IterRangeSanityCheck(const Map& iter_ranges) { +bool IterRangeSanityCheck(const ffi::Map& iter_ranges) { std::unordered_set iters; for (const auto& it : iter_ranges) iters.insert(it.first); - auto f = [&](const VarNode* var) { return iters.count(GetRef(var)); }; + auto f = [&](const VarNode* var) { return iters.count(ffi::GetRef(var)); }; for (const auto& it : iter_ranges) { if (UsesVar(it.second->min, f) || UsesVar(it.second->extent, f)) return false; } return true; } -IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, bool simplify_trivial_iterators) { +IterMapResult DetectIterMap(const ffi::Array& indices, + const ffi::Map& input_iters, const PrimExpr& predicate, + IterMapLevel check_level, arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { IterMapResult result; // Overall detection algorithm is divided into two steps: @@ -1449,7 +1451,7 @@ IterMapResult DetectIterMap(const Array& indices, const Maperrors.push_back("Invalid iterators. Iterators may not be expressions of each other."); return result; } - Map constrained_input_iters = input_iters; + ffi::Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { @@ -1484,7 +1486,7 @@ IterMapResult DetectIterMap(const Array& indices, const Map rewrite_indices; + ffi::Array rewrite_indices; rewrite_indices.reserve(indices.size()); bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { @@ -1526,7 +1528,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "arith.DetectIterMap", - [](const Array& indices, const Map& input_iters, + [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, @@ -1534,7 +1536,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iters, +IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, arith::Analyzer* analyzer) { IterMapResult result; ICHECK(IterRangeSanityCheck(input_iters)) @@ -1553,14 +1555,14 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iter TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("arith.NormalizeToIterSum", - [](PrimExpr index, const Map& input_iters) { + [](PrimExpr index, const ffi::Map& input_iters) { arith::Analyzer ana; return NormalizeToIterSum(index, input_iters, &ana); }); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) return it->second; return var; @@ -1578,7 +1580,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Add(a, b); } @@ -1613,7 +1615,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Sub(a, b); } @@ -1648,7 +1650,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Mul(a, b); } @@ -1657,8 +1659,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { if (a->IsInstance() && b->IsInstance()) { // cannot multiply two iterators, mark as unresolved. ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " - << "occurs in " << GetRef(op); - return GetRef(op); + << "occurs in " << ffi::GetRef(op); + return ffi::GetRef(op); } if (!a->IsInstance()) { @@ -1961,7 +1963,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorDiv(a, b); } @@ -1969,19 +1971,19 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (b->IsInstance()) { // cannot divide an iterator, mark as unresolved. - ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef(op) - << " may not be an iterator"; - return GetRef(op); + ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " + << ffi::GetRef(op) << " may not be an iterator"; + return ffi::GetRef(op); } IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { - return GetRef(op); + return ffi::GetRef(op); } ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { - return GetRef(op); + return ffi::GetRef(op); } return remainder; } @@ -2045,7 +2047,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { // does not contain iter map. if (!a->IsInstance() && !b->IsInstance()) { if (op->a.same_as(a) && op->b.same_as(b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return FloorMod(a, b); } @@ -2054,19 +2056,19 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (b->IsInstance()) { // cannot mod an iterator, mark as unresolved. ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in " - << GetRef(op) << " may not be an iterator"; - return GetRef(op); + << ffi::GetRef(op) << " may not be an iterator"; + return ffi::GetRef(op); } IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { - return GetRef(op); + return ffi::GetRef(op); } ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { - return GetRef(op); + return ffi::GetRef(op); } return remainder; } @@ -2157,13 +2159,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("arith.NormalizeIterMapToExpr", NormalizeIterMapToExpr); }); -Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* ana, bool simplify_trivial_iterators) { +ffi::Array IterMapSimplify(const ffi::Array& indices, + const ffi::Map& input_iters, + const PrimExpr& input_pred, IterMapLevel check_level, + arith::Analyzer* ana, bool simplify_trivial_iterators) { if (!IterRangeSanityCheck(input_iters)) return indices; auto res = DetectIterMap(indices, input_iters, input_pred, check_level, ana, /*simplify_trivial_iterators=*/simplify_trivial_iterators); - Array rewrite = res->indices; + ffi::Array rewrite = res->indices; if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) { // The input predicate may cause detect iter map to fail @@ -2177,7 +2180,7 @@ Array IterMapSimplify(const Array& indices, const Map simplified; + ffi::Array simplified; simplified.reserve(rewrite.size()); IterMapToExprNormalizer converter(ana); for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); @@ -2188,7 +2191,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "arith.IterMapSimplify", - [](const Array& indices, const Map& input_iters, + [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, @@ -2384,7 +2387,7 @@ class SubspaceDivider { extent *= arg->extent; res.push_back(arg); } - return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), extent); + return IterMark(IterSumExpr(ffi::Array(res.rbegin(), res.rend()), base), extent); } DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) { @@ -2394,7 +2397,7 @@ class SubspaceDivider { // encounter one of them. If we encounter another later, we directly return the record. return it->second; } - const Array& splits = collector_.mark2splits_.at(expr->source); + const ffi::Array& splits = collector_.mark2splits_.at(expr->source); if (auto iter_ptr = expr->source->source.as()) { // source is input_iter bool inner = sub_iters_.count(iter_ptr.value()); @@ -2487,15 +2490,16 @@ class SubspaceDivider { PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; }; -Array> SubspaceDivide(const Array& bindings, - const Map& input_iters, - const Array& sub_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, - bool simplify_trivial_iterators) { - if (!IterRangeSanityCheck(input_iters)) return Array>(); +ffi::Array> SubspaceDivide(const ffi::Array& bindings, + const ffi::Map& input_iters, + const ffi::Array& sub_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, + bool simplify_trivial_iterators) { + if (!IterRangeSanityCheck(input_iters)) return ffi::Array>(); auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer, simplify_trivial_iterators); - const Array& maps = res->indices; + const ffi::Array& maps = res->indices; if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -2507,7 +2511,7 @@ Array> SubspaceDivide(const Array& bindings, collector.Collect(maps); SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); - std::vector> results; + std::vector> results; for (const IterSumExpr& expr : maps) { SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0); if (subspace_divider.unresolved_count()) return {}; @@ -2523,9 +2527,10 @@ Array> SubspaceDivide(const Array& bindings, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "arith.SubspaceDivide", [](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, - int check_level, bool simplify_trivial_iterators) { + "arith.SubspaceDivide", + [](const ffi::Array& bindings, const ffi::Map& root_iters, + const ffi::Array& sub_iters, const PrimExpr& predicate, int check_level, + bool simplify_trivial_iterators) { arith::Analyzer ana; return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), &ana, simplify_trivial_iterators); @@ -2536,14 +2541,14 @@ class InverseAffineIterMapTransformer { public: explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} - Map operator()(const Array& iter_map, - const Array& outputs) { + ffi::Map operator()(const ffi::Array& iter_map, + const ffi::Array& outputs) { ICHECK(iter_map.size() == outputs.size()); std::vector post_dfs_order = ReverseTopologyOrder(iter_map); // initialize back propagation accumulator for (const IterMapExprNode* node : post_dfs_order) { - backprop_.Set(GetRef(node), Integer(0)); + backprop_.Set(ffi::GetRef(node), Integer(0)); } for (size_t i = 0; i < iter_map.size(); i++) { backprop_.Set(iter_map[i], outputs[i]); @@ -2552,10 +2557,10 @@ class InverseAffineIterMapTransformer { // run back propagation for (const IterMapExprNode* node : post_dfs_order) { if (node->IsInstance()) { - Visit_(Downcast(GetRef(node))); + Visit_(Downcast(ffi::GetRef(node))); } else { ICHECK(node->IsInstance()); - Visit_(Downcast(GetRef(node))); + Visit_(Downcast(ffi::GetRef(node))); } } return std::move(inverse_); @@ -2591,7 +2596,8 @@ class InverseAffineIterMapTransformer { } } - std::vector ReverseTopologyOrder(const Array& iter_map) { + std::vector ReverseTopologyOrder( + const ffi::Array& iter_map) { std::vector post_dfs_order; std::unordered_map visited; @@ -2652,12 +2658,12 @@ class InverseAffineIterMapTransformer { } Analyzer* analyzer_; - Map backprop_; // the accumulator of backpropgation - Map inverse_; // the result of inverse transformation + ffi::Map backprop_; // the accumulator of backpropgation + ffi::Map inverse_; // the result of inverse transformation }; -Map InverseAffineIterMap(const Array& iter_map, - const Array outputs) { +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const ffi::Array outputs) { Analyzer analyzer; return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index fc082907a6d2..1c8d1ba8b4d8 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -42,7 +42,7 @@ using namespace tir; TVM_FFI_STATIC_INIT_BLOCK({ ModularSetNode::RegisterReflection(); }); ModularSet::ModularSet(int64_t coeff, int64_t base) { - auto node = make_object(); + auto node = ffi::make_object(); node->coeff = coeff; node->base = base; // finish construction. @@ -273,7 +273,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return it->second; diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index d339b728db2c..c608de6b2c45 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -50,14 +50,14 @@ using namespace tir; // with free parameters, and the range of those parameters. class ExpressionNarrower : public tir::ExprMutator { public: - static PrimExpr Apply(PrimExpr expr, Map free_parameters) { + static PrimExpr Apply(PrimExpr expr, ffi::Map free_parameters) { ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; ExpressionNarrower mutator(free_parameters); return mutator(expr); } private: - explicit ExpressionNarrower(Map free_parameters) + explicit ExpressionNarrower(ffi::Map free_parameters) : free_parameters_(free_parameters) {} using Parent = tir::ExprMutator; @@ -111,22 +111,22 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const GTNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), OppositeContext(current), current); + return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); } PrimExpr VisitExpr_(const GENode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), OppositeContext(current), current); + return VisitInequality(ffi::GetRef(op), OppositeContext(current), current); } PrimExpr VisitExpr_(const LTNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const LENode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const EQNode* op) override { @@ -143,7 +143,7 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const SubNode* op) override { auto current = CurrentContext(); - return VisitInequality(GetRef(op), current, OppositeContext(current)); + return VisitInequality(ffi::GetRef(op), current, OppositeContext(current)); } PrimExpr VisitExpr_(const NotNode* op) override { @@ -154,11 +154,11 @@ class ExpressionNarrower : public tir::ExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) override { contains_unknown_expr_ = true; - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const VarNode* op) override { - auto it = free_parameters_.find(GetRef(op)); + auto it = free_parameters_.find(ffi::GetRef(op)); if (it == free_parameters_.end()) { return Parent::VisitExpr_(op); } @@ -206,11 +206,11 @@ class ExpressionNarrower : public tir::ExprMutator { }; std::vector context_stack_; - Map free_parameters_; + ffi::Map free_parameters_; bool contains_unknown_expr_{false}; }; -PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters) { +PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters) { return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); } diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h index 1e452e3ad493..42a7c2cf038f 100644 --- a/src/arith/narrow_predicate_expression.h +++ b/src/arith/narrow_predicate_expression.h @@ -50,7 +50,7 @@ namespace arith { * \returns An expression that, if true, implies that the original * expression is also true. */ -PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters); +PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters); } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 98cf61990d90..7c498d7a9c90 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -214,7 +214,7 @@ class PVar : public Pattern> { typename = typename std::enable_if::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { - return Match_(GetRef(ptr)); + return Match_(ffi::GetRef(ptr)); } else { return false; } @@ -257,7 +257,7 @@ class PVarWithCheck : public arith::Pattern> { typename = typename std::enable_if::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { - return Match_(GetRef(ptr)); + return Match_(ffi::GetRef(ptr)); } else { return false; } @@ -727,7 +727,7 @@ struct PCallExprMatchFunctor { }; struct PCallExprEvalArgsFunctor { - Array args_; + ffi::Array args_; template void operator()(size_t i, const T& pattern) { @@ -778,7 +778,7 @@ class PCallExpr : public Pattern> { // arithemetic intrinsics #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ - static PrimExpr Eval(Array args) { \ + static PrimExpr Eval(ffi::Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ @@ -797,7 +797,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ - static PrimExpr Eval(Array args) { \ + static PrimExpr Eval(ffi::Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ @@ -811,7 +811,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { - static PrimExpr Eval(Array args) { return tir::Call(args[1].dtype(), GetOp(), args); } + static PrimExpr Eval(ffi::Array args) { + return tir::Call(args[1].dtype(), GetOp(), args); + } static const Op& GetOp() { return tir::builtin::if_then_else(); } }; diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 5674cf4f65bf..8f2edb0c1360 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -92,10 +92,10 @@ static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { } PresburgerSet::PresburgerSet(const PrimExpr& constraint) { - Array vars; + ffi::Array vars; PostOrderVisit(constraint, [&vars](const ObjectRef& obj) { if (const VarNode* new_var = obj.as()) { - auto var = GetRef(new_var); + auto var = ffi::GetRef(new_var); if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) { vars.push_back(var); } @@ -105,19 +105,19 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) { Analyzer analyzer; PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0); - auto node = make_object(std::move(space), vars); + auto node = ffi::make_object(std::move(space), vars); node->SetVars(vars); Update(simplified_constraint, node.get()); data_ = std::move(node); } PresburgerSet::PresburgerSet(const std::vector& disjuncts, - const Array& vars) { - auto node = make_object(disjuncts, disjuncts[0].getSpace(), vars); + const ffi::Array& vars) { + auto node = ffi::make_object(disjuncts, disjuncts[0].getSpace(), vars); data_ = std::move(node); } -void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array& vars) { +void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const ffi::Array& vars) { Analyzer analyzer; PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); Update(simplified_constraint, this); @@ -186,7 +186,7 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { return constraint; } -PresburgerSet Union(const Array& sets) { +PresburgerSet Union(const ffi::Array& sets) { CHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; @@ -198,7 +198,7 @@ PresburgerSet Union(const Array& sets) { return PresburgerSet(std::move(relations), sets[0]->GetVars()); } -PresburgerSet Intersect(const Array& sets) { +PresburgerSet Intersect(const ffi::Array& sets) { CHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; @@ -217,7 +217,7 @@ PresburgerSet Intersect(const Array& sets) { } IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { - Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); + ffi::Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); #if TVM_MLIR_VERSION >= 190 SmallVector coeffs; #elif TVM_MLIR_VERSION >= 160 diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h index 3a7114048f92..6996d6188316 100644 --- a/src/arith/presburger_set.h +++ b/src/arith/presburger_set.h @@ -60,10 +60,10 @@ using namespace presburger; class PresburgerSetNode : public IntSetNode { public: PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {} - explicit PresburgerSetNode(const PresburgerSpace& space, const Array& vars) + explicit PresburgerSetNode(const PresburgerSpace& space, const ffi::Array& vars) : disjuncts({}), space(space), vars(vars) {} explicit PresburgerSetNode(const std::vector& disjuncts, - const PresburgerSpace& space, const Array& vars) + const PresburgerSpace& space, const ffi::Array& vars) : disjuncts(disjuncts), space(space), vars(vars) {} /*! \brief Represent the union of multiple IntegerRelation */ @@ -91,7 +91,7 @@ class PresburgerSetNode : public IntSetNode { * \param constraint The added constraint to the PresburgerSet. * \param vars The specified domain vars in constraint expression. */ - void UpdateConstraint(const PrimExpr& constraint, const Array& vars); + void UpdateConstraint(const PrimExpr& constraint, const ffi::Array& vars); /*! * \brief Generate expression that represents the constraint @@ -103,13 +103,13 @@ class PresburgerSetNode : public IntSetNode { * \brief Set domain vars * \param new_vars Vars that will be taken as the domain vars */ - void SetVars(const Array& new_vars) { vars = new_vars; } + void SetVars(const ffi::Array& new_vars) { vars = new_vars; } /*! * \brief Get the current domain vars * \return The current doamin vars */ - Array GetVars() const { return vars; } + ffi::Array GetVars() const { return vars; } /*! \return whether integer set is empty */ bool IsEmpty() const { @@ -121,7 +121,7 @@ class PresburgerSetNode : public IntSetNode { TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); private: - Array vars; + ffi::Array vars; }; /*! @@ -136,7 +136,7 @@ class PresburgerSet : public IntSet { * \param vars The variables that the constraint describes about. * \return The created PresburgerSet. */ - TVM_DLL PresburgerSet(const std::vector& disjuncts, const Array& vars); + TVM_DLL PresburgerSet(const std::vector& disjuncts, const ffi::Array& vars); /*! * \brief Make a new instance of PresburgerSet, collect all vars as space vars. @@ -178,14 +178,14 @@ class PresburgerSet : public IntSet { * \param sets The sets to be combined * \return the set after union */ -PresburgerSet Union(const Array& sets); +PresburgerSet Union(const ffi::Array& sets); /*! * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return The intersect set */ -PresburgerSet Intersect(const Array& sets); +PresburgerSet Intersect(const ffi::Array& sets); /*! * \brief Evaluate the range of given expression based on the constraint diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 66720a579233..9ed30a9de0cd 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1652,7 +1652,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { return ret; } -Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const { +ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( + const PrimExpr& expr) const { PrimExpr negation = Not(expr); ExprDeepEqual expr_equal; @@ -1946,7 +1947,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1); TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); - auto merge_constants = [&]() -> Optional { + auto merge_constants = [&]() -> ffi::Optional { auto [lhs, lhs_offset] = ExtractConstantOffset(ret->a); auto [rhs, rhs_offset] = ExtractConstantOffset(ret->b); if (lhs_offset == 0 && rhs_offset == 0) { @@ -2051,7 +2052,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { // Otherwise, follow ExprMutator's convention of returning the // original object. if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return And(a, b); } @@ -2160,7 +2161,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { - PrimExpr orig = GetRef(op); + PrimExpr orig = ffi::GetRef(op); PrimExpr ret = [&]() -> PrimExpr { // If this extension isn't enabled, just delegate out. @@ -2200,7 +2201,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { // Otherwise, follow ExprMutator's convention of returning the // original object. if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Or(a, b); } @@ -2350,7 +2351,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (op->dtype == DataType::Bool()) { if (auto match = TryMatchLiteralConstraint(var)) { return match.value(); @@ -2361,7 +2362,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { if (it != var_map_.end()) { return it->second; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { @@ -2388,7 +2389,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { } PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b4bd799a2933..8e43da636506 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -71,7 +71,7 @@ struct RewriteSimplifierStatsNode : Object { struct RewriteSimplifierStats : ObjectRef { explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) { - data_ = make_object(data); + data_ = ffi::make_object(data); } TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode); @@ -193,7 +193,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { * matches a constraint, return the boolean it should be replaced * with. Otherwise, return false. */ - Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; + ffi::Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; /*! \brief Rewrite rules for Less Than comparisons * diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 1937b9c34e03..5c968966e2f0 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -86,7 +86,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasVLA(Optional target) { +bool TargetHasVLA(ffi::Optional target) { if (!target.defined()) { target = Target::Current(); } @@ -102,7 +102,7 @@ bool TargetHasVLA(Optional target) { return has_vla; } -const std::vector GetVScaleValues(Optional target) { +const std::vector GetVScaleValues(ffi::Optional target) { unsigned int vector_width = 0; std::vector kVScaleValues; if (!target.defined()) { diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 2470d5dcd827..88c140288734 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -81,14 +81,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr * \param target The target to check. * \return Whether VLA is supported */ -bool TargetHasVLA(Optional target = std::nullopt); +bool TargetHasVLA(ffi::Optional target = std::nullopt); /*! * \brief Get a list of known vscale values to try for an VLA target. * \param target The target to check. * \return A list of vscale values as std::vector */ -const std::vector GetVScaleValues(Optional target = std::nullopt); +const std::vector GetVScaleValues(ffi::Optional target = std::nullopt); } // namespace arith } // namespace tvm diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 5d1f102a5b7e..2e1b725f83c5 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -209,10 +209,11 @@ void SmithNormalFormDiag(std::vector>* S, std::vector InferRange(const Map& vars_to_infer, const Array& ori_vars, - const Map& ori_ranges) { +ffi::Map InferRange(const ffi::Map& vars_to_infer, + const ffi::Array& ori_vars, + const ffi::Map& ori_ranges) { // The resulting ranges - Map new_ranges; + ffi::Map new_ranges; std::unordered_set ori_vset; for (const Var& v : ori_vars) { @@ -260,7 +261,7 @@ void DebugPrint(const std::vector>& S, } std::cout << "\n"; } - std::cout << "V_inv x:\n" << Array(V_inv_x); + std::cout << "V_inv x:\n" << ffi::Array(V_inv_x); std::cout << "\n" << std::endl; } @@ -298,8 +299,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] - Array coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), - system_to_solve->variables); + ffi::Array coeffs = arith::DetectLinearEquation( + analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -337,10 +338,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // Uy is U \times y SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); - Array new_vars; - Array new_relations; - Map new_to_old_map; - Map old_to_new_map; + ffi::Array new_vars; + ffi::Array new_relations; + ffi::Map new_to_old_map; + ffi::Map old_to_new_map; // Simplify right hand sides for (PrimExpr r : Uy) { @@ -372,7 +373,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } } - Array solution_for_V_inv_x; + ffi::Array solution_for_V_inv_x; // Now create new variables or directly solve the equations // suppose the rank of A is r, aka r = # of non-zeros in S // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b @@ -421,7 +422,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } // The resulting ranges - Map new_ranges = + ffi::Map new_ranges = InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); @@ -462,9 +463,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = SolveLinearEquations(args[0].cast()); } else if (args.size() == 3) { - auto opt_vars = args[0].cast>>(); - auto opt_map = args[1].cast>>(); - auto opt_relations = args[2].cast>>(); + auto opt_vars = args[0].cast>>(); + auto opt_map = args[1].cast>>(); + auto opt_relations = args[2].cast>>(); IntConstraints problem(opt_vars.value_or({}), opt_map.value_or({}), opt_relations.value_or({})); *ret = SolveLinearEquations(problem); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index bf50a0ea52ec..bbca4ccbd97e 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -133,7 +133,7 @@ void ClassifyByPolarity(const Var& var, const std::vector& current_ine // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { if (const LENode* le = ineq.as()) { - Array coef = arith::DetectLinearEquation(le->a, {var}); + ffi::Array coef = arith::DetectLinearEquation(le->a, {var}); if (!coef.empty() && is_const_int(coef[0])) { int64_t coef0 = *as_const_int(coef[0]); if (coef0 == 0) { @@ -147,7 +147,7 @@ void ClassifyByPolarity(const Var& var, const std::vector& current_ine continue; } } else if (const EQNode* eq = ineq.as()) { - Array coef = arith::DetectLinearEquation(eq->a, {var}); + ffi::Array coef = arith::DetectLinearEquation(eq->a, {var}); if (!coef.empty() && is_const_int(coef[0])) { int64_t coef0 = *as_const_int(coef[0]); if (coef0 == 0) { @@ -218,7 +218,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t &analyzer); } - Map res_bounds; + ffi::Map res_bounds; for (const Var& v : system_to_solve->variables) { ICHECK(!res_bounds.count(v)) << "Variable " << v @@ -329,16 +329,16 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Write it to the result. IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), - Array(lower_bounds.begin(), lower_bounds.end()), - Array(equal_list.begin(), equal_list.end()), - Array(upper_bounds.begin(), upper_bounds.end())); + ffi::Array(lower_bounds.begin(), lower_bounds.end()), + ffi::Array(equal_list.begin(), equal_list.end()), + ffi::Array(upper_bounds.begin(), upper_bounds.end())); res_bounds.Set(v, bnds); std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); } // Everything that is left goes to res.relations - Array other_conditions; + ffi::Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite); if (is_const_int(e_simp, 0)) { @@ -366,17 +366,17 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges // It will be useful when solving Jacobian axes jac_xxx) - Map res_ranges; + ffi::Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; - Array solved_other_relations = solved_system.second; + ffi::Map solved_bounds = solved_system.first; + ffi::Array solved_other_relations = solved_system.second; - Array res_relations; + ffi::Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges; + ffi::Map vranges; for (std::pair vr : inequalities->ranges) { vranges.Set(vr.first, vr.second); } @@ -441,21 +441,21 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) { // Resulting ranges will contain ranges for the new variables and for the variables that are // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) - Map res_ranges; + ffi::Map res_ranges; // we get a set of equality, lower, upper bound of each variable. auto solved_system = SolveLinearInequalities(inequalities); - Map solved_bounds = solved_system.first; - Array solved_other_relations = solved_system.second; + ffi::Map solved_bounds = solved_system.first; + ffi::Array solved_other_relations = solved_system.second; arith::Analyzer analyzer; - Map res_src_to_dst; - Map res_dst_to_src; - Array res_variables; - Array res_relations; + ffi::Map res_src_to_dst; + ffi::Map res_dst_to_src; + ffi::Array res_variables; + ffi::Array res_relations; // this keeps being updated during determining the range of each variable. - Map vranges; + ffi::Map vranges; for (std::pair vr : inequalities->ranges) { vranges.Set(vr.first, vr.second); } @@ -528,7 +528,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } // Reverse the axis so that it matches the order of the original variables - res_variables = Array(res_variables.rbegin(), res_variables.rend()); + res_variables = ffi::Array(res_variables.rbegin(), res_variables.rend()); IntConstraints new_inequalities(res_variables, res_ranges, res_relations); IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src); @@ -548,8 +548,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ problem = args[0].cast(); ret_ineq = SolveLinearInequalities(problem); } else if (args.size() == 3) { - problem = IntConstraints(args[0].cast>(), args[1].cast>(), - args[2].cast>()); + problem = IntConstraints(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); ret_ineq = SolveLinearInequalities(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " @@ -562,9 +563,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = SolveInequalitiesToRange(args[0].cast()); } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), + args[2].cast>()); *ret = SolveInequalitiesToRange(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " @@ -575,9 +576,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (args.size() == 1) { *ret = SolveInequalitiesDeskewRange(args[0].cast()); } else if (args.size() == 3) { - auto opt_map = args[1].cast>>(); - IntConstraints problem(args[0].cast>(), opt_map.value_or({}), - args[2].cast>()); + auto opt_map = args[1].cast>>(); + IntConstraints problem(args[0].cast>(), opt_map.value_or({}), + args[2].cast>()); *ret = SolveInequalitiesDeskewRange(problem); } else { LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 52010ec322c8..b4cd7b260ebb 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -276,7 +276,7 @@ class TransitiveComparisonAnalyzer::Impl { * Tracked separatedly to handle the `allow_override` option used by * all sub-analyzers when binding variables. */ - Map prev_bindings_; + ffi::Map prev_bindings_; /*! \brief Known comparisons based on definitionally-true statements * diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index 6a3e8c3d434c..c074eb5c935a 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -47,7 +47,7 @@ class Scalarizer : public ExprMutator { PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = let_var_remap_.find(op); if (it != let_var_remap_.end()) { diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index f582f6416d93..dc2d5d1ef9a1 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -53,41 +53,41 @@ class BaseOpCode { * \brief The constructor of BaseOpCode * \param func_name the function name for the node. */ - explicit BaseOpCode(const String& func_name) : func_name_(func_name) {} + explicit BaseOpCode(const ffi::String& func_name) : func_name_(func_name) {} virtual ~BaseOpCode() = default; /*! \brief Config the BaseOpCode*/ void Config(const MSCJoint& node, const std::shared_ptr config, - const Map& prims) { + const ffi::Map& prims) { node_ = node; config_ = config; prims_ = prims; } /*! \brief Get docs for the node*/ - virtual const Array GetDocs() = 0; + virtual const ffi::Array GetDocs() = 0; /*! \brief Get return describe for default node*/ - virtual const String IdxNode() { return IdxNodeBase(node_); } + virtual const ffi::String IdxNode() { return IdxNodeBase(node_); } /*! \brief Get describe for default node input*/ - const String IdxInput(int idx = 0, bool process = true) { + const ffi::String IdxInput(int idx = 0, bool process = true) { return IdxInputBase(node_, idx, process); } /*! \brief Get describe for default node output*/ - const String IdxOutput(int idx = 0) { return IdxOutputBase(node_, idx); } + const ffi::String IdxOutput(int idx = 0) { return IdxOutputBase(node_, idx); } /*! \brief Get describe for default node weight*/ - const String IdxWeight(const String& wtype, bool process = true) { + const ffi::String IdxWeight(const ffi::String& wtype, bool process = true) { return IdxWeightBase(node_, wtype, process); } /*! \brief Get the node attr as doc*/ - const ExprDoc GetAttrDoc(const String& key, const String& type) { + const ExprDoc GetAttrDoc(const ffi::String& key, const ffi::String& type) { if (StringUtils::StartsWith(type, "list")) { - const String& ele_type = + const ffi::String& ele_type = StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); if (ele_type == "bool") { return DocUtils::ToList(node_->GetTypeArrayAttr(key)); @@ -115,16 +115,16 @@ class BaseOpCode { } /*! \brief Get comment for default node*/ - const String Comment() { return Comment(node_); } + const ffi::String Comment() { return Comment(node_); } /*! \brief Get func_name for the default node*/ - const String func_name() { return func_name_; } + const ffi::String func_name() { return func_name_; } /*! \brief Get valid func name for the default node*/ - virtual const String callee_name() { return func_name(); } + virtual const ffi::String callee_name() { return func_name(); } /*! \brief Get valid return name for the default node*/ - virtual const String ret_name() { return IdxNode(); } + virtual const ffi::String ret_name() { return IdxNode(); } /*! \brief Get the default node*/ const MSCJoint node() { return node_; } @@ -132,7 +132,7 @@ class BaseOpCode { CODEGEN_MEMBERS; private: - String func_name_; + ffi::String func_name_; MSCJoint node_; }; @@ -170,7 +170,8 @@ class BaseCodeGen { virtual ~BaseCodeGen() = default; /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") = 0; + virtual const ffi::Map GetSources( + const std::string& print_options = "") = 0; CODEGEN_MEMBERS; @@ -210,7 +211,7 @@ class BaseCodeGen { } /*! \brief Get the optype for op codegen*/ - const String GetOpType(const MSCJoint& node) { + const ffi::String GetOpType(const MSCJoint& node) { if (config_->use_plugin && IsPlugin(node->optype)) { return "plugin"; } @@ -218,10 +219,10 @@ class BaseCodeGen { } /*! \brief Get the docs for the op*/ - virtual const Array GetOpCodes(const MSCJoint& node) = 0; + virtual const ffi::Array GetOpCodes(const MSCJoint& node) = 0; /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { if (prim->optype == "Int") { return prim->GetTypeAttr("value"); } @@ -247,14 +248,14 @@ class BaseCodeGen { const MSCGraph graph() const { return graph_; } /*! \brief Get the scopes*/ - const std::stack> scopes() const { return scopes_; } + const std::stack> scopes() const { return scopes_; } /*! \brief The stack of codes*/ CodeStack stack_; private: MSCGraph graph_; - std::stack> scopes_; + std::stack> scopes_; }; } // namespace msc diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index 041ffe7091b2..e1b34f7d28b7 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -27,16 +27,16 @@ namespace tvm { namespace contrib { namespace msc { -const Array BaseStack::GetDocs() const { +const ffi::Array BaseStack::GetDocs() const { ICHECK(blocks_.size() == 1) << "Has incomplete blocks, please check"; return TopBlock(); } void BaseStack::Line(const Doc& doc) { PushDoc(doc); } -void BaseStack::Line(const String& line) { Line(IdDoc(line)); } +void BaseStack::Line(const ffi::String& line) { Line(IdDoc(line)); } -void BaseStack::Comment(const String& comment, bool attach) { +void BaseStack::Comment(const ffi::String& comment, bool attach) { if (attach) { const auto& doc = TopDoc(); ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; @@ -47,38 +47,39 @@ void BaseStack::Comment(const String& comment, bool attach) { } } -void BaseStack::Declare(const String& type, const String& variable, size_t len, +void BaseStack::Declare(const ffi::String& type, const ffi::String& variable, size_t len, bool use_constructor) { PushDoc(DocUtils::ToDeclare(type, variable, len, use_constructor)); } void BaseStack::DeclareArgBase(const ExprDoc& value) { const auto& declare = PopCheckedDoc(); - Array init_args = declare->init_args; + ffi::Array init_args = declare->init_args; init_args.push_back(value); PushDoc(DeclareDoc(declare->type, declare->variable, init_args, declare->use_constructor)); } -void BaseStack::FuncDef(const String& func_name, const String& ret_type) { +void BaseStack::FuncDef(const ffi::String& func_name, const ffi::String& ret_type) { if (ret_type.size() > 0) { - PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), IdDoc(ret_type), - Array())); + PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), + IdDoc(ret_type), ffi::Array())); } else { - PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), std::nullopt, - Array())); + PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), + std::nullopt, ffi::Array())); } } -void BaseStack::FuncArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::FuncArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& func = PopCheckedDoc(); - Array args = func->args; + ffi::Array args = func->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(FunctionDoc(func->name, args, func->decorators, func->return_type, func->body)); } -void BaseStack::FuncDecorator(const String& decorator) { +void BaseStack::FuncDecorator(const ffi::String& decorator) { const auto& func = PopCheckedDoc(); - Array decorators = func->decorators; + ffi::Array decorators = func->decorators; decorators.push_back(IdDoc(decorator)); PushDoc(FunctionDoc(func->name, func->args, decorators, func->return_type, func->body)); } @@ -95,13 +96,13 @@ void BaseStack::FuncEnd() { PushDoc(FunctionDoc(func->name, func->args, func->decorators, func->return_type, body)); } -void BaseStack::ClassDef(const String& class_name) { - PushDoc(ClassDoc(IdDoc(class_name), Array(), Array())); +void BaseStack::ClassDef(const ffi::String& class_name) { + PushDoc(ClassDoc(IdDoc(class_name), ffi::Array(), ffi::Array())); } -void BaseStack::ClassDecorator(const String& decorator) { +void BaseStack::ClassDecorator(const ffi::String& decorator) { const auto& class_doc = PopCheckedDoc(); - Array decorators = class_doc->decorators; + ffi::Array decorators = class_doc->decorators; decorators.push_back(IdDoc(decorator)); PushDoc(ClassDoc(class_doc->name, decorators, class_doc->body)); } @@ -118,8 +119,8 @@ void BaseStack::ClassEnd() { PushDoc(ClassDoc(class_doc->name, class_doc->decorators, body)); } -void BaseStack::StructStart(const String& struct_name) { - PushDoc(StructDoc(IdDoc(struct_name), Array(), Array())); +void BaseStack::StructStart(const ffi::String& struct_name) { + PushDoc(StructDoc(IdDoc(struct_name), ffi::Array(), ffi::Array())); BlockStart(); } @@ -130,13 +131,14 @@ void BaseStack::StructEnd() { PushDoc(StructDoc(struct_doc->name, struct_doc->decorators, body)); } -void BaseStack::ConstructorDef(const String& constructor_name) { - PushDoc(ConstructorDoc(IdDoc(constructor_name), Array(), Array())); +void BaseStack::ConstructorDef(const ffi::String& constructor_name) { + PushDoc(ConstructorDoc(IdDoc(constructor_name), ffi::Array(), ffi::Array())); } -void BaseStack::ConstructorArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::ConstructorArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& func = PopCheckedDoc(); - Array args = func->args; + ffi::Array args = func->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(ConstructorDoc(func->name, args, func->body)); } @@ -153,20 +155,22 @@ void BaseStack::ConstructorEnd() { PushDoc(ConstructorDoc(func->name, func->args, body)); } -void BaseStack::LambdaDef(const String& lambda_name) { - PushDoc(LambdaDoc(IdDoc(lambda_name), Array(), Array(), Array())); +void BaseStack::LambdaDef(const ffi::String& lambda_name) { + PushDoc(LambdaDoc(IdDoc(lambda_name), ffi::Array(), ffi::Array(), + ffi::Array())); } -void BaseStack::LambdaArg(const String& arg, const String& annotation, const String& value) { +void BaseStack::LambdaArg(const ffi::String& arg, const ffi::String& annotation, + const ffi::String& value) { const auto& lambda = PopCheckedDoc(); - Array args = lambda->args; + ffi::Array args = lambda->args; args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(LambdaDoc(lambda->name, args, lambda->refs, lambda->body)); } -void BaseStack::LambdaRef(const String& ref) { +void BaseStack::LambdaRef(const ffi::String& ref) { const auto& lambda = PopCheckedDoc(); - Array refs = lambda->refs; + ffi::Array refs = lambda->refs; refs.push_back(IdDoc(ref)); PushDoc(LambdaDoc(lambda->name, lambda->args, refs, lambda->body)); } @@ -176,7 +180,7 @@ void BaseStack::LambdaStart() { BlockStart(); } -void BaseStack::LambdaEnd(const String& ret_val) { +void BaseStack::LambdaEnd(const ffi::String& ret_val) { if (ret_val.size() > 0) { PushDoc(ReturnDoc(IdDoc(ret_val))); } @@ -191,13 +195,15 @@ void BaseStack::LambdaEnd(const ExprDoc& ret_val) { LambdaEnd(""); } -void BaseStack::FuncCall(const String& callee, Optional assign_to, - Optional caller) { +void BaseStack::FuncCall(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller) { if (!caller.defined()) { - PushDoc(CallDoc(IdDoc(callee), Array(), Array(), Array())); + PushDoc(CallDoc(IdDoc(callee), ffi::Array(), ffi::Array(), + ffi::Array())); } else { const auto& new_access = AttrAccessDoc(caller.value(), callee); - PushDoc(CallDoc(new_access, Array(), Array(), Array())); + PushDoc(CallDoc(new_access, ffi::Array(), ffi::Array(), + ffi::Array())); } if (assign_to.defined()) { const auto& last_call = PopCheckedDoc(); @@ -211,14 +217,15 @@ void BaseStack::FuncCall(const String& callee, Optional assign_to, } } -void BaseStack::FuncCall(const String& callee, const String& assign_to, const String& caller) { - Optional assign_doc; +void BaseStack::FuncCall(const ffi::String& callee, const ffi::String& assign_to, + const ffi::String& caller) { + ffi::Optional assign_doc; if (assign_to.size() == 0) { assign_doc = std::nullopt; } else { assign_doc = IdDoc(assign_to); } - Optional caller_doc; + ffi::Optional caller_doc; if (caller.size() == 0) { caller_doc = std::nullopt; } else { @@ -227,26 +234,27 @@ void BaseStack::FuncCall(const String& callee, const String& assign_to, const St FuncCall(callee, assign_doc, caller_doc); } -void BaseStack::MethodCall(const String& callee, bool new_line) { +void BaseStack::MethodCall(const ffi::String& callee, bool new_line) { const auto& host = PopDoc(); if (host->IsInstance()) { const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); FuncCall(v_callee, std::nullopt, Downcast(host)); } else if (const auto* a_node = host.as()) { ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; - FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, Array(), true), + FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, ffi::Array(), true), a_node->rhs); } else { LOG(FATAL) << "Unexpected host type for inplace " << host->GetTypeKey(); } } -void BaseStack::InplaceStart(const String& callee, Optional assign_to, - Optional caller) { +void BaseStack::InplaceStart(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller) { FuncCall(callee, assign_to, caller); } -void BaseStack::InplaceStart(const String& callee, const String& assign_to, const String& caller) { +void BaseStack::InplaceStart(const ffi::String& callee, const ffi::String& assign_to, + const ffi::String& caller) { FuncCall(callee, assign_to, caller); } @@ -266,7 +274,7 @@ void BaseStack::InplaceEnd() { } } -void BaseStack::PopNest(const String& key) { +void BaseStack::PopNest(const ffi::String& key) { const auto& last = PopDoc(); if (last->IsInstance()) { CallArgBase(Downcast(last), key); @@ -275,11 +283,11 @@ void BaseStack::PopNest(const String& key) { } } -void BaseStack::CallArgBase(const ExprDoc& value, const String& key) { +void BaseStack::CallArgBase(const ExprDoc& value, const ffi::String& key) { const auto& last = PopDoc(); - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // get args and kwargs if (const auto* call = last.as()) { args = call->args; @@ -313,16 +321,16 @@ void BaseStack::CallArgBase(const ExprDoc& value, const String& key) { } } -void BaseStack::ConditionIf(const String& predicate) { - Array else_branch{ExprStmtDoc(IdDoc("pass"))}; - PushDoc(IfDoc(IdDoc(predicate), Array(), else_branch)); +void BaseStack::ConditionIf(const ffi::String& predicate) { + ffi::Array else_branch{ExprStmtDoc(IdDoc("pass"))}; + PushDoc(IfDoc(IdDoc(predicate), ffi::Array(), else_branch)); BlockStart(); } void BaseStack::ConditionElse() { const auto& block = PopBlock(); const auto& if_doc = PopCheckedDoc(); - PushDoc(IfDoc(if_doc->predicate, DocUtils::ToStmts(block), Array())); + PushDoc(IfDoc(if_doc->predicate, DocUtils::ToStmts(block), ffi::Array())); BlockStart(); } @@ -331,7 +339,7 @@ void BaseStack::ConditionEnd() { const auto& if_doc = PopCheckedDoc(); const auto& branch = DocUtils::ToStmts(block); if (if_doc->then_branch.size() == 0) { - PushDoc(IfDoc(if_doc->predicate, branch, Array())); + PushDoc(IfDoc(if_doc->predicate, branch, ffi::Array())); } else { PushDoc(IfDoc(if_doc->predicate, if_doc->then_branch, branch)); } @@ -344,8 +352,8 @@ void BaseStack::ForEnd() { PushDoc(ForDoc(for_doc->lhs, for_doc->rhs, body)); } -void BaseStack::WhileStart(const String& predicate) { - PushDoc(WhileDoc(IdDoc(predicate), Array())); +void BaseStack::WhileStart(const ffi::String& predicate) { + PushDoc(WhileDoc(IdDoc(predicate), ffi::Array())); BlockStart(); } @@ -356,20 +364,20 @@ void BaseStack::WhileEnd() { PushDoc(WhileDoc(while_doc->predicate, body)); } -void BaseStack::SwitchStart(const String& predicate) { - Array predicates; +void BaseStack::SwitchStart(const ffi::String& predicate) { + ffi::Array predicates; predicates.push_back(IdDoc(predicate)); - PushDoc(SwitchDoc(predicates, Array>(), Array())); + PushDoc(SwitchDoc(predicates, ffi::Array>(), ffi::Array())); BlockStart(); } -void BaseStack::SwitchCase(const String& predicate) { +void BaseStack::SwitchCase(const ffi::String& predicate) { const auto& block = PopBlock(); const auto& switch_doc = PopCheckedDoc(); auto branchs = switch_doc->branchs; branchs.push_back(DocUtils::ToStmts(block)); if (predicate.size() == 0) { - Array default_branch{ExprStmtDoc(IdDoc("pass"))}; + ffi::Array default_branch{ExprStmtDoc(IdDoc("pass"))}; PushDoc(SwitchDoc(switch_doc->predicates, branchs, default_branch)); } else { auto predicates = switch_doc->predicates; @@ -392,7 +400,7 @@ void BaseStack::SwitchEnd() { } void BaseStack::BlockStart() { - Array block; + ffi::Array block; blocks_.push(block); } @@ -407,11 +415,11 @@ void BaseStack::BlockEnd(bool block_docs) { } } -void BaseStack::ScopeStart(const String& scope_def, const String& scope_ref) { +void BaseStack::ScopeStart(const ffi::String& scope_def, const ffi::String& scope_ref) { if (scope_ref.size() > 0) { - PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), Array())); + PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), ffi::Array())); } else { - PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), Array())); + PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), ffi::Array())); } BlockStart(); } @@ -424,12 +432,12 @@ void BaseStack::ScopeEnd() { bool BaseStack::HasBlock() const { return blocks_.size() > 0; } -const Array BaseStack::TopBlock() const { +const ffi::Array BaseStack::TopBlock() const { ICHECK(HasBlock()) << "No block found"; return blocks_.top(); } -const Array BaseStack::PopBlock() { +const ffi::Array BaseStack::PopBlock() { const auto& block = TopBlock(); blocks_.pop(); return block; diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h index ff4e6b58247a..d588c3cf4f31 100644 --- a/src/contrib/msc/core/codegen/code_stack.h +++ b/src/contrib/msc/core/codegen/code_stack.h @@ -59,24 +59,24 @@ class BaseStack { } /*! \brief Get the docs*/ - const Array GetDocs() const; + const ffi::Array GetDocs() const; protected: /*! \brief Push Id Doc*/ void Line(const Doc& doc); - void Line(const String& line = ""); + void Line(const ffi::String& line = ""); /*! \brief Push Comment Doc*/ - void Comment(const String& comment, bool attach = false); + void Comment(const ffi::String& comment, bool attach = false); /*! \brief Push Assign Doc*/ template - inline void Assign(const LT& lhs, const RT& rhs, const String& annotation = "") { + inline void Assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { PushDoc(DocUtils::ToAssign(lhs, rhs, annotation)); } /*! \brief Push declare Doc*/ - void Declare(const String& type, const String& variable, size_t len = 0, + void Declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, bool use_constructor = true); /*! \brief Cache declare argument*/ @@ -89,10 +89,10 @@ class BaseStack { } /*! \brief Cache class Doc*/ - void ClassDef(const String& class_name); + void ClassDef(const ffi::String& class_name); /*! \brief Cache class decorator*/ - void ClassDecorator(const String& decorator); + void ClassDecorator(const ffi::String& decorator); /*! \brief Start class body block*/ void ClassStart(); @@ -101,19 +101,20 @@ class BaseStack { void ClassEnd(); /*! \brief Start struct body block*/ - void StructStart(const String& struct_name); + void StructStart(const ffi::String& struct_name); /*! \brief End struct body block*/ void StructEnd(); /*! \brief Cache function Doc*/ - void FuncDef(const String& func_name, const String& ret_type = ""); + void FuncDef(const ffi::String& func_name, const ffi::String& ret_type = ""); /*! \brief Cache function argument*/ - void FuncArg(const String& arg, const String& annotation = "", const String& value = ""); + void FuncArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Cache function decorator*/ - void FuncDecorator(const String& decorator); + void FuncDecorator(const ffi::String& decorator); /*! \brief Start function body block*/ void FuncStart(); @@ -128,10 +129,11 @@ class BaseStack { } /*! \brief Cache constructor Doc*/ - void ConstructorDef(const String& constructor_name); + void ConstructorDef(const ffi::String& constructor_name); /*! \brief Cache constructor argument*/ - void ConstructorArg(const String& arg, const String& annotation = "", const String& value = ""); + void ConstructorArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Start constructor body block*/ void ConstructorStart(); @@ -140,52 +142,55 @@ class BaseStack { void ConstructorEnd(); /*! \brief Cache lambda Doc*/ - void LambdaDef(const String& lambda_name); + void LambdaDef(const ffi::String& lambda_name); /*! \brief Cache lambda argument*/ - void LambdaArg(const String& arg, const String& annotation = "", const String& value = ""); + void LambdaArg(const ffi::String& arg, const ffi::String& annotation = "", + const ffi::String& value = ""); /*! \brief Cache lambda reference*/ - void LambdaRef(const String& ref); + void LambdaRef(const ffi::String& ref); /*! \brief Start lambda body block*/ void LambdaStart(); /*! \brief End lambda body block*/ - void LambdaEnd(const String& ret_val = ""); + void LambdaEnd(const ffi::String& ret_val = ""); void LambdaEnd(const ExprDoc& ret_val); /*! \brief Push call and maybe assign Doc*/ - void FuncCall(const String& callee, Optional assign_to, - Optional caller = std::nullopt); - void FuncCall(const String& callee, const String& assign_to = "", const String& caller = ""); + void FuncCall(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller = std::nullopt); + void FuncCall(const ffi::String& callee, const ffi::String& assign_to = "", + const ffi::String& caller = ""); /*! \brief Push method call Doc*/ - void MethodCall(const String& callee, bool new_line = false); + void MethodCall(const ffi::String& callee, bool new_line = false); /*! \brief Push inplace call and maybe assign Doc*/ - void InplaceStart(const String& callee, Optional assign_to, - Optional caller = std::nullopt); - void InplaceStart(const String& callee, const String& assign_to = "", const String& caller = ""); + void InplaceStart(const ffi::String& callee, ffi::Optional assign_to, + ffi::Optional caller = std::nullopt); + void InplaceStart(const ffi::String& callee, const ffi::String& assign_to = "", + const ffi::String& caller = ""); /*! \brief End inplace call*/ void InplaceEnd(); /*! \brief Push nested expr to last Doc*/ - void PopNest(const String& key = ""); + void PopNest(const ffi::String& key = ""); /*! \brief Cache call typed argument*/ - void CallArgBase(const ExprDoc& value, const String& key = ""); + void CallArgBase(const ExprDoc& value, const ffi::String& key = ""); /*! \brief Cache call normal argument*/ template - inline void CallArg(T value, const String& key = "") { + inline void CallArg(T value, const ffi::String& key = "") { const auto& doc_value = DocUtils::ToDoc(value); if (doc_value.defined()) { CallArgBase(doc_value, key); } } - inline void CallArg(const Array& values) { + inline void CallArg(const ffi::Array& values) { for (const auto& v : values) { if (v.defined()) { CallArgBase(v); @@ -194,7 +199,7 @@ class BaseStack { } /*! \brief Push if to cache and start if block*/ - void ConditionIf(const String& predicate); + void ConditionIf(const ffi::String& predicate); /*! \brief Push then branch to cached and start block*/ void ConditionElse(); @@ -205,15 +210,15 @@ class BaseStack { /*! \brief Push for to cache and start for block*/ template void ForStart(const LT& lhs, const RT& rhs) { - PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), Array())); + PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), ffi::Array())); BlockStart(); } /*! \brief Push for range to cache and start for block*/ template - void ForStart(const String& lhs, const ST& start, const ET& end) { - Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; - PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), Array())); + void ForStart(const ffi::String& lhs, const ST& start, const ET& end) { + ffi::Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; + PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), ffi::Array())); BlockStart(); } @@ -221,16 +226,16 @@ class BaseStack { void ForEnd(); /*! \brief Push while to cache and start while block*/ - void WhileStart(const String& predicate); + void WhileStart(const ffi::String& predicate); /*! \brief End a while block*/ void WhileEnd(); /*! \brief Push switch to cache and start switch block*/ - void SwitchStart(const String& predicate); + void SwitchStart(const ffi::String& predicate); /*! \brief Add new case to switch*/ - void SwitchCase(const String& predicate = ""); + void SwitchCase(const ffi::String& predicate = ""); /*! \brief Push switch to cached*/ void SwitchEnd(); @@ -242,7 +247,7 @@ class BaseStack { void BlockEnd(bool block_docs = true); /*! \brief Start a new scope*/ - void ScopeStart(const String& scope_def = "", const String& scope_ref = ""); + void ScopeStart(const ffi::String& scope_def = "", const ffi::String& scope_ref = ""); /*! \brief End a scope*/ void ScopeEnd(); @@ -252,10 +257,10 @@ class BaseStack { bool HasBlock() const; /*! \brief Get the last the block*/ - const Array TopBlock() const; + const ffi::Array TopBlock() const; /*! \brief Pop last the block*/ - const Array PopBlock(); + const ffi::Array PopBlock(); /*! \brief Check if doc left*/ bool HasDoc(); @@ -274,237 +279,239 @@ class BaseStack { void PushDoc(const Doc& doc); /*! \brief The blocks, each has docs array*/ - std::stack> blocks_; + std::stack> blocks_; }; -#define COMMON_WRAPPERS(Stack) \ - Stack& line(const Doc& doc) { \ - Line(doc); \ - return *this; \ - } \ - Stack& line(const String& line = "") { \ - Line(line); \ - return *this; \ - } \ - Stack& comment(const String& comment, bool attach = false) { \ - Comment(comment, attach); \ - return *this; \ - } \ - template \ - Stack& assign(const LT& lhs, const RT& rhs, const String& annotation = "") { \ - Assign(lhs, rhs, annotation); \ - return *this; \ - } \ - Stack& declare(const String& type, const String& variable, size_t len = 0, \ - bool use_constructor = true) { \ - Declare(type, variable, len, use_constructor); \ - return *this; \ - } \ - template \ - Stack& declare_arg(const T& value) { \ - DeclareArg(value); \ - return *this; \ - } \ - Stack& class_def(const String& class_name) { \ - ClassDef(class_name); \ - return *this; \ - } \ - Stack& class_decorator(const String& decorator) { \ - ClassDecorator(decorator); \ - return *this; \ - } \ - Stack& class_start() { \ - ClassStart(); \ - return *this; \ - } \ - Stack& class_end() { \ - ClassEnd(); \ - return *this; \ - } \ - Stack& struct_start(const String& struct_name) { \ - StructStart(struct_name); \ - return *this; \ - } \ - Stack& struct_end() { \ - StructEnd(); \ - return *this; \ - } \ - Stack& func_def(const String& func_name, const String& ret_type = "") { \ - FuncDef(func_name, ret_type); \ - return *this; \ - } \ - Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - FuncArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& func_decorator(const String& decorator) { \ - FuncDecorator(decorator); \ - return *this; \ - } \ - Stack& func_start() { \ - FuncStart(); \ - return *this; \ - } \ - Stack& func_end() { \ - FuncEnd(); \ - return *this; \ - } \ - template \ - Stack& func_end(const T& ret_val) { \ - FuncEnd(ret_val); \ - return *this; \ - } \ - Stack& func_call(const String& callee, Optional assign_to, \ - Optional caller = std::nullopt) { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& func_call(const String& callee, const String& assign_to = "", \ - const String& caller = "") { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& method_call(const String& callee, bool new_line = false) { \ - MethodCall(callee, new_line); \ - return *this; \ - } \ - Stack& inplace_start(const String& callee, Optional assign_to, \ - Optional caller = std::nullopt) { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_start(const String& callee, const String& assign_to = "", \ - const String& caller = "") { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_end() { \ - InplaceEnd(); \ - return *this; \ - } \ - Stack& constructor_def(const String& func_name) { \ - ConstructorDef(func_name); \ - return *this; \ - } \ - Stack& constructor_arg(const String& arg, const String& annotation = "", \ - const String& value = "") { \ - ConstructorArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& constructor_start() { \ - ConstructorStart(); \ - return *this; \ - } \ - Stack& constructor_end() { \ - ConstructorEnd(); \ - return *this; \ - } \ - Stack& lambda_def(const String& lambda_name) { \ - LambdaDef(lambda_name); \ - return *this; \ - } \ - Stack& lambda_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - LambdaArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& lambda_ref(const String& ref) { \ - LambdaRef(ref); \ - return *this; \ - } \ - Stack& lambda_start() { \ - LambdaStart(); \ - return *this; \ - } \ - Stack& lambda_end(const String& ret_val = "") { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& lambda_end(const ExprDoc& ret_val) { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& pop_nest(const String& key = "") { \ - PopNest(key); \ - return *this; \ - } \ - template \ - Stack& call_arg(T value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const ExprDoc& value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const Array& values) { \ - CallArg(values); \ - return *this; \ - } \ - Stack& cond_if(const String& predicate) { \ - ConditionIf(predicate); \ - return *this; \ - } \ - Stack& cond_else() { \ - ConditionElse(); \ - return *this; \ - } \ - Stack& cond_end() { \ - ConditionEnd(); \ - return *this; \ - } \ - template \ - Stack& for_start(const LT& lhs, const RT& rhs) { \ - ForStart(lhs, rhs); \ - return *this; \ - } \ - template \ - Stack& for_start(const String& lhs, const ST& start, const ET& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_start(const String& lhs, const String& start, const String& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_end() { \ - ForEnd(); \ - return *this; \ - } \ - Stack& while_start(const String& predicate) { \ - WhileStart(predicate); \ - return *this; \ - } \ - Stack& while_end() { \ - WhileEnd(); \ - return *this; \ - } \ - Stack& switch_start(const String& predicate) { \ - SwitchStart(predicate); \ - return *this; \ - } \ - Stack& switch_case(const String& predicate = "") { \ - SwitchCase(predicate); \ - return *this; \ - } \ - Stack& switch_end() { \ - SwitchEnd(); \ - return *this; \ - } \ - Stack& block_start() { \ - BlockStart(); \ - return *this; \ - } \ - Stack& block_end(bool block_docs = true) { \ - BlockEnd(block_docs); \ - return *this; \ - } \ - Stack& scope_start(const String& scope_def = "", const String& scope_ref = "") { \ - ScopeStart(scope_def, scope_ref); \ - return *this; \ - } \ - Stack& scope_end() { \ - ScopeEnd(); \ - return *this; \ +#define COMMON_WRAPPERS(Stack) \ + Stack& line(const Doc& doc) { \ + Line(doc); \ + return *this; \ + } \ + Stack& line(const ffi::String& line = "") { \ + Line(line); \ + return *this; \ + } \ + Stack& comment(const ffi::String& comment, bool attach = false) { \ + Comment(comment, attach); \ + return *this; \ + } \ + template \ + Stack& assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { \ + Assign(lhs, rhs, annotation); \ + return *this; \ + } \ + Stack& declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, \ + bool use_constructor = true) { \ + Declare(type, variable, len, use_constructor); \ + return *this; \ + } \ + template \ + Stack& declare_arg(const T& value) { \ + DeclareArg(value); \ + return *this; \ + } \ + Stack& class_def(const ffi::String& class_name) { \ + ClassDef(class_name); \ + return *this; \ + } \ + Stack& class_decorator(const ffi::String& decorator) { \ + ClassDecorator(decorator); \ + return *this; \ + } \ + Stack& class_start() { \ + ClassStart(); \ + return *this; \ + } \ + Stack& class_end() { \ + ClassEnd(); \ + return *this; \ + } \ + Stack& struct_start(const ffi::String& struct_name) { \ + StructStart(struct_name); \ + return *this; \ + } \ + Stack& struct_end() { \ + StructEnd(); \ + return *this; \ + } \ + Stack& func_def(const ffi::String& func_name, const ffi::String& ret_type = "") { \ + FuncDef(func_name, ret_type); \ + return *this; \ + } \ + Stack& func_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + FuncArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& func_decorator(const ffi::String& decorator) { \ + FuncDecorator(decorator); \ + return *this; \ + } \ + Stack& func_start() { \ + FuncStart(); \ + return *this; \ + } \ + Stack& func_end() { \ + FuncEnd(); \ + return *this; \ + } \ + template \ + Stack& func_end(const T& ret_val) { \ + FuncEnd(ret_val); \ + return *this; \ + } \ + Stack& func_call(const ffi::String& callee, ffi::Optional assign_to, \ + ffi::Optional caller = std::nullopt) { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& func_call(const ffi::String& callee, const ffi::String& assign_to = "", \ + const ffi::String& caller = "") { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& method_call(const ffi::String& callee, bool new_line = false) { \ + MethodCall(callee, new_line); \ + return *this; \ + } \ + Stack& inplace_start(const ffi::String& callee, ffi::Optional assign_to, \ + ffi::Optional caller = std::nullopt) { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_start(const ffi::String& callee, const ffi::String& assign_to = "", \ + const ffi::String& caller = "") { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_end() { \ + InplaceEnd(); \ + return *this; \ + } \ + Stack& constructor_def(const ffi::String& func_name) { \ + ConstructorDef(func_name); \ + return *this; \ + } \ + Stack& constructor_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + ConstructorArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& constructor_start() { \ + ConstructorStart(); \ + return *this; \ + } \ + Stack& constructor_end() { \ + ConstructorEnd(); \ + return *this; \ + } \ + Stack& lambda_def(const ffi::String& lambda_name) { \ + LambdaDef(lambda_name); \ + return *this; \ + } \ + Stack& lambda_arg(const ffi::String& arg, const ffi::String& annotation = "", \ + const ffi::String& value = "") { \ + LambdaArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& lambda_ref(const ffi::String& ref) { \ + LambdaRef(ref); \ + return *this; \ + } \ + Stack& lambda_start() { \ + LambdaStart(); \ + return *this; \ + } \ + Stack& lambda_end(const ffi::String& ret_val = "") { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& lambda_end(const ExprDoc& ret_val) { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& pop_nest(const ffi::String& key = "") { \ + PopNest(key); \ + return *this; \ + } \ + template \ + Stack& call_arg(T value, const ffi::String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ExprDoc& value, const ffi::String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ffi::Array& values) { \ + CallArg(values); \ + return *this; \ + } \ + Stack& cond_if(const ffi::String& predicate) { \ + ConditionIf(predicate); \ + return *this; \ + } \ + Stack& cond_else() { \ + ConditionElse(); \ + return *this; \ + } \ + Stack& cond_end() { \ + ConditionEnd(); \ + return *this; \ + } \ + template \ + Stack& for_start(const LT& lhs, const RT& rhs) { \ + ForStart(lhs, rhs); \ + return *this; \ + } \ + template \ + Stack& for_start(const ffi::String& lhs, const ST& start, const ET& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_start(const ffi::String& lhs, const ffi::String& start, const ffi::String& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_end() { \ + ForEnd(); \ + return *this; \ + } \ + Stack& while_start(const ffi::String& predicate) { \ + WhileStart(predicate); \ + return *this; \ + } \ + Stack& while_end() { \ + WhileEnd(); \ + return *this; \ + } \ + Stack& switch_start(const ffi::String& predicate) { \ + SwitchStart(predicate); \ + return *this; \ + } \ + Stack& switch_case(const ffi::String& predicate = "") { \ + SwitchCase(predicate); \ + return *this; \ + } \ + Stack& switch_end() { \ + SwitchEnd(); \ + return *this; \ + } \ + Stack& block_start() { \ + BlockStart(); \ + return *this; \ + } \ + Stack& block_end(bool block_docs = true) { \ + BlockEnd(block_docs); \ + return *this; \ + } \ + Stack& scope_start(const ffi::String& scope_def = "", const ffi::String& scope_ref = "") { \ + ScopeStart(scope_def, scope_ref); \ + return *this; \ + } \ + Stack& scope_end() { \ + ScopeEnd(); \ + return *this; \ } /*! @@ -542,35 +549,37 @@ class OpCodeStack : public BaseStack { COMMON_WRAPPERS(OpCodeStack) /*! \brief Push op_call Doc*/ - OpCodeStack& op_call(const String& callee = "msc::auto", - const String& assign_to = "msc::auto") { - const String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; - const String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; + OpCodeStack& op_call(const ffi::String& callee = "msc::auto", + const ffi::String& assign_to = "msc::auto") { + const ffi::String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; + const ffi::String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; return func_call(v_callee, v_assign); } /*! \brief Push op comment Doc*/ - OpCodeStack& op_comment(const String& comment_str = "msc::auto") { - const String& v_comment = (comment_str == "msc::auto" ? codegen_->Comment() : comment_str); + OpCodeStack& op_comment(const ffi::String& comment_str = "msc::auto") { + const ffi::String& v_comment = (comment_str == "msc::auto" ? codegen_->Comment() : comment_str); return comment(v_comment); } /*! \brief Cache typed attribute as argument*/ template - OpCodeStack& op_arg(const String& attr_key, const String& key = "msc::auto") { + OpCodeStack& op_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto") { T attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(attr_val, valid_key); } return *this; } /*! \brief Cache str attribute as argument*/ - OpCodeStack& op_str_arg(const String& attr_key, const String& key = "msc::auto") { + OpCodeStack& op_str_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto") { std::string attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(DocUtils::ToStr(attr_val), valid_key); } return *this; @@ -578,24 +587,25 @@ class OpCodeStack : public BaseStack { /*! \brief Cache list attribute as argument*/ template - OpCodeStack& op_list_arg(const String& attr_key, const String& key = "msc::auto", + OpCodeStack& op_list_arg(const ffi::String& attr_key, + const ffi::String& key = "msc::auto", bool allow_empty = false) { std::vector attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const String& valid_key = key == "msc::auto" ? attr_key : key; + const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; return call_arg(DocUtils::ToList(attr_val, allow_empty), valid_key); } return *this; } /*! \brief Cache input as argument*/ - OpCodeStack& op_input_arg(int idx = 0, const String& key = "") { + OpCodeStack& op_input_arg(int idx = 0, const ffi::String& key = "") { return call_arg(codegen_->IdxInput(idx, true), key); } /*! \brief Cache inputs as argument*/ - OpCodeStack& op_inputs_arg(bool as_list = true, const String& key = "") { - Array inputs; + OpCodeStack& op_inputs_arg(bool as_list = true, const ffi::String& key = "") { + ffi::Array inputs; for (size_t i = 0; i < codegen_->node()->inputs.size(); i++) { inputs.push_back(codegen_->IdxInput(i, true)); } @@ -607,12 +617,12 @@ class OpCodeStack : public BaseStack { } /*! \brief Cache output as argument*/ - OpCodeStack& op_output_arg(int idx = 0, const String& key = "") { + OpCodeStack& op_output_arg(int idx = 0, const ffi::String& key = "") { return call_arg(codegen_->IdxOutput(idx), key); } /*! \brief Cache weight as argument*/ - OpCodeStack& op_weight_arg(const String& wtype, const String& key = "") { + OpCodeStack& op_weight_arg(const ffi::String& wtype, const ffi::String& key = "") { if (codegen_->node()->weights.count(wtype)) { return call_arg(codegen_->IdxWeight(wtype, true), key); } @@ -620,15 +630,15 @@ class OpCodeStack : public BaseStack { } /*! \brief Cache name as argument*/ - OpCodeStack& op_name_arg(const String& key = "msc::auto", - const String& name = "msc::auto") { - const String& valid_key = key == "msc::auto" ? "name" : key; - const String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; + OpCodeStack& op_name_arg(const ffi::String& key = "msc::auto", + const ffi::String& name = "msc::auto") { + const ffi::String& valid_key = key == "msc::auto" ? "name" : key; + const ffi::String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; return call_arg(DocUtils::ToStr(valid_name), valid_key); return *this; } - OpCodeStack& op_dtype_arg(const DataType& dtype, const String& key = "") { + OpCodeStack& op_dtype_arg(const DataType& dtype, const ffi::String& key = "") { return call_arg(codegen_->DType(dtype), key); } diff --git a/src/contrib/msc/core/codegen/codegen_json.cc b/src/contrib/msc/core/codegen/codegen_json.cc index 7bbe576b6bfe..6ccec35b78b4 100644 --- a/src/contrib/msc/core/codegen/codegen_json.cc +++ b/src/contrib/msc/core/codegen/codegen_json.cc @@ -50,11 +50,11 @@ std::vector MSCJSONSerializer::VisitExpr_(const CallNode* ca } global_options_set_ = true; } - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } -void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const String& key, - const String& value) { +void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, + const ffi::String& value) { std::vector array_value{std::string(value)}; std::vector dmlc_value; dmlc_value.emplace_back(array_value); diff --git a/src/contrib/msc/core/codegen/codegen_json.h b/src/contrib/msc/core/codegen/codegen_json.h index dfc2d699a968..08a834bdaa27 100644 --- a/src/contrib/msc/core/codegen/codegen_json.h +++ b/src/contrib/msc/core/codegen/codegen_json.h @@ -69,7 +69,7 @@ class MSCJSONSerializer : public JSONSerializer { * \brief Constructor * \param constant_names The names of all constants in the original module. */ - explicit MSCJSONSerializer(const Map& constant_names, + explicit MSCJSONSerializer(const ffi::Map& constant_names, const std::string& options) : JSONSerializer(constant_names) { MSCCompileConfig config; @@ -86,19 +86,19 @@ class MSCJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final; - const String GetOption(const String& key) { + const ffi::String GetOption(const ffi::String& key) { ICHECK(options_.count(key)) << "Can not find option " << key; return options_[key]; } - const Map GetOptions() { return options_; } + const ffi::Map GetOptions() { return options_; } protected: - void AddNodeAttr(JSONGraphObjectPtr node, const String& key, const String& value); + void AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, const ffi::String& value); private: MSCGraph graph_; - Map options_; + ffi::Map options_; bool global_options_set_; }; diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc b/src/contrib/msc/core/codegen/codegen_utils.cc index 741b729bd015..768c9f276e9e 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.cc +++ b/src/contrib/msc/core/codegen/codegen_utils.cc @@ -27,13 +27,13 @@ namespace tvm { namespace contrib { namespace msc { -const String CodeGenUtils::IdxNode(const MSCJoint& node, const String& prefix, - const String& suffix) { +const ffi::String CodeGenUtils::IdxNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::String& suffix) { return prefix + std::to_string(node->index) + suffix; } -const String CodeGenUtils::IdxOutput(const MSCJoint& node, const String& prefix, int idx, - const String& suffix) { +const ffi::String CodeGenUtils::IdxOutput(const MSCJoint& node, const ffi::String& prefix, int idx, + const ffi::String& suffix) { const auto& idx_node = IdxNode(node, prefix, suffix); size_t output_size = node->outputs.size(); if (output_size == 1 && node->optype != "tuple") { @@ -43,20 +43,20 @@ const String CodeGenUtils::IdxOutput(const MSCJoint& node, const String& prefix, return idx_node + "[" + std::to_string(v_index) + "]"; } -const String CodeGenUtils::IdxInput(const MSCJoint& node, const String& prefix, int idx, - const String& suffix) { +const ffi::String CodeGenUtils::IdxInput(const MSCJoint& node, const ffi::String& prefix, int idx, + const ffi::String& suffix) { const auto& pair = node->ProducerAndIdxOf(idx); return IdxOutput(pair.first, prefix, pair.second, suffix); } -const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype, - const String& suffix) { +const ffi::String CodeGenUtils::IdxWeight(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix) { return wtype + "_" + std::to_string(node->index) + suffix; } -const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, - const Map& prims) { - Array dims; +const ffi::Array CodeGenUtils::GetPrims( + const MSCTensor& tensor, const ffi::Map& prims) { + ffi::Array dims; if (tensor->prims.size() == 0) { for (size_t i = 0; i < tensor->Ndim(); i++) { dims.push_back(StringUtils::ToString(tensor->DimAt(i))); @@ -70,9 +70,9 @@ const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, return dims; } -const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix, - const Map& prims) { - String comment = node->name + "(" + node->optype + "): <"; +const ffi::String CodeGenUtils::CommentNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::Map& prims) { + ffi::String comment = node->name + "(" + node->optype + "): <"; for (size_t i = 0; i < node->inputs.size(); i++) { comment = comment + IdxInput(node, prefix, i) + (i == node->inputs.size() - 1 ? "> -> <" : ","); } diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 09b44af894e4..6fbaa96dd698 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -86,39 +86,42 @@ using namespace tvm::script::printer; this->DescribePrim(prim->ParentAt(1)) + ")"; \ } -#define CODEGEN_MEMBERS \ - public: \ - virtual const String DType(const DataType& dtype) { return runtime::DLDataTypeToString(dtype); } \ - \ - protected: \ - const std::shared_ptr config() { return config_; } \ - const Map prims() { return prims_; } \ - const String IdxNodeBase(const MSCJoint& node) { \ - return helper_.IdxNodeBase(node, config()->prefix, ""); \ - } \ - const String IdxInputBase(const MSCJoint& node, int idx = 0, bool process = true) { \ - return helper_.IdxInputBase(node, config()->prefix, idx, "", process && config()->use_tools); \ - } \ - const String IdxOutputBase(const MSCJoint& node, int idx = 0, bool mark_exit = false) { \ - return helper_.IdxOutputBase(node, config()->prefix, idx, "", \ - mark_exit && config()->use_tools); \ - } \ - const String IdxWeightBase(const MSCJoint& node, const String& wtype, bool process = true) { \ - return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ - } \ - const Array GetPrims(const MSCTensor& tensor) { \ - return CodeGenUtils::GetPrims(tensor, prims_); \ - } \ - const String Comment(const MSCJoint& node) { \ - return helper_.Comment(node, config()->prefix, prims_); \ - } \ - int CompareVersion(size_t major, size_t minor, size_t patch) { \ - return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ - } \ - \ - private: \ - std::shared_ptr config_; \ - Map prims_; \ +#define CODEGEN_MEMBERS \ + public: \ + virtual const ffi::String DType(const DataType& dtype) { \ + return runtime::DLDataTypeToString(dtype); \ + } \ + \ + protected: \ + const std::shared_ptr config() { return config_; } \ + const ffi::Map prims() { return prims_; } \ + const ffi::String IdxNodeBase(const MSCJoint& node) { \ + return helper_.IdxNodeBase(node, config()->prefix, ""); \ + } \ + const ffi::String IdxInputBase(const MSCJoint& node, int idx = 0, bool process = true) { \ + return helper_.IdxInputBase(node, config()->prefix, idx, "", process && config()->use_tools); \ + } \ + const ffi::String IdxOutputBase(const MSCJoint& node, int idx = 0, bool mark_exit = false) { \ + return helper_.IdxOutputBase(node, config()->prefix, idx, "", \ + mark_exit && config()->use_tools); \ + } \ + const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, \ + bool process = true) { \ + return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ + } \ + const ffi::Array GetPrims(const MSCTensor& tensor) { \ + return CodeGenUtils::GetPrims(tensor, prims_); \ + } \ + const ffi::String Comment(const MSCJoint& node) { \ + return helper_.Comment(node, config()->prefix, prims_); \ + } \ + int CompareVersion(size_t major, size_t minor, size_t patch) { \ + return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ + } \ + \ + private: \ + std::shared_ptr config_; \ + ffi::Map prims_; \ HelperType helper_; /*! @@ -130,42 +133,42 @@ class CodeGenUtils { * \brief Get indexed node string. * \return The String. */ - TVM_DLL static const String IdxNode(const MSCJoint& node, const String& prefix, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::String& suffix = ""); /*! * \brief Get indexed output string. * \return The String. */ - TVM_DLL static const String IdxOutput(const MSCJoint& node, const String& prefix, int idx = 0, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxOutput(const MSCJoint& node, const ffi::String& prefix, + int idx = 0, const ffi::String& suffix = ""); /*! * \brief Get indexed input string. * \return The String. */ - TVM_DLL static const String IdxInput(const MSCJoint& node, const String& prefix, int idx = 0, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxInput(const MSCJoint& node, const ffi::String& prefix, + int idx = 0, const ffi::String& suffix = ""); /*! * \brief Get indexed weight string. * \return The String. */ - TVM_DLL static const String IdxWeight(const MSCJoint& node, const String& wtype, - const String& suffix = ""); + TVM_DLL static const ffi::String IdxWeight(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = ""); /*! * \brief Infer prims of tensor. * \return The prims. */ - TVM_DLL static const Array GetPrims(const MSCTensor& tensor, - const Map& prims); + TVM_DLL static const ffi::Array GetPrims( + const MSCTensor& tensor, const ffi::Map& prims); /*! * \brief Get comment of a node. * \return The String. */ - TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix, - const Map& prims); + TVM_DLL static const ffi::String CommentNode(const MSCJoint& node, const ffi::String& prefix, + const ffi::Map& prims); }; /*! @@ -173,16 +176,17 @@ class CodeGenUtils { */ class BaseCodeGenHelper { public: - const String GetSuffix(const MSCJoint& node, bool process = false) { + const ffi::String GetSuffix(const MSCJoint& node, bool process = false) { return process ? "c" + std::to_string(node->index) : ""; } - virtual const String IdxNodeBase(const MSCJoint& node, const String& prefix = "", - const String& suffix = "") { + virtual const ffi::String IdxNodeBase(const MSCJoint& node, const ffi::String& prefix = "", + const ffi::String& suffix = "") { return CodeGenUtils::IdxNode(node, prefix, suffix); } - virtual const String IdxInputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool process = false) { + virtual const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", + int idx = 0, const ffi::String& suffix = "", + bool process = false) { const auto& pair = node->ProducerAndIdxOf(idx); size_t output_size = pair.first->outputs.size(); if (process && (output_size > 1 || pair.first->optype == "tuple")) { @@ -190,8 +194,9 @@ class BaseCodeGenHelper { } return CodeGenUtils::IdxInput(node, prefix, idx, suffix + GetSuffix(node, process)); } - virtual const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) { + virtual const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", + int idx = 0, const ffi::String& suffix = "", + bool mark_exit = false) { if (mark_exit) { if (node->outputs.size() > 1 || node->optype == "tuple") { return CodeGenUtils::IdxNode(node, prefix, suffix) + "_" + std::to_string(idx) + "_exit"; @@ -200,12 +205,13 @@ class BaseCodeGenHelper { } return CodeGenUtils::IdxOutput(node, prefix, idx, suffix); } - virtual const String IdxWeightBase(const MSCJoint& node, const String& wtype, - const String& suffix = "", bool process = false) { + virtual const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = "", bool process = false) { return CodeGenUtils::IdxWeight(node, wtype, suffix + GetSuffix(node, process)); } - virtual const String Comment(const MSCJoint& node, const String& prefix = "", - const Map& prims = Map()) { + virtual const ffi::String Comment( + const MSCJoint& node, const ffi::String& prefix = "", + const ffi::Map& prims = ffi::Map()) { return CodeGenUtils::CommentNode(node, prefix, prims); } }; diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 260bd27ca35a..99988d689a95 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -69,9 +69,10 @@ class CppCodeGen : public BaseCodeGen { virtual void CodeGenCmake() = 0; /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") { - Map sources; - auto add_source = [&print_options, &sources, this](const String& file) { + virtual const ffi::Map GetSources( + const std::string& print_options = "") { + ffi::Map sources; + auto add_source = [&print_options, &sources, this](const ffi::String& file) { CppPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -96,7 +97,7 @@ class CppCodeGen : public BaseCodeGen { protected: /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { // binary ops DESCRIBE_PRIM_BINARY("Min", "std::min", true) DESCRIBE_PRIM_BINARY("Max", "std::max", true) @@ -152,8 +153,8 @@ class CppCodeGen : public BaseCodeGen { } /*! \brief Get the tensor context for codegen_tensor*/ - virtual const Map GetTensorCtx(const MSCTensor& tensor) { - Map tensor_ctx; + virtual const ffi::Map GetTensorCtx(const MSCTensor& tensor) { + ffi::Map tensor_ctx; MSCJoint producer; if (this->graph()->weight_holders.count(tensor->name)) { producer = this->graph()->FindProducer(tensor); @@ -175,8 +176,8 @@ class CppCodeGen : public BaseCodeGen { } /*! \brief Get the step context for codegen_step*/ - virtual const Map GetStepCtx() { - Map step_ctx; + virtual const ffi::Map GetStepCtx() { + ffi::Map step_ctx; std::string version = ""; for (size_t i = 0; i < this->config()->version.size(); i++) { version += std::to_string(this->config()->version[i]) + diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index af75f0e4233d..460818089f82 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -70,8 +70,9 @@ class PyCodeGen : public BaseCodeGen { } /*! \brief Get sources*/ - virtual const Map GetSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetSources( + const std::string& print_options = "") { + ffi::Map sources; PythonPrinter printer(print_options); CodeGenScript(); for (const auto& d : this->stack_.GetDocs()) { @@ -83,7 +84,7 @@ class PyCodeGen : public BaseCodeGen { protected: /*! \brief Describe the prim*/ - virtual const String DescribePrim(const MSCPrim& prim) { + virtual const ffi::String DescribePrim(const MSCPrim& prim) { // binary ops DESCRIBE_PRIM_BINARY("Min", "min", true) DESCRIBE_PRIM_BINARY("Max", "max", true) @@ -216,7 +217,7 @@ class PyCodeGen : public BaseCodeGen { virtual void CodeGenInference() = 0; /*! \brief Get tensor type of the framework*/ - virtual const String TensorType() const { return "np.ndarray"; } + virtual const ffi::String TensorType() const { return "np.ndarray"; } private: std::set graph_outputs_; diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index dff38aade5aa..2d062d033bba 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -36,9 +36,10 @@ namespace tvm { namespace contrib { namespace msc { -MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias, const Array& prims) { - ObjectPtr n = make_object(); +MSCTensor::MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias, + const ffi::Array& prims) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->alias = std::move(alias); n->dtype = std::move(dtype); @@ -49,13 +50,13 @@ MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& la } MSCTensor::MSCTensor(const JsonMSCTensor& j_tensor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_tensor); data_ = std::move(n); } MSCTensor::MSCTensor(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -107,23 +108,23 @@ const Integer MSCTensorNode::DimAt(int index) const { return shape[v_index]; } -const Integer MSCTensorNode::DimAt(const String& axis) const { +const Integer MSCTensorNode::DimAt(const ffi::String& axis) const { auto index = layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); return DimAt(index); } -const String MSCTensorNode::PrimAt(int index) const { +const ffi::String MSCTensorNode::PrimAt(int index) const { if (prims.size() == 0) { return ""; } return prims[CommonUtils::GetIndex(index, Ndim())]; } -const String MSCTensorNode::PrimAt(const String& axis) const { +const ffi::String MSCTensorNode::PrimAt(const ffi::String& axis) const { return PrimAt(layout.IndexOf(tvm::tir::LayoutAxis::Get(axis))); } -int32_t MSCTensorNode::LayoutOf(const String& axis) const { +int32_t MSCTensorNode::LayoutOf(const ffi::String& axis) const { return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); } @@ -135,7 +136,7 @@ const Integer MSCTensorNode::GetSize() const { return size; } -const String MSCTensorNode::DTypeName() const { return runtime::DLDataTypeToString(dtype); } +const ffi::String MSCTensorNode::DTypeName() const { return runtime::DLDataTypeToString(dtype); } size_t BaseJointNode::AddChild(const BaseJoint& child) const { for (size_t i = 0; i < children.size(); i++) { @@ -157,9 +158,9 @@ const BaseJoint BaseJointNode::ChildAt(int index) const { return Downcast(children[v_index]); } -bool BaseJointNode::HasAttr(const String& key) const { return attrs.count(key); } +bool BaseJointNode::HasAttr(const ffi::String& key) const { return attrs.count(key); } -bool BaseJointNode::GetAttr(const String& key, std::string* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::string* val) const { if (attrs.count(key) && attrs[key].size() > 0) { *val = attrs[key]; return true; @@ -167,7 +168,7 @@ bool BaseJointNode::GetAttr(const String& key, std::string* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, int* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, int* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -184,7 +185,7 @@ bool BaseJointNode::GetAttr(const String& key, int* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, int64_t* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, int64_t* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -197,7 +198,7 @@ bool BaseJointNode::GetAttr(const String& key, int64_t* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, float* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, float* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -210,7 +211,7 @@ bool BaseJointNode::GetAttr(const String& key, float* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, bool* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, bool* val) const { int val_int; if (GetAttr(key, &val_int)) { *val = (val_int != 0); @@ -219,7 +220,7 @@ bool BaseJointNode::GetAttr(const String& key, bool* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -238,7 +239,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) co return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -257,7 +258,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { try { @@ -275,7 +276,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const } return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -294,7 +295,7 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { +bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { int pos = val_str.find(","); @@ -313,20 +314,22 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } -MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, +MSCJoint::MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const std::vector>& inputs, - const Array& outputs, const Map& weights) { - ObjectPtr n = make_object(); + const ffi::Array& outputs, + const ffi::Map& weights) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->shared_ref = std::move(shared_ref); n->optype = std::move(optype); n->attrs = std::move(attrs); n->scope = std::move(scope); - Array parents; - Array> array_inputs; - Array added_parents; + ffi::Array parents; + ffi::Array> array_inputs; + ffi::Array added_parents; for (const auto& pair : inputs) { // const auto& parent=Downcast(pair.first); const auto& p_name = pair.first->name; @@ -342,7 +345,7 @@ MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, cons added_parents.push_back(p_name); p_idx = added_parents.size() - 1; } - Array input{Integer(p_idx), Integer(pair.second)}; + ffi::Array input{Integer(p_idx), Integer(pair.second)}; array_inputs.push_back(input); } n->parents = std::move(parents); @@ -352,14 +355,14 @@ MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, cons data_ = std::move(n); } -MSCJoint::MSCJoint(const JsonMSCJoint& j_joint, const Map& nodes) { - ObjectPtr n = make_object(); +MSCJoint::MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_joint, nodes); data_ = std::move(n); } -MSCJoint::MSCJoint(const std::string& json_str, const Map& nodes) { - ObjectPtr n = make_object(); +MSCJoint::MSCJoint(const std::string& json_str, const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, nodes); data_ = std::move(n); } @@ -397,7 +400,8 @@ const JsonMSCJoint MSCJointNode::ToJson() const { return j_joint; } -void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map& nodes) { +void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, + const ffi::Map& nodes) { index = j_joint.index; name = j_joint.name; shared_ref = j_joint.shared_ref; @@ -413,7 +417,7 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map= 0) << "Can not find parent for " << in_name; - Array input{Integer(p_idx), Integer(std::stol(index_str))}; + ffi::Array input{Integer(p_idx), Integer(std::stol(index_str))}; inputs.push_back(input); } for (const auto& o : j_joint.outputs) { @@ -434,7 +438,8 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map& nodes) { +void MSCJointNode::FromJson(const std::string& json_str, + const ffi::Map& nodes) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonMSCJoint j_joint; @@ -449,8 +454,8 @@ const MSCTensor MSCJointNode::InputAt(int index) const { return ParentAt(p_idx->value)->OutputAt(out_idx->value); } -const Array MSCJointNode::GetInputs() const { - Array t_inputs; +const ffi::Array MSCJointNode::GetInputs() const { + ffi::Array t_inputs; for (size_t i = 0; i < inputs.size(); i++) { t_inputs.push_back(InputAt(i)); } @@ -462,15 +467,15 @@ const MSCTensor MSCJointNode::OutputAt(int index) const { return outputs[v_index]; } -const Array MSCJointNode::GetOutputs() const { - Array t_outputs; +const ffi::Array MSCJointNode::GetOutputs() const { + ffi::Array t_outputs; for (size_t i = 0; i < outputs.size(); i++) { t_outputs.push_back(OutputAt(i)); } return t_outputs; } -const MSCTensor MSCJointNode::WeightAt(const String& wtype) const { +const MSCTensor MSCJointNode::WeightAt(const ffi::String& wtype) const { ICHECK(weights.count(wtype)) << "Can not find " << wtype << " from weights"; return weights[wtype]; } @@ -490,7 +495,7 @@ const MSCJoint MSCJointNode::ProducerOf(int index) const { return pair.first; } -const MSCJoint MSCJointNode::ProducerOf(const String& input_name) const { +const MSCJoint MSCJointNode::ProducerOf(const ffi::String& input_name) const { const auto& pair = ProducerAndIdxOf(input_name); return pair.first; } @@ -505,7 +510,7 @@ const std::pair MSCJointNode::ProducerAndIdxOf(int index) cons return std::make_pair(ParentAt(p_idx->value), inputs[v_index][1]->value); } -const std::pair MSCJointNode::ProducerAndIdxOf(const String& name) const { +const std::pair MSCJointNode::ProducerAndIdxOf(const ffi::String& name) const { for (size_t i = 0; i < inputs.size(); i++) { if (InputAt(i)->name == name) { return ProducerAndIdxOf(i); @@ -518,9 +523,10 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor return ProducerAndIdxOf(input->name); } -MSCPrim::MSCPrim(int index, const String& name, const String& optype, - const Array& parents, const Map& attrs) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(int index, const ffi::String& name, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->optype = std::move(optype); @@ -531,14 +537,14 @@ MSCPrim::MSCPrim(int index, const String& name, const String& optype, data_ = std::move(n); } -MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const Map& prims) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_prim, prims); data_ = std::move(n); } -MSCPrim::MSCPrim(const std::string& json_str, const Map& prims) { - ObjectPtr n = make_object(); +MSCPrim::MSCPrim(const std::string& json_str, const ffi::Map& prims) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, prims); data_ = std::move(n); } @@ -557,7 +563,8 @@ const JsonMSCPrim MSCPrimNode::ToJson() const { return j_prim; } -void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { +void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, + const ffi::Map& prims) { index = j_prim.index; name = j_prim.name; optype = j_prim.optype; @@ -570,7 +577,8 @@ void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { +void MSCPrimNode::FromJson(const std::string& json_str, + const ffi::Map& prims) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonMSCPrim j_prim; @@ -588,11 +596,12 @@ const MSCPrim MSCPrimNode::ChildAt(int index) const { return Downcast(children[v_index]); } -WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, - const Array parents, const Map& attrs, - const Array& friends) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, + const ffi::Array parents, + const ffi::Map& attrs, + const ffi::Array& friends) { + ObjectPtr n = ffi::make_object(); n->index = index; n->name = std::move(name); n->shared_ref = std::move(shared_ref); @@ -606,14 +615,16 @@ WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref data_ = std::move(n); } -WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, + const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(j_joint, nodes); data_ = std::move(n); } -WeightJoint::WeightJoint(const std::string& json_str, const Map& nodes) { - ObjectPtr n = make_object(); +WeightJoint::WeightJoint(const std::string& json_str, + const ffi::Map& nodes) { + ObjectPtr n = ffi::make_object(); n->FromJson(json_str, nodes); data_ = std::move(n); } @@ -639,7 +650,7 @@ const JsonWeightJoint WeightJointNode::ToJson() const { } void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, - const Map& nodes) { + const ffi::Map& nodes) { index = j_joint.index; name = j_joint.name; shared_ref = j_joint.shared_ref; @@ -654,7 +665,8 @@ void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, } } -void WeightJointNode::FromJson(const std::string& json_str, const Map& nodes) { +void WeightJointNode::FromJson(const std::string& json_str, + const ffi::Map& nodes) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JsonWeightJoint j_joint; @@ -672,14 +684,14 @@ const WeightJoint WeightJointNode::ChildAt(int index) const { return Downcast(children[v_index]); } -const bool BaseGraphNode::HasNode(const String& name) const { +const bool BaseGraphNode::HasNode(const ffi::String& name) const { return nodes.count(name) ? true : false; } -MSCGraph::MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names, - const Array& prims) { - ObjectPtr n = make_object(); +MSCGraph::MSCGraph(const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, const ffi::Array& prims) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); for (const auto& node : nodes) { n->node_names.push_back(node->name); @@ -696,13 +708,13 @@ MSCGraph::MSCGraph(const String& name, const Array& nodes, } MSCGraph::MSCGraph(const JsonMSCGraph& j_graph) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_graph); data_ = std::move(n); } MSCGraph::MSCGraph(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -735,7 +747,7 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { for (const auto& o : j_graph.outputs) { output_names.push_back(o); } - Map loaded_nodes; + ffi::Map loaded_nodes; for (const auto& n : j_graph.nodes) { const auto& node = MSCJoint(n, loaded_nodes); loaded_nodes.Set(node->name, node); @@ -745,7 +757,7 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { node_names.push_back(node->name); nodes.Set(node->name, node); } - Map loaded_prims; + ffi::Map loaded_prims; for (const auto& n : j_graph.prims) { const auto& prim = MSCPrim(n, loaded_prims); loaded_prims.Set(prim->name, prim); @@ -766,13 +778,13 @@ void MSCGraphNode::FromJson(const std::string& json_str) { FromJson(j_graph); } -const String MSCGraphNode::ToPrototxt() const { +const ffi::String MSCGraphNode::ToPrototxt() const { PrototxtPrinter printer; - printer.Append(Map{{"name", name}}); + printer.Append(ffi::Map{{"name", name}}); for (const auto& n : node_names) { const auto& node = FindNode(n); // define layer - std::vector> layer; + std::vector> layer; layer.push_back(std::make_pair("name", node->name)); layer.push_back(std::make_pair("type", StringUtils::Replace(node->optype, ".", "_"))); layer.push_back(std::make_pair("top", node->name)); @@ -780,7 +792,7 @@ const String MSCGraphNode::ToPrototxt() const { layer.push_back(std::make_pair("bottom", Downcast(p)->name)); } // define layer param - Map param; + ffi::Map param; param.Set("idx", Integer(node->index)); for (size_t i = 0; i < node->inputs.size(); i++) { param.Set("input_" + std::to_string(i), node->InputAt(i)); @@ -796,17 +808,17 @@ const String MSCGraphNode::ToPrototxt() const { } layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); // Append the layer Map - printer.Append(Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); + printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); } return printer.GetString(); } -const MSCJoint MSCGraphNode::FindNode(const String& name) const { +const MSCJoint MSCGraphNode::FindNode(const ffi::String& name) const { ICHECK(nodes.count(name)) << "Can not find node " << name; return Downcast(nodes[name]); } -const MSCPrim MSCGraphNode::FindPrim(const String& name) const { +const MSCPrim MSCGraphNode::FindPrim(const ffi::String& name) const { ICHECK(prims.count(name)) << "Can not find prim " << name; return prims[name]; } @@ -816,8 +828,8 @@ const MSCTensor MSCGraphNode::InputAt(int index) const { return FindTensor(input_names[v_index]); } -const Array MSCGraphNode::GetInputs() const { - Array t_inputs; +const ffi::Array MSCGraphNode::GetInputs() const { + ffi::Array t_inputs; for (size_t i = 0; i < input_names.size(); i++) { t_inputs.push_back(InputAt(i)); } @@ -829,25 +841,25 @@ const MSCTensor MSCGraphNode::OutputAt(int index) const { return FindTensor(output_names[v_index]); } -const Array MSCGraphNode::GetOutputs() const { - Array t_outputs; +const ffi::Array MSCGraphNode::GetOutputs() const { + ffi::Array t_outputs; for (size_t i = 0; i < output_names.size(); i++) { t_outputs.push_back(OutputAt(i)); } return t_outputs; } -const Array MSCGraphNode::GetEntries() const { - Array entries; +const ffi::Array MSCGraphNode::GetEntries() const { + ffi::Array entries; for (size_t i = 0; i < input_names.size(); i++) { entries.push_back(FindProducer(input_names[i])); } return entries; } -const Array MSCGraphNode::GetExits() const { - Array exits; - std::set setted_exits; +const ffi::Array MSCGraphNode::GetExits() const { + ffi::Array exits; + std::set setted_exits; for (size_t i = 0; i < output_names.size(); i++) { const auto& exit = FindProducer(output_names[i]); if (setted_exits.count(exit->name)) { @@ -859,18 +871,18 @@ const Array MSCGraphNode::GetExits() const { return exits; } -const bool MSCGraphNode::HasTensor(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const bool MSCGraphNode::HasTensor(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { return true; } - String host, index; + ffi::String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); return nodes.count(host) > 0 ? true : false; } -const MSCTensor MSCGraphNode::FindTensor(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const MSCTensor MSCGraphNode::FindTensor(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { const auto& node = FindNode(weight_holders[tensor_name][0]); for (const auto& pair : node->weights) { @@ -884,8 +896,8 @@ const MSCTensor MSCGraphNode::FindTensor(const String& name) const { return pair.first->OutputAt(pair.second); } -const MSCJoint MSCGraphNode::FindProducer(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const MSCJoint MSCGraphNode::FindProducer(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { return FindNode(weight_holders[tensor_name][0]); } @@ -897,10 +909,10 @@ const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const { return FindProducer(tensor->name); } -const std::pair MSCGraphNode::FindProducerAndIdx(const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const std::pair MSCGraphNode::FindProducerAndIdx(const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; ICHECK(!weight_holders.count(tensor_name)) << "Weight " << name << " has no producer with index"; - String host, index; + ffi::String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); if (index.size() == 0) { const auto& node = FindNode(host); @@ -914,9 +926,9 @@ const std::pair MSCGraphNode::FindProducerAndIdx(const MSCTens return FindProducerAndIdx(tensor->name); } -const Array MSCGraphNode::FindConsumers(const String& name) const { - Array consumers; - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; +const ffi::Array MSCGraphNode::FindConsumers(const ffi::String& name) const { + ffi::Array consumers; + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; if (weight_holders.count(tensor_name)) { for (const auto& h : weight_holders[tensor_name]) { consumers.push_back(FindNode(h)); @@ -930,13 +942,13 @@ const Array MSCGraphNode::FindConsumers(const String& name) const { return consumers; } -const Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const { +const ffi::Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const { return FindConsumers(tensor->name); } const std::vector> MSCGraphNode::FindConsumersAndIndices( - const String& name) const { - const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + const ffi::String& name) const { + const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; std::vector> consumers; for (const auto& c : FindConsumers(name)) { @@ -987,11 +999,11 @@ void MSCGraphNode::AnalysisGraph() { for (const auto& pair : node->weights) { const auto& w_name = pair.second->name; if (weight_holders.count(w_name)) { - Array holders = weight_holders[w_name]; + ffi::Array holders = weight_holders[w_name]; holders.push_back(n); weight_holders.Set(w_name, holders); } else { - weight_holders.Set(w_name, Array({n})); + weight_holders.Set(w_name, ffi::Array({n})); if (pair.second->alias.size() > 0) { tensor_alias.Set(pair.second->alias, pair.second->name); } @@ -1000,28 +1012,30 @@ void MSCGraphNode::AnalysisGraph() { } } -WeightGraph::WeightGraph(const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) { - ObjectPtr n = make_object(); +WeightGraph::WeightGraph(const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) { + ObjectPtr n = ffi::make_object(); n->name = graph->name + "_weights"; n->Build(graph, main_wtypes, relation_wtypes); data_ = std::move(n); } WeightGraph::WeightGraph(const JsonWeightGraph& j_graph) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_graph); data_ = std::move(n); } WeightGraph::WeightGraph(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } -void WeightGraphNode::Build(const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) { +void WeightGraphNode::Build(const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) { auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) { return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index; }; @@ -1058,7 +1072,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map parents_array; + ffi::Array parents_array; if (parents.size() > 1) { std::sort(parents.begin(), parents.end(), sort_nodes); } @@ -1089,7 +1103,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapoptype]) { if (node->weights.count(wtype)) { const auto& weight = node->WeightAt(wtype); - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "main"); const auto& w_node = @@ -1104,7 +1118,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapweights) { if (!nodes.count(pair.second->name)) { - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, @@ -1116,7 +1130,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapoptype)) { const auto& tensor = node->OutputAt(0); - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); if (node->optype == "reshape") { // TODO(archermmt): check non-passby reshape @@ -1134,7 +1148,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Mapweights.size() > 0) { for (const auto& pair : node->weights) { if (!nodes.count(pair.second->name)) { - Map attrs; + ffi::Map attrs; attrs.Set("producer_type", node->optype); attrs.Set("weight_strategy", "follow"); const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, @@ -1151,7 +1165,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map(nodes[name]); } @@ -1168,7 +1182,7 @@ const JsonWeightGraph WeightGraphNode::ToJson() const { void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { name = j_graph.name; - Map loaded_nodes; + ffi::Map loaded_nodes; for (const auto& n : j_graph.nodes) { const auto& node = WeightJoint(n, loaded_nodes); loaded_nodes.Set(node->name, node); @@ -1196,13 +1210,13 @@ void WeightGraphNode::FromJson(const std::string& json_str) { FromJson(j_graph); } -const String WeightGraphNode::ToPrototxt() const { +const ffi::String WeightGraphNode::ToPrototxt() const { PrototxtPrinter printer; - printer.Append(Map{{"name", name}}); + printer.Append(ffi::Map{{"name", name}}); for (const auto& n : node_names) { const auto& node = FindNode(n); // define layer - std::vector> layer; + std::vector> layer; layer.push_back(std::make_pair("name", node->name)); layer.push_back(std::make_pair("type", node->weight_type)); layer.push_back(std::make_pair("top", node->name)); @@ -1210,7 +1224,7 @@ const String WeightGraphNode::ToPrototxt() const { layer.push_back(std::make_pair("bottom", Downcast(p)->name)); } // define layer param - Map param; + ffi::Map param; param.Set("idx", Integer(node->index)); param.Set("weight", node->weight); for (size_t i = 0; i < node->friends.size(); i++) { @@ -1221,14 +1235,15 @@ const String WeightGraphNode::ToPrototxt() const { } layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); // Append the layer Map - printer.Append(Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); + printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); } return printer.GetString(); } -MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors) { - Array nodes; - std::unordered_map> inputs_map; +MSCGraph PruneWeights(const MSCGraph& graph, + const ffi::Map& pruned_tensors) { + ffi::Array nodes; + std::unordered_map> inputs_map; for (const auto& name : graph->node_names) { const auto& node = graph->FindNode(name); // define inputs @@ -1238,20 +1253,20 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune inputs.push_back(inputs_map[input->name]); } // define outputs - Array outputs; + ffi::Array outputs; for (const auto& out : node->outputs) { const auto& output = pruned_tensors.count(out->name) ? pruned_tensors[out->name] : out; outputs.push_back(output); } // define weights - Map weights; + ffi::Map weights; for (const auto& pair : node->weights) { const auto& weight = pruned_tensors.count(pair.second->name) ? pruned_tensors[pair.second->name] : pair.second; weights.Set(pair.first, weight); } // define attributes - Map attrs = node->attrs; + ffi::Map attrs = node->attrs; if (node->optype == "reshape" && attrs.count("shape") && pruned_tensors.count(node->OutputAt(0)->name)) { const auto& new_shape = pruned_tensors[node->OutputAt(0)->name]->shape; @@ -1268,7 +1283,7 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune Downcast(p)->AddChild(new_node); } } - Array prims; + ffi::Array prims; for (const auto& name : graph->prim_names) { prims.push_back(graph->FindPrim(name)); } @@ -1436,13 +1451,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.MSCTensor", - [](const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias, - const Array& prims) -> MSCTensor { + [](const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias, + const ffi::Array& prims) -> MSCTensor { return MSCTensor(name, dtype, layout, shape, alias, prims); }) .def("msc.core.MSCTensorToJson", - [](const MSCTensor& tensor) -> String { + [](const MSCTensor& tensor) -> ffi::String { const auto& tensor_json = tensor->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1450,12 +1465,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.MSCTensorFromJson", - [](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) + [](const ffi::String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) .def("msc.core.MSCJoint", - [](Integer index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, - const Array& parents, const Array out_indices, - const Array& outputs, const Map& weights) -> MSCJoint { + [](Integer index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const ffi::Array& parents, + const ffi::Array out_indices, const ffi::Array& outputs, + const ffi::Map& weights) -> MSCJoint { std::vector> inputs; for (size_t i = 0; i < parents.size(); i++) { inputs.push_back(std::make_pair(parents[i], out_indices[i]->value)); @@ -1464,19 +1480,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ weights); }) .def("msc.core.MSCPrim", - [](Integer index, const String& name, const String& optype, - const Map& attrs, const Array& parents) -> MSCPrim { - Array b_parents; + [](Integer index, const ffi::String& name, const ffi::String& optype, + const ffi::Map& attrs, + const ffi::Array& parents) -> MSCPrim { + ffi::Array b_parents; for (const auto& p : parents) { b_parents.push_back(p); } return MSCPrim(index->value, name, optype, b_parents, attrs); }) .def("msc.core.WeightJoint", - [](Integer index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, const Array parents, - const Map& attrs, const Array& friends) -> WeightJoint { - Array b_parents, b_friends; + [](Integer index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, + const ffi::Array parents, const ffi::Map& attrs, + const ffi::Array& friends) -> WeightJoint { + ffi::Array b_parents, b_friends; for (const auto& p : parents) { b_parents.push_back(p); } @@ -1486,16 +1504,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, attrs, b_friends); }) - .def("msc.core.WeightJointSetAttr", [](const WeightJoint& node, const String& key, - const String& value) { node->attrs.Set(key, value); }) + .def("msc.core.WeightJointSetAttr", + [](const WeightJoint& node, const ffi::String& key, const ffi::String& value) { + node->attrs.Set(key, value); + }) .def("msc.core.MSCGraph", - [](const String& name, const Array& nodes, const Array& input_names, - const Array& output_names, const Array& prims) -> MSCGraph { + [](const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, + const ffi::Array& prims) -> MSCGraph { return MSCGraph(name, nodes, input_names, output_names, prims); }) .def("msc.core.WeightGraph", - [](const MSCGraph& graph, const Map>& main_wtypes, - const Map& relation_wtypes) -> WeightGraph { + [](const MSCGraph& graph, + const ffi::Map>& main_wtypes, + const ffi::Map& relation_wtypes) -> WeightGraph { return WeightGraph(graph, main_wtypes, relation_wtypes); }); }); @@ -1505,36 +1528,36 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.MSCGraphHasNode", - [](const MSCGraph& graph, const String& name) -> Bool { + [](const MSCGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasNode(name)); }) .def("msc.core.MSCGraphFindNode", - [](const MSCGraph& graph, const String& name) -> MSCJoint { + [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { return graph->FindNode(name); }) .def("msc.core.MSCGraphFindPrim", - [](const MSCGraph& graph, const String& name) -> MSCPrim { + [](const MSCGraph& graph, const ffi::String& name) -> MSCPrim { return graph->FindPrim(name); }) .def("msc.core.MSCGraphHasTensor", - [](const MSCGraph& graph, const String& name) -> Bool { + [](const MSCGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasTensor(name)); }) .def("msc.core.MSCGraphFindTensor", - [](const MSCGraph& graph, const String& name) -> MSCTensor { + [](const MSCGraph& graph, const ffi::String& name) -> MSCTensor { return graph->FindTensor(name); }) .def("msc.core.MSCGraphSetTensorAlias", - [](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { + [](const MSCGraph& graph, const MSCTensor& tensor, const ffi::String& alias) { tensor->alias = alias; graph->tensor_alias.Set(alias, tensor->name); }) .def("msc.core.MSCGraphFindProducer", - [](const MSCGraph& graph, const String& name) -> MSCJoint { + [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { return graph->FindProducer(name); }) .def("msc.core.MSCGraphFindConsumers", - [](const MSCGraph& graph, const String& name) -> Array { + [](const MSCGraph& graph, const ffi::String& name) -> ffi::Array { return graph->FindConsumers(name); }) .def("msc.core.MSCGraphInputAt", @@ -1542,11 +1565,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("msc.core.MSCGraphOutputAt", [](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }) .def("msc.core.MSCGraphGetInputs", - [](const MSCGraph& graph) -> Array { return graph->GetInputs(); }) + [](const MSCGraph& graph) -> ffi::Array { return graph->GetInputs(); }) .def("msc.core.MSCGraphGetOutputs", - [](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }) + [](const MSCGraph& graph) -> ffi::Array { return graph->GetOutputs(); }) .def("msc.core.MSCGraphToJson", - [](const MSCGraph& graph) -> String { + [](const MSCGraph& graph) -> ffi::String { const auto& graph_json = graph->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1554,9 +1577,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.MSCGraphFromJson", - [](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) + [](const ffi::String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) .def("msc.core.MSCGraphToPrototxt", - [](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); + [](const MSCGraph& graph) -> ffi::String { return graph->ToPrototxt(); }); }); // Weight Graph APIS @@ -1564,15 +1587,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.WeightGraphHasNode", - [](const WeightGraph& graph, const String& name) -> Bool { + [](const WeightGraph& graph, const ffi::String& name) -> Bool { return Bool(graph->HasNode(name)); }) .def("msc.core.WeightGraphFindNode", - [](const WeightGraph& graph, const String& name) -> WeightJoint { + [](const WeightGraph& graph, const ffi::String& name) -> WeightJoint { return graph->FindNode(name); }) .def("msc.core.WeightGraphToJson", - [](const WeightGraph& graph) -> String { + [](const WeightGraph& graph) -> ffi::String { const auto& graph_json = graph->ToJson(); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -1580,47 +1603,49 @@ TVM_FFI_STATIC_INIT_BLOCK({ return os.str(); }) .def("msc.core.WeightGraphFromJson", - [](const String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) + [](const ffi::String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) .def("msc.core.WeightGraphToPrototxt", - [](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }) + [](const WeightGraph& graph) -> ffi::String { return graph->ToPrototxt(); }) .def("msc.core.MSCJointInputAt", [](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }) .def("msc.core.MSCJointOutputAt", [](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }) .def("msc.core.MSCJointWeightAt", - [](const MSCJoint& node, const String& wtype) -> MSCTensor { + [](const MSCJoint& node, const ffi::String& wtype) -> MSCTensor { return node->WeightAt(wtype); }) .def("msc.core.MSCJointGetInputs", - [](const MSCJoint& node) -> Array { return node->GetInputs(); }) + [](const MSCJoint& node) -> ffi::Array { return node->GetInputs(); }) .def("msc.core.MSCJointGetOutputs", - [](const MSCJoint& node) -> Array { return node->GetOutputs(); }) + [](const MSCJoint& node) -> ffi::Array { return node->GetOutputs(); }) .def("msc.core.MSCJointGetWeights", - [](const MSCJoint& node) -> Map { return node->weights; }) + [](const MSCJoint& node) -> ffi::Map { return node->weights; }) .def("msc.core.MSCJointHasAttr", - [](const MSCJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }) + [](const MSCJoint& node, const ffi::String& key) -> Bool { + return Bool(node->HasAttr(key)); + }) .def("msc.core.MSCJointGetAttrs", - [](const MSCJoint& node) -> Map { return node->attrs; }) + [](const MSCJoint& node) -> ffi::Map { return node->attrs; }) .def("msc.core.WeightJointHasAttr", - [](const WeightJoint& node, const String& key) -> Bool { + [](const WeightJoint& node, const ffi::String& key) -> Bool { return Bool(node->HasAttr(key)); }) - .def("msc.core.WeightJointGetAttrs", - [](const WeightJoint& node) -> Map { return node->attrs; }) + .def( + "msc.core.WeightJointGetAttrs", + [](const WeightJoint& node) -> ffi::Map { return node->attrs; }) .def("msc.core.MSCTensorDTypeName", - [](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }) + [](const MSCTensor& tensor) -> ffi::String { return tensor->DTypeName(); }) .def("msc.core.MSCTensorDimAt", - [](const MSCTensor& tensor, const String& axis) -> Integer { + [](const MSCTensor& tensor, const ffi::String& axis) -> Integer { return tensor->DimAt(axis); }) .def("msc.core.MSCTensorGetSize", [](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }) .def("msc.core.MSCTensorSetAlias", - [](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }) + [](const MSCTensor& tensor, const ffi::String& alias) { tensor->alias = alias; }) .def("msc.core.PruneWeights", - [](const MSCGraph& graph, const Map& pruned_tensors) -> MSCGraph { - return PruneWeights(graph, pruned_tensors); - }); + [](const MSCGraph& graph, const ffi::Map& pruned_tensors) + -> MSCGraph { return PruneWeights(graph, pruned_tensors); }); }); } // namespace msc diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index a8587a2e5ed8..46da84dc03b8 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -342,17 +342,17 @@ struct JsonWeightGraph { class MSCTensorNode : public Object { public: /*! \brief The name of tensor. */ - String name; + ffi::String name; /*! \brief The alias of tensor, can be changed. */ - mutable String alias; + mutable ffi::String alias; /*! \brief The data type of tensor. */ DataType dtype; /*! \brief The layout of tensor. */ tvm::tir::Layout layout; /*! \brief The shape of tensor. */ - Array shape; + ffi::Array shape; /*! \brief The prims of tensor. */ - Array prims; + ffi::Array prims; /*! \brief Export tensor to json. */ const JsonMSCTensor ToJson() const; /*! \brief Load tensor from json struct. */ @@ -364,17 +364,17 @@ class MSCTensorNode : public Object { /*! \brief Get dim at given index. */ const Integer DimAt(int index) const; /*! \brief Get dim at given axis. */ - const Integer DimAt(const String& axis) const; + const Integer DimAt(const ffi::String& axis) const; /*! \brief Get prim at given index. */ - const String PrimAt(int index) const; + const ffi::String PrimAt(int index) const; /*! \brief Get prim at given axis. */ - const String PrimAt(const String& axis) const; + const ffi::String PrimAt(const ffi::String& axis) const; /*! \brief Get layout index of given axis. */ - int32_t LayoutOf(const String& axis) const; + int32_t LayoutOf(const ffi::String& axis) const; /*! \brief Get size of the tensor. */ const Integer GetSize() const; /*! \brief Get name of the dtype. */ - const String DTypeName() const; + const ffi::String DTypeName() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -407,9 +407,9 @@ class MSCTensor : public ObjectRef { * \param alias The alias of the tensor. * \param prims The prims of the tensor shape. */ - TVM_DLL MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias = "", - const Array& prims = Array()); + TVM_DLL MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, + const ffi::Array& shape, const ffi::String& alias = "", + const ffi::Array& prims = ffi::Array()); /*! * \brief The json constructor. @@ -435,15 +435,15 @@ class BaseJointNode : public Object { /*! \brief The index of node, can be changed. */ mutable int index; /*! \brief The name of node. */ - String name; + ffi::String name; /*! \brief The shared_ref of node, can be changed. */ - String shared_ref; + ffi::String shared_ref; /*! \brief The attributes of node. */ - mutable Map attrs; + mutable ffi::Map attrs; /*! \brief The parents of node. */ - Array parents; + ffi::Array parents; /*! \brief The children of node. */ - mutable Array children; + mutable ffi::Array children; /*! \brief Add child to the node. */ size_t AddChild(const BaseJoint& child) const; /*! \brief Get parent from the node. */ @@ -451,27 +451,27 @@ class BaseJointNode : public Object { /*! \brief Get child from the node. */ const BaseJoint ChildAt(int index) const; /*! \brief Check if has the attribute. */ - bool HasAttr(const String& key) const; + bool HasAttr(const ffi::String& key) const; /*! \brief Get the attribute by type. */ - bool GetAttr(const String& key, std::string* val) const; - bool GetAttr(const String& key, int* val) const; - bool GetAttr(const String& key, int64_t* val) const; - bool GetAttr(const String& key, float* val) const; - bool GetAttr(const String& key, bool* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; - bool GetAttr(const String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::string* val) const; + bool GetAttr(const ffi::String& key, int* val) const; + bool GetAttr(const ffi::String& key, int64_t* val) const; + bool GetAttr(const ffi::String& key, float* val) const; + bool GetAttr(const ffi::String& key, bool* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; + bool GetAttr(const ffi::String& key, std::vector* val) const; /*! \brief Check and get the attribute by type. */ template - const T GetTypeAttr(const String& key) const { + const T GetTypeAttr(const ffi::String& key) const { T val; ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; } template - const std::vector GetTypeArrayAttr(const String& key) const { + const std::vector GetTypeArrayAttr(const ffi::String& key) const { std::vector val; ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; @@ -510,42 +510,42 @@ class MSCJoint; class MSCJointNode : public BaseJointNode { public: /*! \brief The op type of node. */ - String optype; + ffi::String optype; /*! \brief The scope of node. */ - Array scope; + ffi::Array scope; /*! \brief The inputs of node, can be changed. */ - Array> inputs; + ffi::Array> inputs; /*! \brief The outputs of node. */ - Array outputs; + ffi::Array outputs; /*! \brief The weights of node. */ - Map weights; + ffi::Map weights; /*! \brief Export node to json. */ const JsonMSCJoint ToJson() const; /*! \brief Load node from json struct. */ - void FromJson(const JsonMSCJoint& j_joint, const Map& nodes); + void FromJson(const JsonMSCJoint& j_joint, const ffi::Map& nodes); /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const Map& nodes); + void FromJson(const std::string& json_str, const ffi::Map& nodes); /*! \brief Get input from the node. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the node. */ - const Array GetInputs() const; + const ffi::Array GetInputs() const; /*! \brief Get output from the node. */ const MSCTensor OutputAt(int index) const; /*! \brief Get outputs from the node. */ - const Array GetOutputs() const; + const ffi::Array GetOutputs() const; /*! \brief Get weight from the node. */ - const MSCTensor WeightAt(const String& wtype) const; + const MSCTensor WeightAt(const ffi::String& wtype) const; /*! \brief Get parent from the node. */ const MSCJoint ParentAt(int index) const; /*! \brief Get child from the node. */ const MSCJoint ChildAt(int index) const; /*! \brief Get Producer of the input. */ const MSCJoint ProducerOf(int index) const; - const MSCJoint ProducerOf(const String& input_name) const; + const MSCJoint ProducerOf(const ffi::String& input_name) const; const MSCJoint ProducerOf(const MSCTensor& input) const; /*! \brief Get Producer and out index of the input. */ const std::pair ProducerAndIdxOf(int index) const; - const std::pair ProducerAndIdxOf(const String& name) const; + const std::pair ProducerAndIdxOf(const ffi::String& name) const; const std::pair ProducerAndIdxOf(const MSCTensor& input) const; static void RegisterReflection() { @@ -580,22 +580,24 @@ class MSCJoint : public BaseJoint { * \param outputs The outputs of the node. * \param weights The weights of the node. */ - TVM_DLL MSCJoint(int index, const String& name, const String& shared_ref, const String& optype, - const Map& attrs, const Array& scope, + TVM_DLL MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& optype, const ffi::Map& attrs, + const ffi::Array& scope, const std::vector>& inputs, - const Array& outputs, const Map& weights); + const ffi::Array& outputs, + const ffi::Map& weights); /*! * \brief The json constructor. * \param j_joint The json describe of the node. */ - TVM_DLL MSCJoint(const JsonMSCJoint& j_joint, const Map& nodes); + TVM_DLL MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes); /*! * \brief The json constructor. * \param json_str The json describe of the node. */ - TVM_DLL MSCJoint(const std::string& json_str, const Map& nodes); + TVM_DLL MSCJoint(const std::string& json_str, const ffi::Map& nodes); /*! \brief Clone the node. */ TVM_DLL static const MSCJoint Clone(const MSCJoint& node, @@ -611,13 +613,13 @@ class MSCPrim; class MSCPrimNode : public BaseJointNode { public: /*! \brief The op of prim. */ - String optype; + ffi::String optype; /*! \brief Export prim to json. */ const JsonMSCPrim ToJson() const; /*! \brief Load prim from json struct. */ - void FromJson(const JsonMSCPrim& j_prim, const Map& prims); + void FromJson(const JsonMSCPrim& j_prim, const ffi::Map& prims); /*! \brief Load prim from json string. */ - void FromJson(const std::string& json_str, const Map& prims); + void FromJson(const std::string& json_str, const ffi::Map& prims); /*! \brief Get parent from the prim. */ const MSCPrim ParentAt(int index) const; /*! \brief Get child from the prim. */ @@ -646,21 +648,22 @@ class MSCPrim : public BaseJoint { * \param parents The parents of the prim. * \param attrs The attributes of the prim. */ - TVM_DLL MSCPrim(int index, const String& name, const String& optype, - const Array& parents, - const Map& attrs = Map()); + TVM_DLL MSCPrim( + int index, const ffi::String& name, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs = ffi::Map()); /*! * \brief The json constructor. * \param j_prim The json describe of the prim. */ - TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const Map& prims); + TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims); /*! * \brief The json constructor. * \param json_str The json describe of the prim. */ - TVM_DLL MSCPrim(const std::string& json_str, const Map& prims); + TVM_DLL MSCPrim(const std::string& json_str, const ffi::Map& prims); TVM_DEFINE_OBJECT_REF_METHODS(MSCPrim, BaseJoint, MSCPrimNode); }; @@ -672,17 +675,17 @@ class WeightJoint; class WeightJointNode : public BaseJointNode { public: /*! \brief The weight reference of weight node. */ - String weight_type; + ffi::String weight_type; /*! \brief The weight of weight node. */ MSCTensor weight; /*! \brief The friends of weight node. */ - mutable Array friends; + mutable ffi::Array friends; /*! \brief Export node to json. */ const JsonWeightJoint ToJson() const; /*! \brief Load node from json struct. */ - void FromJson(const JsonWeightJoint& j_joint, const Map& nodes); + void FromJson(const JsonWeightJoint& j_joint, const ffi::Map& nodes); /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const Map& nodes); + void FromJson(const std::string& json_str, const ffi::Map& nodes); /*! \brief Get parent from the node. */ const WeightJoint ParentAt(int index) const; /*! \brief Get child from the node. */ @@ -717,23 +720,24 @@ class WeightJoint : public BaseJoint { * \param attrs The attributes of the node. * \param friends The friends of the node. */ - TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref, - const String& weight_type, const MSCTensor& weight, - const Array parents, - const Map& attrs = Map(), - const Array& friends = Array()); + TVM_DLL WeightJoint( + int index, const ffi::String& name, const ffi::String& shared_ref, + const ffi::String& weight_type, const MSCTensor& weight, const ffi::Array parents, + const ffi::Map& attrs = ffi::Map(), + const ffi::Array& friends = ffi::Array()); /*! * \brief The json constructor. * \param j_joint The json describe of the node. */ - TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, const Map& nodes); + TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, + const ffi::Map& nodes); /*! * \brief The json constructor. * \param json_str The json describe of the node. */ - TVM_DLL WeightJoint(const std::string& json_str, const Map& nodes); + TVM_DLL WeightJoint(const std::string& json_str, const ffi::Map& nodes); TVM_DEFINE_OBJECT_REF_METHODS(WeightJoint, BaseJoint, WeightJointNode); }; @@ -744,13 +748,13 @@ class WeightJoint : public BaseJoint { class BaseGraphNode : public Object { public: /*! \brief The name of graph. */ - String name; + ffi::String name; /*! \brief The node names in graph, can be changed. */ - Array node_names; + ffi::Array node_names; /*! \brief The nodes in graph, can be changed. */ - Map nodes; + ffi::Map nodes; /*! \brief Check if node in the graph. */ - const bool HasNode(const String& name) const; + const bool HasNode(const ffi::String& name) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -783,17 +787,17 @@ class MSCGraph; class MSCGraphNode : public BaseGraphNode { public: /*! \brief The shape node names in graph. */ - Array prim_names; + ffi::Array prim_names; /*! \brief The shape nodes in graph. */ - Map prims; + ffi::Map prims; /*! \brief The input names of graph. */ - Array input_names; + ffi::Array input_names; /*! \brief The output names of graph. */ - Array output_names; + ffi::Array output_names; /*! \brief The tensor alias in graph, get by AnalysisGraph. */ - mutable Map tensor_alias; + mutable ffi::Map tensor_alias; /*! \brief The weights in graph, get by AnalysisGraph. */ - Map> weight_holders; + ffi::Map> weight_holders; /*! \brief Export graph to json. */ const JsonMSCGraph ToJson() const; /*! \brief Load graph from json. */ @@ -801,41 +805,42 @@ class MSCGraphNode : public BaseGraphNode { /*! \brief Load graph from json string. */ void FromJson(const std::string& json_str); /*! \brief Export graph to prototxt. */ - const String ToPrototxt() const; + const ffi::String ToPrototxt() const; /*! \brief Find node in graph. */ - const MSCJoint FindNode(const String& name) const; + const MSCJoint FindNode(const ffi::String& name) const; /*! \brief Find prim in graph. */ - const MSCPrim FindPrim(const String& name) const; + const MSCPrim FindPrim(const ffi::String& name) const; /*! \brief Get input from the graph. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the graph. */ - const Array GetInputs() const; + const ffi::Array GetInputs() const; /*! \brief Get output from the graph. */ const MSCTensor OutputAt(int index) const; /*! \brief Get outputs from the graph. */ - const Array GetOutputs() const; + const ffi::Array GetOutputs() const; /*! \brief Get entries from the graph. */ - const Array GetEntries() const; + const ffi::Array GetEntries() const; /*! \brief Get exits from the graph. */ - const Array GetExits() const; + const ffi::Array GetExits() const; /*! \brief Check if tensor in the graph. */ - const bool HasTensor(const String& name) const; + const bool HasTensor(const ffi::String& name) const; /*! \brief Find tensor from the graph. */ - const MSCTensor FindTensor(const String& name) const; + const MSCTensor FindTensor(const ffi::String& name) const; /*! \brief Find producer of tensor from the graph. */ - const MSCJoint FindProducer(const String& name) const; + const MSCJoint FindProducer(const ffi::String& name) const; /*! \brief Find producer of tensor from the graph. */ const MSCJoint FindProducer(const MSCTensor& tensor) const; /*! \brief Find producer and output index of tensor from the graph. */ - const std::pair FindProducerAndIdx(const String& name) const; + const std::pair FindProducerAndIdx(const ffi::String& name) const; /*! \brief Find producer and output index of tensor from the graph. */ const std::pair FindProducerAndIdx(const MSCTensor& tensor) const; /*! \brief Find consumers of tensor from the graph. */ - const Array FindConsumers(const String& name) const; + const ffi::Array FindConsumers(const ffi::String& name) const; /*! \brief Find consumers of tensor from the graph. */ - const Array FindConsumers(const MSCTensor& tensor) const; + const ffi::Array FindConsumers(const MSCTensor& tensor) const; /*! \brief Find consumers and input indices of tensor from the graph. */ - const std::vector> FindConsumersAndIndices(const String& name) const; + const std::vector> FindConsumersAndIndices( + const ffi::String& name) const; /*! \brief Find consumers and input indices of tensor from the graph. */ const std::vector> FindConsumersAndIndices( const MSCTensor& tensor) const; @@ -870,9 +875,10 @@ class MSCGraph : public BaseGraph { * \param output_names The output names of the graph. * \param prims The prims in the graph. */ - TVM_DLL MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names, - const Array& prims = Array()); + TVM_DLL MSCGraph(const ffi::String& name, const ffi::Array& nodes, + const ffi::Array& input_names, + const ffi::Array& output_names, + const ffi::Array& prims = ffi::Array()); /*! * \brief The json constructor. @@ -895,10 +901,11 @@ class MSCGraph : public BaseGraph { class WeightGraphNode : public BaseGraphNode { public: /*! \brief build from MSCGraph. */ - void Build(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types); + void Build(const MSCGraph& graph, + const ffi::Map>& prunable_types, + const ffi::Map& relation_types); /*! \brief Find node in graph. */ - const WeightJoint FindNode(const String& name) const; + const WeightJoint FindNode(const ffi::String& name) const; /*! \brief Export graph to json. */ const JsonWeightGraph ToJson() const; /*! \brief Load graph from json. */ @@ -906,7 +913,7 @@ class WeightGraphNode : public BaseGraphNode { /*! \brief Load graph from json string. */ void FromJson(const std::string& json_str); /*! \brief Export graph to prototxt. */ - const String ToPrototxt() const; + const ffi::String ToPrototxt() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -929,8 +936,9 @@ class WeightGraph : public BaseGraph { * \param prunable_types The prunable types. * \param relation_types The relation types. */ - TVM_DLL WeightGraph(const MSCGraph& graph, const Map>& prunable_types, - const Map& relation_types); + TVM_DLL WeightGraph(const MSCGraph& graph, + const ffi::Map>& prunable_types, + const ffi::Map& relation_types); /*! * \brief The json constructor. @@ -947,7 +955,8 @@ class WeightGraph : public BaseGraph { TVM_DEFINE_OBJECT_REF_METHODS(WeightGraph, BaseGraph, WeightGraphNode); }; -MSCGraph PruneWeights(const MSCGraph& graph, const Map& pruned_tensors); +MSCGraph PruneWeights(const MSCGraph& graph, + const ffi::Map& pruned_tensors); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 00176fb2ca0f..67770a21f27a 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -50,13 +50,13 @@ const std::string GetScalarStr(const runtime::Tensor& data, int float_precision) void FuncAttrGetter::VisitExpr_(const CallNode* op) { if (op->attrs.defined()) { - Map attrs; + ffi::Map attrs; AttrGetter getter(&attrs); getter(op->attrs); for (const auto& pair : attrs) { if (attrs_.count(pair.first)) { int cnt = 1; - String rep_key = pair.first; + ffi::String rep_key = pair.first; while (attrs_.count(rep_key + "_" + std::to_string(cnt))) { cnt++; } @@ -87,7 +87,7 @@ void FuncValueGetter::VisitExpr_(const CallNode* op) { } void FuncParamsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { @@ -112,7 +112,7 @@ void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { } void LayoutsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void LayoutsFinder::VisitExpr_(const CallNode* call_node) { @@ -126,7 +126,8 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { func = local_funcs_[call_node->op]; } if (func.defined()) { - const auto& layouts_opt = func->GetAttr>(msc_attr::kInputLayouts); + const auto& layouts_opt = + func->GetAttr>(msc_attr::kInputLayouts); if (layouts_opt.defined()) { for (const auto& pair : layouts_opt.value()) { layouts_.Set(pair.first, pair.second); @@ -137,8 +138,8 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { const MSCGraph GraphBuilder::Build(const Function& func) { // Add input nodes and record inputs; - Array input_names, output_names; - std::set added_inputs; + ffi::Array input_names, output_names; + std::set added_inputs; // Add prims for (const auto& p : func->params) { if (!p->struct_info_.defined()) { @@ -148,11 +149,11 @@ const MSCGraph GraphBuilder::Build(const Function& func) { const auto& shape = ExprUtils::GetShape(p, false); for (size_t i = 0; i < shape.size(); i++) { if (shape[i]->IsInstance()) { - Map attrs; + ffi::Map attrs; attrs.Set("producer", p->name_hint()); attrs.Set("out_idx", "0"); attrs.Set("dim", std::to_string(i)); - MatchOrCreatePrim(shape[i], "shape", Array(), attrs); + MatchOrCreatePrim(shape[i], "shape", ffi::Array(), attrs); } } } else { @@ -169,7 +170,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } if (func_params_.count(p) && func_params_[p]->IsInstance()) { const auto& tuple = Downcast(func_params_[p]); - Array tuple_names; + ffi::Array tuple_names; for (const auto& f : tuple->fields) { if (expr_tensor_map_.count(f)) { LOG_INFO << "Replica tuple input " << f; @@ -200,8 +201,8 @@ const MSCGraph GraphBuilder::Build(const Function& func) { << "Can not find seqexpr body " << func->body->body; output_names = expr_tensor_map_[func->body->body]; // remove const nodes as weights - Array valid_nodes; - std::set ignore_inputs; + ffi::Array valid_nodes; + std::set ignore_inputs; for (const auto& n : nodes_) { if (weights_.count(n->name) || ignore_nodes_.count(n->name)) { for (const auto& o : n->outputs) { @@ -218,7 +219,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } } // remove uselese inputs - Array valid_inputs; + ffi::Array valid_inputs; for (const auto& i : input_names) { if (!ignore_inputs.count(i)) { valid_inputs.push_back(i); @@ -255,12 +256,12 @@ const MSCGraph GraphBuilder::Build(const Function& func) { return graph; } -const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& binding_var, - const String& name) { +const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional& binding_var, + const ffi::String& name) { // Get optype, node_name and layout - String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); - String optype = "unknown"; - String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + ffi::String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); + ffi::String optype = "unknown"; + ffi::String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName); optype = "constant"; @@ -318,11 +319,12 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin(); // Extract normal attributes - Map attrs; + ffi::Map attrs; if (plugin.defined()) { const auto& op = Downcast(expr)->op; if (target_funcs_.count(op)) { - const auto& opattrs_opt = target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); + const auto& opattrs_opt = + target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); if (opattrs_opt.defined()) { const auto& opattrs = opattrs_opt.value(); ICHECK_EQ(opattrs.size(), plugin->attrs.size()) @@ -341,7 +343,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } else if (const auto* call_node = expr.as()) { if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& name_opt = func->GetAttr(relax::attr::kComposite); + const auto& name_opt = func->GetAttr(relax::attr::kComposite); if (name_opt.has_value()) { attrs = FuncAttrGetter().GetAttrs(func); } @@ -365,10 +367,10 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Extract attributes from arguments - Array input_types; + ffi::Array input_types; if (!plugin.defined() && expr->IsInstance()) { const auto& call = Downcast(expr); - Array values; + ffi::Array values; if (call->op->IsInstance()) { ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; values = FuncValueGetter().GetValues(target_funcs_[call->op]); @@ -396,8 +398,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build inputs and weights - Array input_names; - Map node_weights; + ffi::Array input_names; + ffi::Map node_weights; if (plugin.defined()) { const auto& call = Downcast(expr); if (call->args.size() == 1) { @@ -419,7 +421,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin continue; } const auto& arg = call_node->args[i]; - Array arg_names; + ffi::Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; } else if (input_types[i] == "input" && arg->IsInstance()) { @@ -431,7 +433,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } } } - String weight_name; + ffi::String weight_name; if (input_types[i] != "input" && arg->IsInstance()) { weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); } else if (input_types[i] != "input" && func_params_.count(arg) && @@ -448,12 +450,12 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& ref = producer->OutputAt(pair.second); MSCTensor weight; if (input_types[i] == "bias") { - weight = MSCTensor(weight_name, ref->dtype, "O", Array{ref->GetSize()}); + weight = MSCTensor(weight_name, ref->dtype, "O", ffi::Array{ref->GetSize()}); } else if (input_types[i] == "weight" && (optype == "msc.linear" || optype == "msc.linear_bias")) { if (ref->layout.name() == "IO") { - String valid_layout = ref->layout[1].name() + ref->layout[0].name(); - const auto& valid_shape = Array({ref->shape[1], ref->shape[0]}); + ffi::String valid_layout = ref->layout[1].name() + ref->layout[0].name(); + const auto& valid_shape = ffi::Array({ref->shape[1], ref->shape[0]}); weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape); } else { weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); @@ -512,13 +514,13 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build output tensor - auto build_output = [this](const StructInfo& sinfo, const String& node_name, - const String& layout) { + auto build_output = [this](const StructInfo& sinfo, const ffi::String& node_name, + const ffi::String& layout) { ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); const auto& t_info = Downcast(sinfo); const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); - Array prims; + ffi::Array prims; bool has_prims = false; if (shape.size() > 0) { for (const auto& s : t_info->GetShape().value()) { @@ -537,15 +539,15 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin }; // Gather outputs - Array outputs; + ffi::Array outputs; const auto& sinfo = GetStructInfo(expr); - Array layouts = StringUtils::Split(layout, ","); + ffi::Array layouts = StringUtils::Split(layout, ","); size_t num_output = 1; if (const auto* tuple_sinfo = sinfo.as()) { num_output = tuple_sinfo->fields.size(); } if (layouts.size() == 0) { - layouts = Array(num_output, ""); + layouts = ffi::Array(num_output, ""); } ICHECK_EQ(layouts.size(), num_output) << "Layouts " << layouts << " msimatch with output size " << num_output; @@ -553,7 +555,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin const auto& t_name = node_name + ":" + std::to_string(0); outputs.push_back(build_output(sinfo, t_name, layouts[0])); } else if (const auto* s_sinfo = sinfo.as()) { - Array shape{s_sinfo->ndim}; + ffi::Array shape{s_sinfo->ndim}; const auto& t_name = node_name + ":" + std::to_string(0); const auto& dtype = DataType(ffi::StringToDLDataType("int32")); outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); @@ -568,14 +570,14 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } // Build node - Array scope; + ffi::Array scope; if (optype != "input" && optype != "constant") { scope = StringUtils::Split(scope_name_, "."); } const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs, outputs, node_weights); - Array output_names; + ffi::Array output_names; for (size_t i = 0; i < outputs.size(); i++) { output_names.push_back(outputs[i]->name); tensor_input_map_[outputs[i]->name] = std::make_pair(node, i); @@ -587,11 +589,11 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } void GraphBuilder::VisitBindingBlock(const BindingBlock& block) { - String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); + ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; } - const String& prefix = StringUtils::Join(block_stack_, "."); + const ffi::String& prefix = StringUtils::Join(block_stack_, "."); if (setted_blocks_.count(prefix + "." + block_name)) { int cnt = 1; while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { @@ -638,15 +640,15 @@ const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { // scalar if (prim->IsInstance()) { - Map attrs; + ffi::Map attrs; attrs.Set("value", StringUtils::ToString(prim)); - return MatchOrCreatePrim(prim, "Int", Array(), attrs); + return MatchOrCreatePrim(prim, "Int", ffi::Array(), attrs); } // call if (const auto* c_node = prim.as()) { - String optype; - Array parents; + ffi::String optype; + ffi::Array parents; if (const auto* op_node = c_node->op.as()) { optype = StringUtils::Replace(op_node->name, "tir.", ""); } else { @@ -660,9 +662,9 @@ const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { return MatchOrCreatePrim(prim); } -const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, - const Array& parents, - const Map& attrs) { +const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const ffi::String& optype, + const ffi::Array& parents, + const ffi::Map& attrs) { if (prim_map_.count(prim)) { return prim_map_[prim]; } @@ -692,7 +694,7 @@ const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String prim_map_.Set(prim, p); return p; } - String name; + ffi::String name; if (const auto* v_node = prim.as()) { name = v_node->name_hint; } else { @@ -705,26 +707,26 @@ const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String } void GraphBuilder::VisitExpr_(const ConstantNode* op) { - if (!expr_tensor_map_.count(GetRef(op))) { - AddNode(GetRef(op)); + if (!expr_tensor_map_.count(ffi::GetRef(op))) { + AddNode(ffi::GetRef(op)); } } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { ExprVisitor::VisitBinding_(binding, call_node); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; try { - AddNode(GetRef(call_node), binding->var, name); + AddNode(ffi::GetRef(call_node), binding->var, name); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value << ", reason: " << err.what(); @@ -734,49 +736,50 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(ffi::GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const VarNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); + const auto& output = ffi::GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& output = GetRef(val); + const auto& output = ffi::GetRef(val); ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - const auto& name_opt = val->GetAttr(relax::attr::kComposite); + const auto& name_opt = val->GetAttr(relax::attr::kComposite); ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; - target_funcs_.Set(binding->var, GetRef(val)); + target_funcs_.Set(binding->var, ffi::GetRef(val)); } -const std::tuple GraphBuilder::ParseFunc(const Function& func) { - String node_name, optype, layout; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); +const std::tuple GraphBuilder::ParseFunc( + const Function& func) { + ffi::String node_name, optype, layout; + const auto& name_opt = func->GetAttr(msc_attr::kUnique); // get node_name if (name_opt.has_value()) { node_name = name_opt.value(); } // get optype - const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - const auto& composite_opt = func->GetAttr(relax::attr::kComposite); + const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& composite_opt = func->GetAttr(relax::attr::kComposite); if (codegen_opt.has_value()) { optype = codegen_opt.value(); } else if (optype_opt.has_value()) { @@ -788,7 +791,7 @@ const std::tuple GraphBuilder::ParseFunc(const Function& } } // get layout - const auto& layout_opt = func->GetAttr(msc_attr::kLayout); + const auto& layout_opt = func->GetAttr(msc_attr::kLayout); if (layout_opt.has_value()) { layout = layout_opt.value(); } @@ -802,14 +805,14 @@ void GraphBuilder::VisitPrimExpr(const PrimExpr& prim) { } } -Array GraphBuilder::GetPluginInputs(const Expr& expr) { +ffi::Array GraphBuilder::GetPluginInputs(const Expr& expr) { ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; return Downcast(call->args[1])->fields; } -Map WeightsExtractor::GetWeights(const Function& func) { +ffi::Map WeightsExtractor::GetWeights(const Function& func) { VisitExpr(func); return weights_; } @@ -817,13 +820,13 @@ Map WeightsExtractor::GetWeights(const Function& func) { void WeightsExtractor::VisitExpr_(const ConstantNode* op) { const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); - const auto& sinfo = GetStructInfo(GetRef(op)); + const auto& sinfo = GetStructInfo(ffi::GetRef(op)); ICHECK(sinfo->IsInstance()) << "Constant StrcutInfo should be TensorStructInfo"; const auto& t_info = Downcast(sinfo); const auto& opt_shape = t_info->GetShape(); const auto& shape = - opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : Array(); + opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : ffi::Array(); const auto& weight = MSCTensor(name, t_info->dtype, layout, shape); weights_.Set(weight, op->data); } @@ -840,19 +843,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.BuildFromRelax", - [](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { + [](const IRModule& module, const ffi::String& entry_name, + const ffi::String& options) -> MSCGraph { auto builder = GraphBuilder(module, entry_name, options); const auto& func_name = builder.config().byoc_entry.size() > 0 - ? String(builder.config().byoc_entry) + ? ffi::String(builder.config().byoc_entry) : entry_name; const auto& func = Downcast(module->Lookup(func_name)); return builder.Build(func); }) - .def("msc.core.GetRelaxWeights", - [](const IRModule& module, const String& entry_name) -> Map { - const auto& func = Downcast(module->Lookup(entry_name)); - return WeightsExtractor(module).GetWeights(func); - }); + .def( + "msc.core.GetRelaxWeights", + [](const IRModule& module, const ffi::String& entry_name) -> ffi::Map { + const auto& func = Downcast(module->Lookup(entry_name)); + return WeightsExtractor(module).GetWeights(func); + }); }); } // namespace msc diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 79c4048304cf..22a4929fe12f 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -110,10 +110,10 @@ struct MSCRBuildConfig { class AttrGetter { public: /*! - * \brief Get the attributes as Map + * \brief Get the attributes as ffi::Map * \param attrs the attributes. */ - explicit AttrGetter(Map* attrs) : attrs_(attrs) {} + explicit AttrGetter(ffi::Map* attrs) : attrs_(attrs) {} void operator()(const Attrs& attrs) { if (const auto* dict_attrs = attrs.as()) { @@ -125,14 +125,14 @@ class AttrGetter { if (attrs_tinfo->metadata != nullptr) { tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); - this->VisitAny(String(field_info->name), field_value); + this->VisitAny(ffi::String(field_info->name), field_value); }); } } } private: - void VisitAny(String key, Any value) { + void VisitAny(ffi::String key, Any value) { switch (value.type_index()) { case kTVMFFINone: { attrs_->Set(key, ""); @@ -156,7 +156,7 @@ class AttrGetter { } case kTVMFFISmallStr: case kTVMFFIStr: { - attrs_->Set(key, value.cast()); + attrs_->Set(key, value.cast()); break; } default: { @@ -171,13 +171,13 @@ class AttrGetter { } private: - Map* attrs_; + ffi::Map* attrs_; }; class FuncAttrGetter : public ExprVisitor { public: - /*! \brief Get the attributes as Map*/ - Map GetAttrs(const Expr& expr) { + /*! \brief Get the attributes as ffi::Map*/ + ffi::Map GetAttrs(const Expr& expr) { VisitExpr(expr); return attrs_; } @@ -187,13 +187,13 @@ class FuncAttrGetter : public ExprVisitor { void VisitExpr_(const TupleGetItemNode* op) final; private: - Map attrs_; + ffi::Map attrs_; }; class FuncValueGetter : public ExprVisitor { public: - /*! \brief Get the attributes from prim value as Map*/ - Array GetValues(const Expr& expr) { + /*! \brief Get the attributes from prim value as ffi::Map*/ + ffi::Array GetValues(const Expr& expr) { VisitExpr(expr); return values_; } @@ -201,7 +201,7 @@ class FuncValueGetter : public ExprVisitor { void VisitExpr_(const CallNode* op) final; private: - Array values_; + ffi::Array values_; }; class FuncParamsFinder : public ExprVisitor { @@ -215,7 +215,7 @@ class FuncParamsFinder : public ExprVisitor { } /*! \brief Find the func params and bind with arguments*/ - Map FindParams(const Expr& expr) { + ffi::Map FindParams(const Expr& expr) { VisitExpr(expr); return params_; } @@ -226,8 +226,8 @@ class FuncParamsFinder : public ExprVisitor { private: IRModule ref_module_; - Map params_; - Map local_funcs_; + ffi::Map params_; + ffi::Map local_funcs_; }; class LayoutsFinder : public ExprVisitor { @@ -239,7 +239,7 @@ class LayoutsFinder : public ExprVisitor { explicit LayoutsFinder(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } /*! \brief Find the layouts form attrs*/ - Map FindLayouts(const Expr& expr) { + ffi::Map FindLayouts(const Expr& expr) { VisitExpr(expr); return layouts_; } @@ -250,8 +250,8 @@ class LayoutsFinder : public ExprVisitor { private: IRModule ref_module_; - Map layouts_; - Map local_funcs_; + ffi::Map layouts_; + ffi::Map local_funcs_; }; class GraphBuilder : public ExprVisitor { @@ -262,7 +262,7 @@ class GraphBuilder : public ExprVisitor { * \param name the name of the graph. * \param options the options of build the graph. */ - explicit GraphBuilder(const IRModule& ref_module, const String& name, + explicit GraphBuilder(const IRModule& ref_module, const ffi::String& name, const std::string& options = "") : ExprVisitor() { ref_module_ = ref_module; @@ -271,7 +271,7 @@ class GraphBuilder : public ExprVisitor { dmlc::JSONReader reader(&is); reader.Read(&config_); } - name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; + name_ = config_.graph_name.size() > 0 ? ffi::String(config_.graph_name) : name; if (config_.byoc_entry.size() > 0) { func_params_ = FuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); } @@ -285,15 +285,16 @@ class GraphBuilder : public ExprVisitor { const MSCRBuildConfig config() { return config_; } /*! \brief Create and add MSCJoint from expr*/ - const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = std::nullopt, - const String& name = ""); + const MSCJoint AddNode(const Expr& expr, const ffi::Optional& binding_var = std::nullopt, + const ffi::String& name = ""); /*! \brief Create and add MSCPrim from prim*/ const MSCPrim AddPrim(const PrimExpr& prim); - const MSCPrim MatchOrCreatePrim(const PrimExpr& prim, const String& op = "", - const Array& parents = Array(), - const Map& attrs = Map()); + const MSCPrim MatchOrCreatePrim( + const PrimExpr& prim, const ffi::String& op = "", + const ffi::Array& parents = ffi::Array(), + const ffi::Map& attrs = ffi::Map()); void VisitBindingBlock(const BindingBlock& block) final; @@ -319,30 +320,30 @@ class GraphBuilder : public ExprVisitor { private: /*! \brief Get the node_name, optype, layout for func*/ - const std::tuple ParseFunc(const Function& func); + const std::tuple ParseFunc(const Function& func); /*! \brief Get the plugin inputs*/ - Array GetPluginInputs(const Expr& expr); + ffi::Array GetPluginInputs(const Expr& expr); - String name_; + ffi::String name_; IRModule ref_module_; MSCRBuildConfig config_; - Map layouts_; - Array nodes_; - Map weights_; - Map> expr_tensor_map_; - std::unordered_map> tensor_input_map_; - std::set ignore_nodes_; + ffi::Map layouts_; + ffi::Array nodes_; + ffi::Map weights_; + ffi::Map> expr_tensor_map_; + std::unordered_map> tensor_input_map_; + std::set ignore_nodes_; // scope name - String scope_name_; - std::set setted_blocks_; - Array block_stack_; + ffi::String scope_name_; + std::set setted_blocks_; + ffi::Array block_stack_; // BYOC maps - Map target_funcs_; - Map func_params_; + ffi::Map target_funcs_; + ffi::Map func_params_; // prims - Array prims_; - Map prim_map_; + ffi::Array prims_; + ffi::Map prim_map_; }; class WeightsExtractor : public ExprVisitor { @@ -358,15 +359,15 @@ class WeightsExtractor : public ExprVisitor { } /*! \brief Visit the constant and save weights */ - Map GetWeights(const Function& func); + ffi::Map GetWeights(const Function& func); void VisitExpr_(const ConstantNode* op) final; void VisitExpr_(const CallNode* op) final; private: - Map weights_; - Map local_funcs_; + ffi::Map weights_; + ffi::Map local_funcs_; IRModule ref_module_; }; diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index 659cb29628e7..3c143b03ea18 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -35,9 +35,9 @@ namespace tvm { namespace contrib { namespace msc { -PluginAttr::PluginAttr(const String& name, const String& type, const String& default_value, - const String& describe) { - ObjectPtr n = make_object(); +PluginAttr::PluginAttr(const ffi::String& name, const ffi::String& type, + const ffi::String& default_value, const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->type = std::move(type); n->default_value = std::move(default_value); @@ -46,13 +46,13 @@ PluginAttr::PluginAttr(const String& name, const String& type, const String& def } PluginAttr::PluginAttr(const JsonPluginAttr& j_attr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_attr); data_ = std::move(n); } PluginAttr::PluginAttr(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -81,9 +81,9 @@ void PluginAttrNode::FromJson(const std::string& json_str) { FromJson(j_attr); } -PluginTensor::PluginTensor(const String& name, const String& dtype, const Integer& ndim, - const String& device, const String& describe) { - ObjectPtr n = make_object(); +PluginTensor::PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, + const ffi::String& device, const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->dtype = std::move(dtype); n->ndim = std::move(ndim); @@ -93,13 +93,13 @@ PluginTensor::PluginTensor(const String& name, const String& dtype, const Intege } PluginTensor::PluginTensor(const JsonPluginTensor& j_tensor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_tensor); data_ = std::move(n); } PluginTensor::PluginTensor(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -130,9 +130,10 @@ void PluginTensorNode::FromJson(const std::string& json_str) { FromJson(j_tensor); } -PluginExtern::PluginExtern(const String& name, const String& header, const String& source, - const String& lib, const String& describe) { - ObjectPtr n = make_object(); +PluginExtern::PluginExtern(const ffi::String& name, const ffi::String& header, + const ffi::String& source, const ffi::String& lib, + const ffi::String& describe) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->header = std::move(header); n->source = std::move(source); @@ -142,13 +143,13 @@ PluginExtern::PluginExtern(const String& name, const String& header, const Strin } PluginExtern::PluginExtern(const JsonPluginExtern& j_extern) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_extern); data_ = std::move(n); } PluginExtern::PluginExtern(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -179,13 +180,13 @@ void PluginExternNode::FromJson(const std::string& json_str) { FromJson(j_extern); } -Plugin::Plugin(const String& name, const String& version, const String& describe, - const Array& attrs, const Array& inputs, - const Array& outputs, const Array& buffers, - const Map& externs, - const Map>& support_dtypes, - const Map& options) { - ObjectPtr n = make_object(); +Plugin::Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, + const ffi::Array& attrs, const ffi::Array& inputs, + const ffi::Array& outputs, const ffi::Array& buffers, + const ffi::Map& externs, + const ffi::Map>& support_dtypes, + const ffi::Map& options) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->version = std::move(version); n->describe = std::move(describe); @@ -200,13 +201,13 @@ Plugin::Plugin(const String& name, const String& version, const String& describe } Plugin::Plugin(const JsonPlugin& j_plugin) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(j_plugin); data_ = std::move(n); } Plugin::Plugin(const std::string& json_str) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->FromJson(json_str); data_ = std::move(n); } @@ -264,7 +265,7 @@ void PluginNode::FromJson(const JsonPlugin& j_plugin) { externs.Set(pair.first, PluginExtern(pair.second)); } for (const auto& pair : j_plugin.support_dtypes) { - Array dtypes; + ffi::Array dtypes; for (const auto& d : pair.second) { dtypes.push_back(d); } @@ -301,11 +302,11 @@ int PluginNode::FindDeviceRefIdx(const PluginTensor& tensor) const { return -1; } -const Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } +const ffi::Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } -const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Get(name); } +const Plugin GetPlugin(const ffi::String& name) { return PluginRegistry::Global()->Get(name); } -bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } +bool IsPlugin(const ffi::String& name) { return PluginRegistry::Global()->Registered(name); } TVM_FFI_STATIC_INIT_BLOCK({ PluginAttrNode::RegisterReflection(); @@ -318,12 +319,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.core.RegisterPlugin", - [](const String& name, const String& json_str) { + [](const ffi::String& name, const ffi::String& json_str) { PluginRegistry::Global()->Register(name, json_str); }) - .def("msc.core.ListPluginNames", []() -> Array { return ListPluginNames(); }) - .def("msc.core.GetPlugin", [](const String& name) -> Plugin { return GetPlugin(name); }) - .def("msc.core.IsPlugin", [](const String& name) -> Bool { return Bool(IsPlugin(name)); }); + .def("msc.core.ListPluginNames", + []() -> ffi::Array { return ListPluginNames(); }) + .def("msc.core.GetPlugin", [](const ffi::String& name) -> Plugin { return GetPlugin(name); }) + .def("msc.core.IsPlugin", + [](const ffi::String& name) -> Bool { return Bool(IsPlugin(name)); }); }); } // namespace msc diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index f0a5dc9937b8..2d8b429959a3 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -254,13 +254,13 @@ struct JsonPlugin { class PluginAttrNode : public Object { public: /*! \brief The name of attribute. */ - String name; + ffi::String name; /*! \brief The type of attribute. */ - String type; + ffi::String type; /*! \brief The default_value of attribute. */ - String default_value; + ffi::String default_value; /*! \brief The describe of attribute. */ - String describe; + ffi::String describe; /*! \brief Export attribute to json. */ const JsonPluginAttr ToJson() const; @@ -296,8 +296,8 @@ class PluginAttr : public ObjectRef { * \param default_value The default_value of the attribute. * \param describe The describe of the attribute. */ - TVM_DLL PluginAttr(const String& name, const String& type, const String& default_value, - const String& describe); + TVM_DLL PluginAttr(const ffi::String& name, const ffi::String& type, + const ffi::String& default_value, const ffi::String& describe); /*! * \brief The json constructor. @@ -320,15 +320,15 @@ class PluginAttr : public ObjectRef { class PluginTensorNode : public Object { public: /*! \brief The name of tensor. */ - String name; + ffi::String name; /*! \brief The dtype of tensor. */ - String dtype; + ffi::String dtype; /*! \brief The ndim of tensor. */ Integer ndim; /*! \brief The device of tensor. */ - String device; + ffi::String device; /*! \brief The describe of tensor. */ - String describe; + ffi::String describe; /*! \brief Export tensor to json. */ const JsonPluginTensor ToJson() const; @@ -366,8 +366,8 @@ class PluginTensor : public ObjectRef { * \param device The device of the tensor. * \param describe The describe of the tensor. */ - TVM_DLL PluginTensor(const String& name, const String& dtype, const Integer& ndim, - const String& device, const String& describe); + TVM_DLL PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, + const ffi::String& device, const ffi::String& describe); /*! * \brief The json constructor. @@ -390,15 +390,15 @@ class PluginTensor : public ObjectRef { class PluginExternNode : public Object { public: /*! \brief The name of extern. */ - String name; + ffi::String name; /*! \brief The header of extern. */ - String header; + ffi::String header; /*! \brief The source of extern. */ - String source; + ffi::String source; /*! \brief The lib of extern. */ - String lib; + ffi::String lib; /*! \brief The describe of extern. */ - String describe; + ffi::String describe; /*! \brief Export extern to json. */ const JsonPluginExtern ToJson() const; @@ -436,8 +436,9 @@ class PluginExtern : public ObjectRef { * \param lib The lib of the extern. * \param describe The describe of the extern. */ - TVM_DLL PluginExtern(const String& name, const String& header, const String& source, - const String& lib, const String& describe); + TVM_DLL PluginExtern(const ffi::String& name, const ffi::String& header, + const ffi::String& source, const ffi::String& lib, + const ffi::String& describe); /*! * \brief The json constructor. @@ -460,25 +461,25 @@ class PluginExtern : public ObjectRef { class PluginNode : public Object { public: /*! \brief The name of plugin. */ - String name; + ffi::String name; /*! \brief The version of plugin. */ - String version; + ffi::String version; /*! \brief The describe of plugin. */ - String describe; + ffi::String describe; /*! \brief The attributes of plugin. */ - Array attrs; + ffi::Array attrs; /*! \brief The inputs of plugin. */ - Array inputs; + ffi::Array inputs; /*! \brief The outputs of plugin. */ - Array outputs; + ffi::Array outputs; /*! \brief The buffers of plugin. */ - Array buffers; + ffi::Array buffers; /*! \brief The externs of plugin. */ - Map externs; + ffi::Map externs; /*! \brief The support_dtypes of plugin. */ - Map> support_dtypes; + ffi::Map> support_dtypes; /*! \brief The options of plugin. */ - Map options; + ffi::Map options; /*! \brief Export plugin to json. */ const JsonPlugin ToJson() const; @@ -531,12 +532,12 @@ class Plugin : public ObjectRef { * \param support_dtypes The support_dtypes of the plugin. * \param options The options of the plugin. */ - TVM_DLL Plugin(const String& name, const String& version, const String& describe, - const Array& attrs, const Array& inputs, - const Array& outputs, const Array& buffers, - const Map& externs, - const Map>& support_dtypes, - const Map& options); + TVM_DLL Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, + const ffi::Array& attrs, const ffi::Array& inputs, + const ffi::Array& outputs, const ffi::Array& buffers, + const ffi::Map& externs, + const ffi::Map>& support_dtypes, + const ffi::Map& options); /*! * \brief The json constructor. @@ -561,7 +562,7 @@ class PluginRegistry { * \param json_str The json_str. * \return The corresponding entry. */ - bool Register(const String& name, const String& json_str) { + bool Register(const ffi::String& name, const ffi::String& json_str) { plugin_map_[name] = Plugin(json_str); return true; } @@ -571,7 +572,7 @@ class PluginRegistry { * \param name The name of the item. * \return Whether the plugin is registered. */ - bool Registered(const String& name) const { + bool Registered(const ffi::String& name) const { auto it = plugin_map_.find(name); return it != plugin_map_.end(); } @@ -581,7 +582,7 @@ class PluginRegistry { * \param name The name of the item. * \return The corresponding plugin. */ - const Plugin Get(const String& name) const { + const Plugin Get(const ffi::String& name) const { auto it = plugin_map_.find(name); ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; return it->second; @@ -591,8 +592,8 @@ class PluginRegistry { * \brief List all the plugin names in the registry. * \return The plugin names. */ - Array ListAllNames() const { - Array names; + ffi::Array ListAllNames() const { + ffi::Array names; for (const auto& kv : plugin_map_) { names.push_back(kv.first); } @@ -609,28 +610,28 @@ class PluginRegistry { private: // map from name to plugins. - std::unordered_map plugin_map_; + std::unordered_map plugin_map_; }; /*! * \brief List all plugin names. * \return the corresponding plugin names. */ -const Array ListPluginNames(); +const ffi::Array ListPluginNames(); /*! * \brief Get the registered plugin. * \param name The name of the Plugin. * \return the corresponding plugin. */ -const Plugin GetPlugin(const String& name); +const Plugin GetPlugin(const ffi::String& name); /*! * \brief Check if an plugin is registered. * \param name The name of the item. * \return Whether the plugin is registered. */ -bool IsPlugin(const String& name); +bool IsPlugin(const ffi::String& name); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc index 1f0fdb11778a..8c2a512a6d86 100644 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -348,7 +348,7 @@ bool CppPrinter::IsEmptyDoc(const ExprDoc& doc) { return id_doc->name == DocSymbol::Empty(); } -void CppPrinter::PrintIndentedBlock(const Array& docs) { +void CppPrinter::PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { PrintDoc(d); diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h index bdd25acdebed..62e205a7c749 100644 --- a/src/contrib/msc/core/printer/cpp_printer.h +++ b/src/contrib/msc/core/printer/cpp_printer.h @@ -147,7 +147,7 @@ class CppPrinter : public MSCBasePrinter { bool IsEmptyDoc(const ExprDoc& doc); /*! \brief Print block with indent*/ - void PrintIndentedBlock(const Array& docs); + void PrintIndentedBlock(const ffi::Array& docs); }; } // namespace msc diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h index af369a530dae..10dafb54c2ac 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ b/src/contrib/msc/core/printer/msc_base_printer.h @@ -97,7 +97,7 @@ class MSCBasePrinter { * \brief Get the printed string of all Doc appended * \sa Append */ - String GetString() const { return output_.str(); } + ffi::String GetString() const { return output_.str(); } protected: /*! \brief Print doc*/ @@ -199,7 +199,7 @@ class MSCBasePrinter { /*! \brief Print docs to joined doc */ template - void PrintJoinedDocs(const Array& docs, const String& separator = ", ") { + void PrintJoinedDocs(const ffi::Array& docs, const ffi::String& separator = ", ") { for (size_t i = 0; i < docs.size(); i++) { PrintDoc(docs[i], false); output_ << (i == docs.size() - 1 ? "" : separator); diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc index b69e554ab9c4..40d1ada3b4d7 100644 --- a/src/contrib/msc/core/printer/msc_doc.cc +++ b/src/contrib/msc/core/printer/msc_doc.cc @@ -29,9 +29,9 @@ namespace tvm { namespace contrib { namespace msc { -DeclareDoc::DeclareDoc(Optional type, ExprDoc variable, Array init_args, +DeclareDoc::DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, bool use_constructor) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->type = type; n->variable = variable; n->init_args = init_args; @@ -40,45 +40,46 @@ DeclareDoc::DeclareDoc(Optional type, ExprDoc variable, Array } StrictListDoc::StrictListDoc(ListDoc list, bool allow_empty) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->list = list; n->allow_empty = allow_empty; this->data_ = std::move(n); } -PointerDoc::PointerDoc(String name) { - ObjectPtr n = make_object(); +PointerDoc::PointerDoc(ffi::String name) { + ObjectPtr n = ffi::make_object(); n->name = name; this->data_ = std::move(n); } -StructDoc::StructDoc(IdDoc name, Array decorators, Array body) { - ObjectPtr n = make_object(); +StructDoc::StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->decorators = decorators; n->body = body; this->data_ = std::move(n); } -ConstructorDoc::ConstructorDoc(IdDoc name, Array args, Array body) { - ObjectPtr n = make_object(); +ConstructorDoc::ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->body = body; this->data_ = std::move(n); } -SwitchDoc::SwitchDoc(Array predicates, Array> branchs, - Array default_branch) { - ObjectPtr n = make_object(); +SwitchDoc::SwitchDoc(ffi::Array predicates, ffi::Array> branchs, + ffi::Array default_branch) { + ObjectPtr n = ffi::make_object(); n->predicates = predicates; n->branchs = branchs; n->default_branch = default_branch; this->data_ = std::move(n); } -LambdaDoc::LambdaDoc(IdDoc name, Array args, Array refs, Array body) { - ObjectPtr n = make_object(); +LambdaDoc::LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, + ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->refs = refs; diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index ea13d74d569f..ea1cee396ba6 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -43,11 +43,11 @@ using namespace tvm::script::printer; class DeclareDocNode : public ExprDocNode { public: /*! \brief The type of the variable */ - Optional type; + ffi::Optional type; /*! \brief The variable */ ExprDoc variable{nullptr}; /*! \brief The init arguments for the variable. */ - Array init_args; + ffi::Array init_args; /*! \brief Whether to use constructor(otherwise initializer) */ bool use_constructor{true}; @@ -78,7 +78,7 @@ class DeclareDoc : public ExprDoc { * \param init_args The init arguments of the variable. * \param use_constructor Whether to use constructor(otherwise initializer). */ - explicit DeclareDoc(Optional type, ExprDoc variable, Array init_args, + explicit DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, bool use_constructor); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DeclareDoc, ExprDoc, DeclareDocNode); }; @@ -130,7 +130,7 @@ class StrictListDoc : public ExprDoc { class PointerDocNode : public ExprDocNode { public: /*! \brief The name of the identifier */ - String name; + ffi::String name; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -152,7 +152,7 @@ class PointerDoc : public ExprDoc { * \brief Constructor of PointerDoc. * \param name The name of identifier. */ - explicit PointerDoc(String name); + explicit PointerDoc(ffi::String name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PointerDoc, ExprDoc, PointerDocNode); }; @@ -166,9 +166,9 @@ class StructDocNode : public StmtDocNode { /*! \brief The name of class. */ IdDoc name{nullptr}; /*! \brief Decorators of class. */ - Array decorators; + ffi::Array decorators; /*! \brief The body of class. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -195,7 +195,7 @@ class StructDoc : public StmtDoc { * \param decorators The decorator of class. * \param body The body of class. */ - explicit StructDoc(IdDoc name, Array decorators, Array body); + explicit StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StructDoc, StmtDoc, StructDocNode); }; @@ -215,9 +215,9 @@ class ConstructorDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief The body of function. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -244,7 +244,7 @@ class ConstructorDoc : public StmtDoc { * \param args The arguments of function. * \param body The body of function. */ - explicit ConstructorDoc(IdDoc name, Array args, Array body); + explicit ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ConstructorDoc, StmtDoc, ConstructorDocNode); }; @@ -256,11 +256,11 @@ class ConstructorDoc : public StmtDoc { class SwitchDocNode : public StmtDocNode { public: /*! \brief The predicates of the switch statement. */ - Array predicates; + ffi::Array predicates; /*! \brief The branchs of the switch statement. */ - Array> branchs; + ffi::Array> branchs; /*! \brief The default_branch of the switch statement. */ - Array default_branch; + ffi::Array default_branch; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -287,8 +287,8 @@ class SwitchDoc : public StmtDoc { * \param branchs The branchs of the switch statement. * \param default_branch The default_branch of the switch statement. */ - explicit SwitchDoc(Array predicates, Array> branchs, - Array default_branch); + explicit SwitchDoc(ffi::Array predicates, ffi::Array> branchs, + ffi::Array default_branch); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SwitchDoc, StmtDoc, SwitchDocNode); }; @@ -308,11 +308,11 @@ class LambdaDocNode : public StmtDocNode { * `annotation` means argument type, * and `rhs` means default value. */ - Array args; + ffi::Array args; /*! \brief References of lambda. */ - Array refs; + ffi::Array refs; /*! \brief The body of lambda. */ - Array body; + ffi::Array body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -341,7 +341,8 @@ class LambdaDoc : public StmtDoc { * \param refs The references of lambda. * \param body The body of lambda. */ - explicit LambdaDoc(IdDoc name, Array args, Array refs, Array body); + explicit LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, + ffi::Array body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, StmtDoc, LambdaDocNode); }; diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 234ca3aec9c3..50d36df10bdb 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -28,9 +28,9 @@ namespace tvm { namespace contrib { namespace msc { -const String DocSymbol::Empty() { return "::EMPTY"; } +const ffi::String DocSymbol::Empty() { return "::EMPTY"; } -const String DocSymbol::NextLine() { return "::NEXT_LINE"; } +const ffi::String DocSymbol::NextLine() { return "::NEXT_LINE"; } const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, std::nullopt); } @@ -50,19 +50,19 @@ const ExprDoc DocUtils::ToDoc(const FloatImm& val) { return ToDoc(val->value); } const ExprDoc DocUtils::ToDoc(const char* val) { return IdDoc(std::string(val)); } -const ExprDoc DocUtils::ToDoc(const String& val) { return IdDoc(val); } +const ExprDoc DocUtils::ToDoc(const ffi::String& val) { return IdDoc(val); } const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, std::nullopt); } const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } -const ExprDoc DocUtils::ToStr(const String& val) { return LiteralDoc::Str(val, std::nullopt); } +const ExprDoc DocUtils::ToStr(const ffi::String& val) { return LiteralDoc::Str(val, std::nullopt); } -const PointerDoc DocUtils::ToPtr(const String& val) { return PointerDoc(val); } +const PointerDoc DocUtils::ToPtr(const ffi::String& val) { return PointerDoc(val); } const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { if (values.size() > 0 || allow_empty) { - Array elements; + ffi::Array elements; for (const auto& v : values) { elements.push_back(ToStr(v)); } @@ -71,7 +71,7 @@ const StrictListDoc DocUtils::ToStrList(const std::vector& values, return StrictListDoc(ListDoc(), false); } -const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { +const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -79,7 +79,7 @@ const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool return ToStrList(v_values, allow_empty); } -const StrictListDoc DocUtils::ToStrList(const Array& values, bool allow_empty) { +const StrictListDoc DocUtils::ToStrList(const ffi::Array& values, bool allow_empty) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -87,8 +87,8 @@ const StrictListDoc DocUtils::ToStrList(const Array& values, bool allow_ return ToStrList(v_values, allow_empty); } -const Array DocUtils::ToStmts(const Array& docs) { - Array stmts; +const ffi::Array DocUtils::ToStmts(const ffi::Array& docs) { + ffi::Array stmts; for (const auto& d : docs) { if (d->IsInstance()) { stmts.push_back(Downcast(d)); @@ -101,7 +101,7 @@ const Array DocUtils::ToStmts(const Array& docs) { return stmts; } -const StmtBlockDoc DocUtils::ToStmtBlock(const Array& docs) { +const StmtBlockDoc DocUtils::ToStmtBlock(const ffi::Array& docs) { return StmtBlockDoc(ToStmts(docs)); } diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h index b3949d54a762..3ccc1cdc22cc 100644 --- a/src/contrib/msc/core/printer/print_utils.h +++ b/src/contrib/msc/core/printer/print_utils.h @@ -44,10 +44,10 @@ using namespace tvm::script::printer; class DocSymbol { public: /*! * \brief The empty symbol*/ - TVM_DLL static const String Empty(); + TVM_DLL static const ffi::String Empty(); /*! * \brief The next line symbol*/ - TVM_DLL static const String NextLine(); + TVM_DLL static const ffi::String NextLine(); }; /*! @@ -68,30 +68,30 @@ class DocUtils { TVM_DLL static const ExprDoc ToDoc(double val); TVM_DLL static const ExprDoc ToDoc(const FloatImm& val); TVM_DLL static const ExprDoc ToDoc(const char* val); - TVM_DLL static const ExprDoc ToDoc(const String& val); + TVM_DLL static const ExprDoc ToDoc(const ffi::String& val); TVM_DLL static const ExprDoc ToDoc(bool val); TVM_DLL static const ExprDoc ToDoc(const ExprDoc& val); - TVM_DLL static const ExprDoc ToStr(const String& val); - TVM_DLL static const PointerDoc ToPtr(const String& val); + TVM_DLL static const ExprDoc ToStr(const ffi::String& val); + TVM_DLL static const PointerDoc ToPtr(const ffi::String& val); /*! * \brief Change object to DeclareDoc. * \return The DeclareDoc. */ template - TVM_DLL static const DeclareDoc ToDeclare(const String& type, const T& variable, size_t len = 0, - bool use_constructor = true) { - Optional type_doc; + TVM_DLL static const DeclareDoc ToDeclare(const ffi::String& type, const T& variable, + size_t len = 0, bool use_constructor = true) { + ffi::Optional type_doc; if (type.size() == 0) { type_doc = std::nullopt; } else { type_doc = IdDoc(type); } if (len == 0) { - return DeclareDoc(type_doc, ToDoc(variable), Array(), use_constructor); + return DeclareDoc(type_doc, ToDoc(variable), ffi::Array(), use_constructor); } - Array doc_indices{DocUtils::ToDoc(len)}; - return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), Array(), + ffi::Array doc_indices{DocUtils::ToDoc(len)}; + return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), ffi::Array(), use_constructor); } @@ -101,22 +101,22 @@ class DocUtils { */ template TVM_DLL static const AssignDoc ToAssign(const LT& lhs, const RT& rhs, - const String& annotation = "") { + const ffi::String& annotation = "") { if (annotation.size() == 0) { return AssignDoc(ToDoc(lhs), ToDoc(rhs), std::nullopt); } return AssignDoc(ToDoc(lhs), ToDoc(rhs), IdDoc(annotation)); } template - TVM_DLL static const AssignDoc ToAssign(const T& lhs, const String& rhs, - const String& annotation = "") { - Optional rhs_doc; + TVM_DLL static const AssignDoc ToAssign(const T& lhs, const ffi::String& rhs, + const ffi::String& annotation = "") { + ffi::Optional rhs_doc; if (rhs.size() > 0) { rhs_doc = IdDoc(rhs); } else { rhs_doc = std::nullopt; } - Optional annotation_doc; + ffi::Optional annotation_doc; if (annotation.size() > 0) { annotation_doc = IdDoc(annotation); } else { @@ -130,7 +130,7 @@ class DocUtils { * \return The AttrAccessDoc. */ template - TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const String& name) { + TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const ffi::String& name) { return AttrAccessDoc(ToDoc(value), name); } @@ -139,15 +139,15 @@ class DocUtils { * \return The List of Docs. */ template - TVM_DLL static const Array ToDocList(const std::vector& values) { - Array elements; + TVM_DLL static const ffi::Array ToDocList(const std::vector& values) { + ffi::Array elements; for (const auto& v : values) { elements.push_back(ToDoc(v)); } return elements; } template - TVM_DLL static const Array ToDocList(const Array& values) { + TVM_DLL static const ffi::Array ToDocList(const ffi::Array& values) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -168,7 +168,7 @@ class DocUtils { return StrictListDoc(ListDoc(), false); } template - TVM_DLL static const StrictListDoc ToList(const Array& values, bool allow_empty = false) { + TVM_DLL static const StrictListDoc ToList(const ffi::Array& values, bool allow_empty = false) { std::vector v_values; for (const auto& v : values) { v_values.push_back(v); @@ -182,9 +182,9 @@ class DocUtils { */ TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, + TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const Array& values, + TVM_DLL static const StrictListDoc ToStrList(const ffi::Array& values, bool allow_empty = false); /*! @@ -193,21 +193,21 @@ class DocUtils { */ template TVM_DLL static const IndexDoc ToIndex(const VT& value, const IT& index) { - Array doc_indices; + ffi::Array doc_indices; doc_indices.push_back(ToDoc(index)); return IndexDoc(ToDoc(value), doc_indices); } template TVM_DLL static const IndexDoc ToIndices(const VT& value, const std::vector& indices) { - Array doc_indices; + ffi::Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } return IndexDoc(ToDoc(value), doc_indices); } template - TVM_DLL static const IndexDoc ToIndices(const VT& value, const Array& indices) { - Array doc_indices; + TVM_DLL static const IndexDoc ToIndices(const VT& value, const ffi::Array& indices) { + ffi::Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } @@ -218,13 +218,13 @@ class DocUtils { * \brief Convert the docs to Stmts. * \return The Stmts. */ - TVM_DLL static const Array ToStmts(const Array& docs); + TVM_DLL static const ffi::Array ToStmts(const ffi::Array& docs); /*! * \brief Convert the docs to StmtBlock. * \return The StmtBlockDoc. */ - TVM_DLL static const StmtBlockDoc ToStmtBlock(const Array& docs); + TVM_DLL static const StmtBlockDoc ToStmtBlock(const ffi::Array& docs); }; } // namespace msc diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index d62e5ac2a8f6..ffaf035385f1 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -43,9 +43,9 @@ LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) { return LiteralDoc::Str(obj_des.str(), std::nullopt); } -DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { - Array keys; - Array values; +DictDoc PrototxtPrinter::ToDictDoc(const ffi::Map& dict) { + ffi::Array keys; + ffi::Array values; for (const auto& pair : dict) { keys.push_back(IdDoc(pair.first)); if (pair.second.as()) { @@ -57,9 +57,9 @@ DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { return DictDoc(keys, values); } -DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& dict) { - Array keys; - Array values; +DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& dict) { + ffi::Array keys; + ffi::Array values; for (const auto& pair : dict) { keys.push_back(IdDoc(pair.first)); if (pair.second.as()) { @@ -71,18 +71,18 @@ DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& di return DictDoc(keys, values); } -void PrototxtPrinter::Append(const Map& dict) { +void PrototxtPrinter::Append(const ffi::Map& dict) { DictDoc doc = ToDictDoc(dict); PrintDoc(doc, false); } -void PrototxtPrinter::Append(const std::vector>& dict) { +void PrototxtPrinter::Append(const std::vector>& dict) { DictDoc doc = ToDictDoc(dict); PrintDoc(doc, false); } -void PrototxtPrinter::AppendPair(const String& key, const ffi::Any& value) { - Map dict; +void PrototxtPrinter::AppendPair(const ffi::String& key, const ffi::Any& value) { + ffi::Map dict; dict.Set(key, value); return Append(dict); } diff --git a/src/contrib/msc/core/printer/prototxt_printer.h b/src/contrib/msc/core/printer/prototxt_printer.h index e760a179d8dd..f304dcdd5819 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.h +++ b/src/contrib/msc/core/printer/prototxt_printer.h @@ -53,19 +53,19 @@ class PrototxtPrinter : public MSCBasePrinter { static LiteralDoc ToLiteralDoc(const ffi::Any& obj); /*! \brief Change map to DictDoc*/ - static DictDoc ToDictDoc(const Map& dict); + static DictDoc ToDictDoc(const ffi::Map& dict); /*! \brief Change ordered pairs to DictDoc*/ - static DictDoc ToDictDoc(const std::vector>& dict); + static DictDoc ToDictDoc(const std::vector>& dict); /*! \brief Append a map into the final content*/ - void Append(const Map& dict); + void Append(const ffi::Map& dict); /*! \brief Append ordered pairs into the final content*/ - void Append(const std::vector>& dict); + void Append(const std::vector>& dict); /*! \brief Append a map pair into the final content*/ - void AppendPair(const String& key, const ffi::Any& value); + void AppendPair(const ffi::String& key, const ffi::Any& value); protected: /*! * \brief Print a DictDoc to prototxt format*/ diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index df75887ce1b6..eb087f7f40e6 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -248,7 +248,7 @@ void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { } } -void PythonPrinter::PrintIndentedBlock(const Array& docs) { +void PythonPrinter::PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { PrintDoc(d); @@ -259,7 +259,7 @@ void PythonPrinter::PrintIndentedBlock(const Array& docs) { DecreaseIndent(); } -void PythonPrinter::PrintDecorators(const Array& decorators) { +void PythonPrinter::PrintDecorators(const ffi::Array& decorators) { for (const ExprDoc& decorator : decorators) { output_ << "@"; PrintDoc(decorator, false); diff --git a/src/contrib/msc/core/printer/python_printer.h b/src/contrib/msc/core/printer/python_printer.h index 31f380bc87be..3e09b1fcdabc 100644 --- a/src/contrib/msc/core/printer/python_printer.h +++ b/src/contrib/msc/core/printer/python_printer.h @@ -92,10 +92,10 @@ class PythonPrinter : public MSCBasePrinter { private: /*! \brief Print block with indent*/ - void PrintIndentedBlock(const Array& docs); + void PrintIndentedBlock(const ffi::Array& docs); /*! \brief Print decorators for function and class*/ - void PrintDecorators(const Array& decorators); + void PrintDecorators(const ffi::Array& decorators); }; } // namespace msc diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index dec4616f5e38..630f5d473ba8 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -34,23 +34,23 @@ namespace tvm { namespace relax { using namespace tvm::contrib::msc; -std::tuple, Map> NormalizeNamedBindings( - const Function& func, const Map& untyped_params) { +std::tuple, ffi::Map> NormalizeNamedBindings( + const Function& func, const ffi::Map& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set var_set; for (const auto& param : func->params) { string_lookup[param->name_hint()].push_back(param); var_set.insert(param.get()); } - Map relax_var_remap; + ffi::Map relax_var_remap; auto normalize_key = [&](ffi::Any obj) -> relax::Var { - if (auto opt_str = obj.as()) { + if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); CHECK(it != string_lookup.end()) @@ -96,7 +96,7 @@ std::tuple, Map> NormalizeNamedBindings( } arith::Analyzer analyzer; - Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); return {relax_var_remap, symbolic_var_map}; } @@ -107,7 +107,8 @@ std::tuple, Map> NormalizeNamedBindings( * \param params params dict * \return Function */ -Function FunctionBindNamedParams(Function func, const Map& untyped_params) { +Function FunctionBindNamedParams(Function func, + const ffi::Map& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeNamedBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -121,33 +122,37 @@ Function FunctionBindNamedParams(Function func, const Map& * \param param The param dict * \return The module after binding params. */ -IRModule BindNamedParam(IRModule m, String func_name, Map bind_params) { +IRModule BindNamedParam(IRModule m, ffi::String func_name, + ffi::Map bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); - Map functions = m->functions; + ffi::Map functions = m->functions; for (const auto& func_pr : functions) { if (const auto* relax_f = func_pr.second.as()) { if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage - Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = + relax_f->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol.value() == func_name) { - Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + Function f_after_bind = + FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } else { // Use global var's name_hint if it's internal linkage if (func_pr.first->name_hint == func_name) { - Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); + Function f_after_bind = + FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } } } - return GetRef(new_module); + return ffi::GetRef(new_module); } namespace transform { -Pass BindNamedParams(String func_name, Map params) { +Pass BindNamedParams(ffi::String func_name, ffi::Map params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindNamedParam(std::move(mod), func_name, params); }; diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index b7c3491bff1a..c85c821c145a 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -37,7 +37,8 @@ namespace relax { */ class ShapeBinder : public ExprMutator { public: - explicit ShapeBinder(IRModule ctx_module, const String& entry_name) : ExprMutator(ctx_module) { + explicit ShapeBinder(IRModule ctx_module, const ffi::String& entry_name) + : ExprMutator(ctx_module) { mod_ = ctx_module; entry_name_ = entry_name; } @@ -51,7 +52,7 @@ class ShapeBinder : public ExprMutator { continue; } if (func->IsInstance()) { - Array new_params; + ffi::Array new_params; for (const auto& p : Downcast(func)->params) { auto struct_info = GetStructInfo(p); if (struct_info->IsInstance()) { @@ -76,7 +77,7 @@ class ShapeBinder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Array new_args; + ffi::Array new_args; for (const auto& a : call_node->args) { auto struct_info = GetStructInfo(a); if (a->IsInstance() && struct_info->IsInstance()) { @@ -92,7 +93,7 @@ class ShapeBinder : public ExprMutator { } else if (const auto* op_node = call_node->op.as()) { ICHECK(op_node->name == "relax.reshape" || op_node->name == "relax.image.resize2d") << "Expect ShapeExpr consumer as reshape or image.resize2d, get " - << GetRef(call_node); + << ffi::GetRef(call_node); const auto& opt_shape = Downcast(GetStructInfo(call_node->args[1]))->values; ICHECK(opt_shape.defined()) << "Expected shape defined, get " << call_node->args[1]; new_args.push_back(ShapeExpr(opt_shape.value())); @@ -101,7 +102,7 @@ class ShapeBinder : public ExprMutator { ReEmitBinding(binding, builder_->Normalize(new_call)); } else if (const auto* gv_node = call_node->op.as()) { const auto& func_info = Downcast(gv_node->struct_info_); - Array params_info; + ffi::Array params_info; for (const auto& a : new_args) { ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; @@ -113,22 +114,22 @@ class ShapeBinder : public ExprMutator { Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); ReEmitBinding(binding, builder_->Normalize(new_call)); } else { - LOG_FATAL << "Unexpected shape consumer " << GetRef(call_node); + LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); } } private: IRModule mod_; - String entry_name_; + ffi::String entry_name_; }; -IRModule BindShape(IRModule mod, const String& entry_name) { +IRModule BindShape(IRModule mod, const ffi::String& entry_name) { return ShapeBinder(mod, entry_name).Bind(); } namespace transform { -Pass BindShape(const String& entry_name) { +Pass BindShape(const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::BindShape(m, entry_name); }; return CreateModulePass(pass_func, 0, "BindShape", {}); } diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 19b8f08f4780..692ff826e150 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -41,7 +41,7 @@ using namespace tvm::contrib::msc; */ class TupleFuser : public ExprMutator { public: - explicit TupleFuser(IRModule ctx_module, const String& target, const String& entry_name) + explicit TupleFuser(IRModule ctx_module, const ffi::String& target, const ffi::String& entry_name) : ExprMutator(ctx_module) { mod_ = ctx_module; target_ = target + "."; @@ -54,7 +54,7 @@ class TupleFuser : public ExprMutator { if (gv->name_hint == entry_name_) { main_var = gv; } else { - const auto& name_opt = func->GetAttr(attr::kComposite); + const auto& name_opt = func->GetAttr(attr::kComposite); if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(), target_)) { target_funcs_.Set(gv, Downcast(func)); } @@ -70,12 +70,12 @@ class TupleFuser : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { bool has_tuple_arg = false; if (target_funcs_.count(val->op)) { - Array new_args; + ffi::Array new_args; for (size_t i = 0; i < val->args.size(); i++) { const auto& arg = val->args[i]; if (arg->IsInstance()) { - String tuple_name; - const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); + ffi::String tuple_name; + const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { if (val->args.size() == 1) { tuple_name = name_opt.value() + "_input"; @@ -114,7 +114,7 @@ class TupleFuser : public ExprMutator { } } if (on_target) { - ReEmitFunc(binding, GetRef(val)); + ReEmitFunc(binding, ffi::GetRef(val)); } else { ExprMutator::VisitBinding_(binding, val); } @@ -122,16 +122,16 @@ class TupleFuser : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { if (target_funcs_.count(val->tuple)) { - ReEmitFunc(binding, GetRef(val)); + ReEmitFunc(binding, ffi::GetRef(val)); } else { ExprMutator::VisitBinding_(binding, val); } } private: - Call AddFunc(const Expr& expr, const String tuple_name = "") { + Call AddFunc(const Expr& expr, const ffi::String tuple_name = "") { builder_->BeginDataflowBlock(); - Array inputs; + ffi::Array inputs; if (const auto* v_node = expr.as()) { inputs = v_node->fields; } else if (const auto* g_node = expr.as()) { @@ -139,17 +139,17 @@ class TupleFuser : public ExprMutator { } else { LOG_FATAL << "Unexpceted expr " << expr; } - Array func_inputs; - Array call_inputs; - Array params; - Map added_params; + ffi::Array func_inputs; + ffi::Array call_inputs; + ffi::Array params; + ffi::Map added_params; for (size_t i = 0; i < inputs.size(); i++) { if (inputs[i]->IsInstance()) { func_inputs.push_back(inputs[i]); continue; } if (!added_params.count(inputs[i])) { - const auto& name = String("param_" + std::to_string(i)); + const auto& name = ffi::String("param_" + std::to_string(i)); const auto& var = Var(std::move(name), GetStructInfo(inputs[i])); added_params.Set(inputs[i], var); } @@ -159,7 +159,7 @@ class TupleFuser : public ExprMutator { } Expr out_expr; - String func_name; + ffi::String func_name; Span expr_span = expr->span; if (!expr_span.defined()) { ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; @@ -180,7 +180,7 @@ class TupleFuser : public ExprMutator { Expr body = builder_->Normalize(output); body = builder_->Normalize(SeqExpr({new_block}, body)); - Map func_attrs; + ffi::Map func_attrs; func_attrs.Set(attr::kPrimitive, true); func_attrs.Set(attr::kComposite, target_ + func_name); func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr_span, msc_attr::kName)); @@ -190,7 +190,7 @@ class TupleFuser : public ExprMutator { /*ret_struct_info=*/std::nullopt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(func_attrs)); - Array free_vars = + ffi::Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); @@ -214,18 +214,18 @@ class TupleFuser : public ExprMutator { } IRModule mod_; - String target_; - String entry_name_; - Map target_funcs_; + ffi::String target_; + ffi::String entry_name_; + ffi::Map target_funcs_; }; -IRModule FuseTuple(IRModule mod, const String& target, const String& entry_name) { +IRModule FuseTuple(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { return TupleFuser(mod, target, entry_name).Fuse(); } namespace transform { -Pass FuseTuple(const String& target, const String& entry_name) { +Pass FuseTuple(const ffi::String& target, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::FuseTuple(m, target, entry_name); }; diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index 086c475f6d1f..eb59713e7111 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -40,7 +40,8 @@ using namespace tvm::contrib::msc; */ class ParamsInliner : public ExprMutator { public: - explicit ParamsInliner(IRModule ctx_module, const String& entry_name) : ExprMutator(ctx_module) { + explicit ParamsInliner(IRModule ctx_module, const ffi::String& entry_name) + : ExprMutator(ctx_module) { mod_ = ctx_module; entry_name_ = entry_name; } @@ -54,22 +55,22 @@ class ParamsInliner : public ExprMutator { continue; } if (func->IsInstance()) { - Array new_params; - Array attrs; + ffi::Array new_params; + ffi::Array attrs; for (const auto& p : Downcast(func)->params) { auto struct_info = GetStructInfo(p); if (struct_info->IsInstance()) { continue; } if (struct_info->IsInstance()) { - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); ICHECK(optype_opt.has_value()) << "Can not find attr " << msc_attr::kOptype << " form extern func"; extern_types_.Set(p, optype_opt.value()); continue; } if (const auto* tuple_info = struct_info.as()) { - Array new_fields; + ffi::Array new_fields; for (const auto& i : tuple_info->fields) { if (i->IsInstance()) { new_fields.push_back(i); @@ -88,7 +89,7 @@ class ParamsInliner : public ExprMutator { continue; } const auto& new_func = Downcast(VisitExpr(func)); - Map func_attrs = new_func->attrs->dict; + ffi::Map func_attrs = new_func->attrs->dict; if (attrs.size() > 0) { func_attrs.Set(msc_attr::kOpattrs, attrs); } @@ -105,7 +106,7 @@ class ParamsInliner : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Array new_args; + ffi::Array new_args; bool has_inline = false; for (const auto& a : call_node->args) { auto struct_info = GetStructInfo(a); @@ -124,8 +125,8 @@ class ParamsInliner : public ExprMutator { has_inline = true; } else if (call_node->op->IsInstance() && a->IsInstance()) { const auto& tuple = Downcast(a); - Array new_fields; - Array new_infos; + ffi::Array new_fields; + ffi::Array new_infos; for (const auto& f : tuple->fields) { if (f->IsInstance()) { @@ -152,7 +153,7 @@ class ParamsInliner : public ExprMutator { ReEmitBinding(binding, builder_->Normalize(new_call)); } else if (const auto* gv_node = call_node->op.as()) { const auto& func_info = Downcast(gv_node->struct_info_); - Array params_info; + ffi::Array params_info; for (const auto& a : new_args) { ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; @@ -164,23 +165,23 @@ class ParamsInliner : public ExprMutator { Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); ReEmitBinding(binding, builder_->Normalize(new_call)); } else { - LOG_FATAL << "Unexpected shape consumer " << GetRef(call_node); + LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); } } private: IRModule mod_; - String entry_name_; - Map extern_types_; + ffi::String entry_name_; + ffi::Map extern_types_; }; -IRModule InlineParams(IRModule mod, const String& entry_name) { +IRModule InlineParams(IRModule mod, const ffi::String& entry_name) { return ParamsInliner(mod, entry_name).Bind(); } namespace transform { -Pass InlineParams(const String& entry_name) { +Pass InlineParams(const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::InlineParams(m, entry_name); }; return CreateModulePass(pass_func, 0, "InlineParams", {}); } diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index a634b8e9e36a..a4f46dce7fe4 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -57,12 +57,12 @@ LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, } bool LayoutUtils::LayoutInfered(const Expr& expr) { - const String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + const ffi::String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); return layout.size() > 0; } bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { - const String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + const ffi::String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); const auto& sinfo = GetStructInfo(expr); if (sinfo->IsInstance() || sinfo->IsInstance()) { if (!layout.IsLeaf()) { @@ -80,8 +80,8 @@ bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { if (layout.IsLeaf()) { return false; } - String layout_str; - Array nested_layouts = layout.NestedArray(); + ffi::String layout_str; + ffi::Array nested_layouts = layout.NestedArray(); for (size_t i = 0; i < nested_layouts.size(); i++) { if (!nested_layouts[i].IsLeaf()) { return false; @@ -109,7 +109,7 @@ const NLayout LayoutUtils::GetNLayout(const Expr& expr) { return LayoutDecision(SpanUtils::GetAttr(expr->span, msc_attr::kLayout)); } if (sinfo->IsInstance()) { - String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + ffi::String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); std::vector output_layout; for (const auto& l : StringUtils::Split(layout_str, ",")) { output_layout.push_back(LayoutDecision(l)); @@ -134,7 +134,7 @@ bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) { return find; } -bool LayoutUtils::HasUnknownDimTensor(const Array& args) { +bool LayoutUtils::HasUnknownDimTensor(const ffi::Array& args) { for (const auto& arg : args) { if (IsNestedTensor(arg)) { if (HasUnknownDimTensor(GetNLayout(arg))) { @@ -204,8 +204,8 @@ const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& src_layout, } const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, - const Array& axes) { - String layout_str; + const ffi::Array& axes) { + ffi::String layout_str; for (const auto& a : axes) { layout_str = layout_str + src_layout->layout[a->value].name(); } @@ -214,7 +214,7 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes) { - String layout_str; + ffi::String layout_str; for (const auto& a : axes) { layout_str = layout_str + src_layout->layout[a].name(); } diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h index 787c73cc8404..88bcc5703589 100644 --- a/src/contrib/msc/core/transform/layout_utils.h +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -100,7 +100,7 @@ class LayoutUtils { * \brief Check if the args has unknown dim tensor. * \return Whether the args has unknown dim tensor. */ - TVM_DLL static bool HasUnknownDimTensor(const Array& args); + TVM_DLL static bool HasUnknownDimTensor(const ffi::Array& args); /*! * \brief Insert axes to the Layout @@ -120,7 +120,7 @@ class LayoutUtils { * \return The new layout. */ TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, - const Array& axes); + const ffi::Array& axes); TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes); diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc index c88cad3e64f7..a20e7d5ac3b0 100644 --- a/src/contrib/msc/core/transform/rewrite_utils.cc +++ b/src/contrib/msc/core/transform/rewrite_utils.cc @@ -29,18 +29,18 @@ namespace tvm { namespace contrib { namespace msc { -Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& expr) { +Var RewriteUtils::ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr) { expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); return builder->Emit(expr, name); } -Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, - Attrs attrs) { +Var RewriteUtils::MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, + ffi::Array args, Attrs attrs) { const auto& call = Call(op, args, attrs); return ReEmit(builder, name, call); } -Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double value, +Expr RewriteUtils::MakeConstant(BlockBuilder builder, const ffi::String& name, double value, const DataType& dtype, size_t ndim) { const auto& data = support::FloatImmToTensor(FloatImm(dtype, value)); Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); @@ -49,7 +49,7 @@ Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double return constant; } static const Op& reshape_op = Op::Get("relax.reshape"); - Array exp_shape(ndim, Integer(1)); + ffi::Array exp_shape(ndim, Integer(1)); return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)}); } diff --git a/src/contrib/msc/core/transform/rewrite_utils.h b/src/contrib/msc/core/transform/rewrite_utils.h index 307581b274ec..b5dc5e4f2a64 100644 --- a/src/contrib/msc/core/transform/rewrite_utils.h +++ b/src/contrib/msc/core/transform/rewrite_utils.h @@ -49,20 +49,20 @@ class RewriteUtils { * \brief Emit call with span name. * \return The emitted var. */ - TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const Expr& expr); + TVM_DLL static Var ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr); /*! * \brief Make and emit a call binding with span. * \return The emitted var. */ - TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, - Attrs attrs = Attrs()); + TVM_DLL static Var MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, + ffi::Array args, Attrs attrs = Attrs()); /*! * \brief Make and emit a (shaped)constant with span. * \return The constant/reshape. */ - TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, double value, + TVM_DLL static Expr MakeConstant(BlockBuilder builder, const ffi::String& name, double value, const DataType& dtype, size_t ndim = 0); }; diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 85819ea58dc6..c6b35129a8df 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -41,7 +41,8 @@ using namespace tvm::contrib::msc; */ class ByocNameSetter : public ExprMutator { public: - explicit ByocNameSetter(IRModule ctx_module, const String& target, const String& entry_name) + explicit ByocNameSetter(IRModule ctx_module, const ffi::String& target, + const ffi::String& entry_name) : ExprMutator(ctx_module) { mod_ = ctx_module; target_ = target; @@ -54,9 +55,9 @@ class ByocNameSetter : public ExprMutator { if (gv->name_hint == entry_name_) { continue; } - const auto& name_opt = func->GetAttr(attr::kCodegen); + const auto& name_opt = func->GetAttr(attr::kCodegen); if (name_opt.has_value() && name_opt.value() == target_) { - const String& func_name = target_ + "_" + std::to_string(func_cnt); + const ffi::String& func_name = target_ + "_" + std::to_string(func_cnt); const auto& new_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name)); func_cnt += 1; @@ -66,7 +67,7 @@ class ByocNameSetter : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); ExprMutator::VisitBinding_(binding, val); } @@ -74,7 +75,7 @@ class ByocNameSetter : public ExprMutator { ExprMutator::VisitBinding_(binding, val); if (val->op->IsInstance()) { ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; - const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); + const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); } @@ -83,19 +84,19 @@ class ByocNameSetter : public ExprMutator { private: IRModule mod_; - String target_; - String entry_name_; - Map new_funcs_; - Map local_funcs_; + ffi::String target_; + ffi::String entry_name_; + ffi::Map new_funcs_; + ffi::Map local_funcs_; }; -IRModule SetBYOCAttrs(IRModule mod, const String& target, const String& entry_name) { +IRModule SetBYOCAttrs(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { return ByocNameSetter(mod, target, entry_name).SetNames(); } namespace transform { -Pass SetBYOCAttrs(const String& target, const String& entry_name) { +Pass SetBYOCAttrs(const ffi::String& target, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::SetBYOCAttrs(m, target, entry_name); }; diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 59711a99188d..1e38ecd147b0 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -35,9 +35,9 @@ namespace relax { using namespace tvm::contrib::msc; -std::tuple AccumulateMatch(const Array& input_shape, - const Array& output_shape, size_t in_start, - size_t out_start) { +std::tuple AccumulateMatch(const ffi::Array& input_shape, + const ffi::Array& output_shape, + size_t in_start, size_t out_start) { // find input position in_pos and output position out_pos // cumsum(in_shape[in_start:in_pos])==cumsum(out_shape[out_start:out_pos]) std::vector in_shape, out_shape; @@ -84,7 +84,8 @@ std::tuple AccumulateMatch(const Array& input_shape, } std::tuple, std::vector> InferReshapeAxes( - const Array& input_shape, const Array& output_shape, int batch_dim) { + const ffi::Array& input_shape, const ffi::Array& output_shape, + int batch_dim) { std::vector expand_axes, reduce_axes; size_t in_start = 0; while (in_start < input_shape.size()) { @@ -120,11 +121,11 @@ std::tuple, std::vector> InferReshapeAxes( } // Forward and Backward infer -InferLayoutOutput MSCInferLayoutConv(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutConv( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision data_layout, kernel_layout, out_layout; - const String& op_name = Downcast(call->op)->name; + const ffi::String& op_name = Downcast(call->op)->name; if (op_name == "relax.nn.conv1d") { const auto* attrs = call->attrs.as(); data_layout = LayoutDecision(attrs->data_layout); @@ -144,11 +145,11 @@ InferLayoutOutput MSCInferLayoutConv(const Call& call, return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); } -InferLayoutOutput MSCInferLayoutPool2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutPool2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision layout, out_layout; - const String& op_name = Downcast(call->op)->name; + const ffi::String& op_name = Downcast(call->op)->name; if (op_name == "relax.nn.adaptive_avg_pool2d") { const auto* attrs = call->attrs.as(); layout = LayoutDecision(attrs->layout); @@ -161,9 +162,9 @@ InferLayoutOutput MSCInferLayoutPool2d(const Call& call, return InferLayoutOutput({layout}, {out_layout}, Attrs()); } -InferLayoutOutput MSCInferLayoutResize2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput MSCInferLayoutResize2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto* attrs = call->attrs.as(); const auto& data_layout = LayoutDecision(attrs->layout); const auto& shape_layout = LayoutDecision("O"); @@ -171,10 +172,10 @@ InferLayoutOutput MSCInferLayoutResize2d(const Call& call, } // Forward Infer -InferLayoutOutput ForwardInferLayoutCommon(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - Array input_layouts; +InferLayoutOutput ForwardInferLayoutCommon( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ffi::Array input_layouts; LayoutDecision layout_hint; for (const auto& arg : call->args) { const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); @@ -190,7 +191,7 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, if (sinfo->IsInstance()) { return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); } - Array output_layouts; + ffi::Array output_layouts; if (const auto* tuple_sinfo = sinfo.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { output_layouts.push_back(layout_hint); @@ -200,10 +201,10 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - Array input_layouts; +InferLayoutOutput ForwardInferLayoutBroadcast( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ffi::Array input_layouts; LayoutDecision layout_hint; for (const auto& arg : call->args) { const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); @@ -224,15 +225,15 @@ InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutInplace( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); } -InferLayoutOutput ForwardInferLayoutBinary(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutBinary( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& output = ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); if (!output.defined()) { return output; @@ -256,9 +257,9 @@ InferLayoutOutput ForwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutArgMaxMin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -280,9 +281,9 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); @@ -300,9 +301,9 @@ InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, {{in_layout, g_layout, g_layout}}, Attrs()); } -InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForkwardInferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -320,9 +321,9 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutNormalize( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); @@ -339,9 +340,9 @@ InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutMatmul( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& a_shape = ExprUtils::GetShape(call->args[0]); const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (a_shape.size() == 0) { @@ -358,7 +359,7 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, } } size_t start = a_layout->layout.ndim() - b_shape.size(); - String pre_layout; + ffi::String pre_layout; for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) { pre_layout = pre_layout + a_layout->layout[i].name(); } @@ -366,9 +367,9 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutPermute(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutPermute( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -388,9 +389,9 @@ InferLayoutOutput ForwardInferLayoutPermute(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutReduceAxis( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -414,9 +415,9 @@ InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutReshape(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutReshape( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -444,9 +445,9 @@ InferLayoutOutput ForwardInferLayoutReshape(const Call& call, return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); @@ -475,9 +476,9 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput ForwardInferLayoutTake(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutTake( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); const auto& input_shape = ExprUtils::GetShape(call->args[0]); @@ -508,9 +509,9 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call, return InferLayoutOutput(); } -InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput ForwardInferLayoutPlugin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { if (!call->args[0]->IsInstance()) { return InferLayoutOutput(); } @@ -626,9 +627,9 @@ TVM_REGISTER_OP("relax.call_dps_packed") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPlugin); // Backward Infer -InferLayoutOutput BackwardInferLayoutCommon(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutCommon( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { NLayout output_layout = LayoutUtils::InferNLayout(call, var_layout_map); LayoutDecision layout_hint; if (output_layout.IsLeaf()) { @@ -643,7 +644,7 @@ InferLayoutOutput BackwardInferLayoutCommon(const Call& call, if (!layout_hint->layout.defined()) { return InferLayoutOutput(); } - Array input_layouts; + ffi::Array input_layouts; for (const auto& arg : call->args) { const auto& saved_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); if (saved_layout->layout.defined()) { @@ -655,9 +656,9 @@ InferLayoutOutput BackwardInferLayoutCommon(const Call& call, return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutBinary(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutBinary( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& output = BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); if (!output.defined()) { return output; @@ -681,15 +682,15 @@ InferLayoutOutput BackwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput BackwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutInplace( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); } -InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutArgMaxMin( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -708,9 +709,9 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -720,9 +721,9 @@ InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, {{output_layout, g_layout, g_layout}}, Attrs()); } -InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -740,9 +741,9 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutNormalize( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -751,9 +752,9 @@ InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, return InferLayoutOutput({output_layout, g_layout, g_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutMatmul( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -763,7 +764,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, return InferLayoutOutput(); } size_t start = output_layout->layout.ndim() - b_shape.size(); - String pre_layout; + ffi::String pre_layout; for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) { pre_layout = pre_layout + output_layout->layout[i].name(); } @@ -771,9 +772,9 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, return InferLayoutOutput({output_layout, b_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutPermute(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutPermute( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -802,9 +803,9 @@ InferLayoutOutput BackwardInferLayoutPermute(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutReduceAxis( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -825,9 +826,9 @@ InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutReshape(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutReshape( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -855,9 +856,9 @@ InferLayoutOutput BackwardInferLayoutReshape(const Call& call, return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -886,9 +887,9 @@ InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutTake(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutTake( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); @@ -912,9 +913,9 @@ InferLayoutOutput BackwardInferLayoutTake(const Call& call, return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } -InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput BackwardInferLayoutTupleInputs( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); @@ -1091,16 +1092,17 @@ class LayoutInfer : public ExprVisitor { continue; } // Infer by op_node - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); InferLayoutOutput infered_layout; const auto& msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); try { if (msc_infer_map.count(op)) { FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); - } else { infered_layout = - BackwardInferLayoutCommon(call, Map>(), var_layout_map_); + f(call, ffi::Map>(), var_layout_map_); + } else { + infered_layout = BackwardInferLayoutCommon( + call, ffi::Map>(), var_layout_map_); } } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.what(); @@ -1118,7 +1120,7 @@ class LayoutInfer : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { ExprVisitor::VisitBinding_(binding, call_node); - const auto& call = GetRef(call_node); + const auto& call = ffi::GetRef(call_node); if (const auto* v_node = call->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); RecordExpr(binding->var, call); @@ -1143,7 +1145,7 @@ class LayoutInfer : public ExprVisitor { } if (infer_outputs) { // infer layouts - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); InferLayoutOutput infered_layout; const auto& msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); const auto& relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); @@ -1151,14 +1153,16 @@ class LayoutInfer : public ExprVisitor { try { if (msc_infer_map.count(op)) { FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); - } else if (!relax_infer_map.count(op)) { infered_layout = - ForwardInferLayoutCommon(call, Map>(), var_layout_map_); + f(call, ffi::Map>(), var_layout_map_); + } else if (!relax_infer_map.count(op)) { + infered_layout = ForwardInferLayoutCommon( + call, ffi::Map>(), var_layout_map_); } if (relax_infer_map.count(op) && !infered_layout.defined()) { FRelaxInferLayout f = relax_infer_map[op]; - infered_layout = f(call, Map>(), var_layout_map_); + infered_layout = + f(call, ffi::Map>(), var_layout_map_); set_inputs = false; } } catch (runtime::InternalError& err) { @@ -1187,14 +1191,14 @@ class LayoutInfer : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); + RecordExpr(binding->var, ffi::GetRef(val)); if (IsNestedTensor(binding->var)) { - Array input_layouts; + ffi::Array input_layouts; for (const auto& field : val->fields) { input_layouts.push_back(LayoutUtils::InferLayoutDecision(field, var_layout_map_)); } @@ -1204,15 +1208,15 @@ class LayoutInfer : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); - const auto& out_layout = LayoutUtils::InferLayoutDecisionAt(GetRef(val)->tuple, - var_layout_map_, val->index); + RecordExpr(binding->var, ffi::GetRef(val)); + const auto& out_layout = LayoutUtils::InferLayoutDecisionAt( + ffi::GetRef(val)->tuple, var_layout_map_, val->index); SetExprLayout(binding->var, out_layout); } void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final { ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, GetRef(val)); + RecordExpr(binding->var, ffi::GetRef(val)); SetExprLayout(binding->var, LayoutDecision("O")); } @@ -1252,7 +1256,7 @@ class LayoutInfer : public ExprVisitor { } } - void SetInputLayouts(const Call& call, const Array& input_layouts) { + void SetInputLayouts(const Call& call, const ffi::Array& input_layouts) { if (input_layouts.size() == call->args.size()) { for (size_t i = 0; i < input_layouts.size(); i++) { SetExprLayout(call->args[i], input_layouts[i]); @@ -1309,10 +1313,10 @@ class LayoutInfer : public ExprVisitor { IRModule ref_module_; bool infered_; - Map var_map_; - Array ordered_exprs_; + ffi::Map var_map_; + ffi::Array ordered_exprs_; std::unordered_map var_layout_map_; - Map local_funcs_; + ffi::Map local_funcs_; }; // class LayoutInfer class LayoutChecker : public ExprVisitor { @@ -1326,14 +1330,14 @@ class LayoutChecker : public ExprVisitor { void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); - if (!LayoutUtils::LayoutInfered(GetRef(call))) { + if (!LayoutUtils::LayoutInfered(ffi::GetRef(call))) { missing_num_++; } } void VisitExpr_(const ConstantNode* cn) final { ExprVisitor::VisitExpr_(cn); - if (!LayoutUtils::LayoutInfered(GetRef(cn))) { + if (!LayoutUtils::LayoutInfered(ffi::GetRef(cn))) { missing_num_++; } } @@ -1352,7 +1356,7 @@ void SetExprLayout(const IRModule& ref_module, const Expr& func, bool allow_miss namespace transform { -Pass SetExprLayout(bool allow_missing, const String& entry_name) { +Pass SetExprLayout(bool allow_missing, const ffi::String& entry_name) { auto pass_func = [=](IRModule m, PassContext pc) { relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing); return m; diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 14ea3ccfec7b..ecf1afd9940f 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -36,10 +36,10 @@ namespace relax { class FuncNameGetter : public ExprVisitor { public: - explicit FuncNameGetter(const Array& arg_names) : arg_names_(arg_names) {} + explicit FuncNameGetter(const ffi::Array& arg_names) : arg_names_(arg_names) {} - /*! \brief Get the attributes from prim value as Map*/ - String HintName(const Expr& expr) { + /*! \brief Get the attributes from prim value as ffi::Map*/ + ffi::String HintName(const Expr& expr) { name_ = ""; ExprVisitor::VisitExpr(expr); return name_; @@ -73,8 +73,8 @@ class FuncNameGetter : public ExprVisitor { } private: - String name_; - Array arg_names_; + ffi::String name_; + ffi::Array arg_names_; }; /*! @@ -82,16 +82,16 @@ class FuncNameGetter : public ExprVisitor { */ class RelaxExprNameSetter : public ExprVisitor { public: - explicit RelaxExprNameSetter(const IRModule& ref_module, const String& target, - const Map& var_names) + explicit RelaxExprNameSetter(const IRModule& ref_module, const ffi::String& target, + const ffi::Map& var_names) : ref_module_(ref_module), target_{target}, var_names_{var_names} {} void VisitBindingBlock(const BindingBlock& block) final { - String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); + ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; } - const String& prefix = StringUtils::Join(block_stack_, "."); + const ffi::String& prefix = StringUtils::Join(block_stack_, "."); if (setted_blocks_.count(prefix + "." + block_name)) { int cnt = 1; while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { @@ -101,7 +101,7 @@ class RelaxExprNameSetter : public ExprVisitor { } setted_blocks_.insert(prefix + "." + block_name); block_stack_.push_back(block_name); - const String& unique_name = StringUtils::Join(block_stack_, "."); + const ffi::String& unique_name = StringUtils::Join(block_stack_, "."); block->span = SpanUtils::SetAttr(block->span, msc_attr::kName, unique_name); ExprVisitor::VisitBindingBlock(block); block_stack_.pop_back(); @@ -109,16 +109,16 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitExpr_(const ConstantNode* val) { ExprVisitor::VisitExpr_(val); - const String& unique_name = GetUniqueName(GetRef(val), "const"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } - expr_names_.Set(GetRef(val), unique_name); + expr_names_.Set(ffi::GetRef(val), unique_name); } void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "const"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -127,7 +127,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "shape"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "shape"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -136,7 +136,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); - const String& unique_name = GetUniqueName(GetRef(val), "tuple"); + const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "tuple"); if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } @@ -145,7 +145,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - String unique_name; + ffi::String unique_name; if (expr_names_.count(val->tuple)) { unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); } else if (const auto* v_node = val->tuple.as()) { @@ -159,15 +159,15 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& name_opt = val->GetAttr(attr::kComposite); + const auto& name_opt = val->GetAttr(attr::kComposite); if (name_opt.has_value()) { - local_funcs_.Set(binding->var, GetRef(val)); + local_funcs_.Set(binding->var, ffi::GetRef(val)); } } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { ExprVisitor::VisitBinding_(binding, val); - String name_hint, optype; + ffi::String name_hint, optype; bool use_unique = true; if (var_names_.count(binding->var->name_hint())) { name_hint = var_names_[binding->var->name_hint()]; @@ -177,7 +177,7 @@ class RelaxExprNameSetter : public ExprVisitor { const auto& func = Downcast(val->args[0]); name_hint = func->global_symbol; optype = func->global_symbol; - const String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); + const ffi::String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); if (input_name != SpanUtils::GetAttr(val->args[1]->span, msc_attr::kName)) { val->args[1]->span = SpanUtils::SetAttr(val->args[1]->span, msc_attr::kName, input_name); } @@ -190,27 +190,28 @@ class RelaxExprNameSetter : public ExprVisitor { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); ExprVisitor::VisitExpr(func); optype = GetFuncType(func); - name_hint = GetFuncName(GetRef(val), func); + name_hint = GetFuncName(ffi::GetRef(val), func); use_unique = false; } else if (local_funcs_.count(val->op)) { ExprVisitor::VisitExpr(local_funcs_[val->op]); optype = GetFuncType(local_funcs_[val->op]); - name_hint = GetFuncName(GetRef(val), local_funcs_[val->op]); + name_hint = GetFuncName(ffi::GetRef(val), local_funcs_[val->op]); use_unique = false; } if (name_hint.size() > 0) { // set name - const String& unique_name = - use_unique ? GetUniqueName(GetRef(val), name_hint) : name_hint; + const ffi::String& unique_name = + use_unique ? GetUniqueName(ffi::GetRef(val), name_hint) : name_hint; if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } // set constant consumer && shared_ref - Array input_types; + ffi::Array input_types; try { input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(val) << " : " << err.what(); + LOG(WARNING) << "Failed to GetInputTypes for " << ffi::GetRef(val) << " : " + << err.what(); throw err; } for (size_t i = 0; i < input_types.size(); i++) { @@ -218,7 +219,7 @@ class RelaxExprNameSetter : public ExprVisitor { continue; } if (const auto* c_node = val->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); + const ffi::String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); if (constant_consumers_.count(const_name)) { val->span = SpanUtils::SetAttr(val->span, msc_attr::kSharedRef, constant_consumers_[const_name]); @@ -232,8 +233,8 @@ class RelaxExprNameSetter : public ExprVisitor { } private: - const String GetUniqueName(const Expr& expr, const String& name_hint) { - String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); + const ffi::String GetUniqueName(const Expr& expr, const ffi::String& name_hint) { + ffi::String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (expr_name.size() == 0) { expr_name = name_hint; } @@ -256,10 +257,10 @@ class RelaxExprNameSetter : public ExprVisitor { return expr_name; } - const String GetFuncType(const Function& func) { - String optype; - const auto& comp_opt = func->GetAttr(attr::kComposite); - const auto& code_opt = func->GetAttr(attr::kCodegen); + const ffi::String GetFuncType(const Function& func) { + ffi::String optype; + const auto& comp_opt = func->GetAttr(attr::kComposite); + const auto& code_opt = func->GetAttr(attr::kCodegen); if (comp_opt.has_value()) { optype = comp_opt.value(); } else if (code_opt.has_value()) { @@ -273,15 +274,15 @@ class RelaxExprNameSetter : public ExprVisitor { return optype; } - const String GetFuncName(const Call& call, const Function& func) { - String name; + const ffi::String GetFuncName(const Call& call, const Function& func) { + ffi::String name; // get from unique - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { return name_opt.value(); } // get from exprs in the func - Array arg_names; + ffi::Array arg_names; for (const auto& a : call->args) { arg_names.push_back(expr_names_.count(a) ? expr_names_[a] : ""); } @@ -298,26 +299,26 @@ class RelaxExprNameSetter : public ExprVisitor { return GetUniqueName(call, name); } - Map setted_names_; - Map constant_consumers_; - std::set setted_blocks_; - Array block_stack_; - Map expr_names_; - Map local_funcs_; + ffi::Map setted_names_; + ffi::Map constant_consumers_; + std::set setted_blocks_; + ffi::Array block_stack_; + ffi::Map expr_names_; + ffi::Map local_funcs_; IRModule ref_module_; - String target_; - Map var_names_; + ffi::String target_; + ffi::Map var_names_; }; // class ExprNameSetter -void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const String& target, - const Map& var_names) { +void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const ffi::String& target, + const ffi::Map& var_names) { RelaxExprNameSetter(ref_module, target, var_names).VisitExpr(e); } namespace transform { -Pass SetRelaxExprName(const String& entry_name, const String& target, - const Map& var_names) { +Pass SetRelaxExprName(const ffi::String& entry_name, const ffi::String& target, + const ffi::Map& var_names) { auto pass_func = [=](IRModule m, PassContext pc) { relax::SetRelaxExprName(m, m->Lookup(entry_name), target, var_names); return m; diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index f4a79602f506..720574cfa9a9 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -69,8 +69,8 @@ int CommonUtils::CompareVersion(const std::vector& given_version, return 0; } -int CommonUtils::CompareVersion(const Array& given_version, - const Array& target_version) { +int CommonUtils::CompareVersion(const ffi::Array& given_version, + const ffi::Array& target_version) { std::vector int_given_version; std::vector int_target_version; for (const auto& v : given_version) { @@ -82,7 +82,7 @@ int CommonUtils::CompareVersion(const Array& given_version, return CompareVersion(int_given_version, int_target_version); } -const String CommonUtils::ToAttrKey(const String& key) { +const ffi::String CommonUtils::ToAttrKey(const ffi::String& key) { if (key == "name") { return msc_attr::kName; } @@ -111,7 +111,7 @@ const String CommonUtils::ToAttrKey(const String& key) { TVM_FFI_UNREACHABLE(); } -bool StringUtils::Contains(const String& src_string, const String& sub_string) { +bool StringUtils::Contains(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -125,7 +125,7 @@ bool StringUtils::Contains(const String& src_string, const String& sub_string) { return pos >= 0; } -bool StringUtils::StartsWith(const String& src_string, const String& sub_string) { +bool StringUtils::StartsWith(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -138,7 +138,7 @@ bool StringUtils::StartsWith(const String& src_string, const String& sub_string) return pos == 0; } -bool StringUtils::EndsWith(const String& src_string, const String& sub_string) { +bool StringUtils::EndsWith(const ffi::String& src_string, const ffi::String& sub_string) { if (src_string.size() == 0) { return false; } @@ -154,8 +154,9 @@ bool StringUtils::EndsWith(const String& src_string, const String& sub_string) { return static_cast(pos) == src_cstring.size() - sub_cstring.size(); } -const Array StringUtils::Split(const String& src_string, const String& sep) { - Array sub_strings; +const ffi::Array StringUtils::Split(const ffi::String& src_string, + const ffi::String& sep) { + ffi::Array sub_strings; if (src_string.size() == 0) { return sub_strings; } @@ -175,26 +176,27 @@ const Array StringUtils::Split(const String& src_string, const String& s return sub_strings; } -const String StringUtils::Join(const Array& sub_strings, const String& joint) { - String join_str = ""; +const ffi::String StringUtils::Join(const ffi::Array& sub_strings, + const ffi::String& joint) { + ffi::String join_str = ""; for (size_t i = 0; i < sub_strings.size(); i++) { join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : joint); } return join_str; } -const String StringUtils::Join(const std::vector& sub_strings, - const std::string& joint) { - Array new_strings; +const ffi::String StringUtils::Join(const std::vector& sub_strings, + const std::string& joint) { + ffi::Array new_strings; for (const auto& s : sub_strings) { new_strings.push_back(s); } return Join(new_strings, joint); } -const String StringUtils::Replace(const String& src_string, const String& old_str, - const String& new_str) { - String new_string; +const ffi::String StringUtils::Replace(const ffi::String& src_string, const ffi::String& old_str, + const ffi::String& new_str) { + ffi::String new_string; const auto& sub_strings = Split(src_string, old_str); for (size_t i = 0; i < sub_strings.size(); i++) { new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : new_str); @@ -202,10 +204,11 @@ const String StringUtils::Replace(const String& src_string, const String& old_st return new_string; } -const std::tuple StringUtils::SplitOnce(const String& src_string, const String& sep, - bool from_left) { +const std::tuple StringUtils::SplitOnce(const ffi::String& src_string, + const ffi::String& sep, + bool from_left) { if (src_string.size() == 0) { - return std::make_tuple(String(), String()); + return std::make_tuple(ffi::String(), ffi::String()); } std::string src_cstring = src_string; const std::string& csep = sep; @@ -213,17 +216,18 @@ const std::tuple StringUtils::SplitOnce(const String& src_string if (pos >= 0) { return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos + csep.size())); } - return std::make_tuple(src_string, String()); + return std::make_tuple(src_string, ffi::String()); } -const Array StringUtils::GetClosures(const String& src_string, const String& left, - const String& right) { - Array tokens; +const ffi::Array StringUtils::GetClosures(const ffi::String& src_string, + const ffi::String& left, + const ffi::String& right) { + ffi::Array tokens; if (src_string.size() == 0) { return tokens; } - String token = "start"; - String left_str = src_string; + ffi::String token = "start"; + ffi::String left_str = src_string; while (token.size() > 0) { std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left); if (left_str.size() > 0) { @@ -238,35 +242,36 @@ const Array StringUtils::GetClosures(const String& src_string, const Str return tokens; } -const String StringUtils::GetClosureOnce(const String& src_string, const String& left, - const String& right, bool from_left) { +const ffi::String StringUtils::GetClosureOnce(const ffi::String& src_string, + const ffi::String& left, const ffi::String& right, + bool from_left) { if (src_string.size() == 0) { return ""; } - String val = std::get<1>(SplitOnce(src_string, left, from_left)); + ffi::String val = std::get<1>(SplitOnce(src_string, left, from_left)); if (val.size() > 0) { val = std::get<0>(StringUtils::SplitOnce(val, right, from_left)); } return val; } -const String StringUtils::Upper(const String& src_string) { +const ffi::String StringUtils::Upper(const ffi::String& src_string) { std::string str = std::string(src_string); std::transform(str.begin(), str.end(), str.begin(), ::toupper); return str; } -const String StringUtils::Lower(const String& src_string) { +const ffi::String StringUtils::Lower(const ffi::String& src_string) { std::string str = std::string(src_string); std::transform(str.begin(), str.end(), str.begin(), ::tolower); return str; } -const String StringUtils::ToString(const ffi::Any& obj) { - String obj_string; +const ffi::String StringUtils::ToString(const ffi::Any& obj) { + ffi::String obj_string; if (obj == nullptr) { obj_string = ""; - } else if (auto opt_str = obj.as()) { + } else if (auto opt_str = obj.as()) { obj_string = *opt_str; } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); @@ -291,7 +296,8 @@ const String StringUtils::ToString(const ffi::Any& obj) { return obj_string; } -bool ArrayUtils::CompareArrays(const Array& left, const Array& right, int size) { +bool ArrayUtils::CompareArrays(const ffi::Array& left, + const ffi::Array& right, int size) { if (left.size() == right.size() && left.size() == 0) { return true; } @@ -314,7 +320,7 @@ bool ArrayUtils::CompareArrays(const Array& left, const Array& r return true; } -PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { +PrimExpr ArrayUtils::Accumulate(const ffi::Array& array, int pos) { size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos; PrimExpr accumulate = Integer(1); for (size_t i = 0; i < t_pos; i++) { @@ -323,7 +329,7 @@ PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { return accumulate; } -bool ArrayUtils::Broadcastable(const Array& lhs, const Array& rhs) { +bool ArrayUtils::Broadcastable(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -345,16 +351,16 @@ bool ArrayUtils::Broadcastable(const Array& lhs, const Array return true; } -const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) { +const Span SpanUtils::SetAttr(const Span& span, const ffi::String& key, const ffi::String& value) { if (value.size() == 0) { return span; } - String new_source; - Array tokens{"<" + key + ">", ""}; + ffi::String new_source; + ffi::Array tokens{"<" + key + ">", ""}; if (span.defined() && span->source_name.defined()) { - const String& source_str = span->source_name->name; - String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); - String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); + const ffi::String& source_str = span->source_name->name; + ffi::String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); + ffi::String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); if (StringUtils::Contains(source_str, tokens[0]) && StringUtils::Contains(source_str, tokens[1])) { new_source = left + tokens[0] + value + tokens[1] + right; @@ -371,29 +377,29 @@ const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& return Span(SourceName::Get(new_source), 0, 0, 0, 0); } -String SpanUtils::GetAttr(const Span& span, const String& key) { +ffi::String SpanUtils::GetAttr(const Span& span, const ffi::String& key) { if (span.defined() && span->source_name.defined()) { - Array tokens{"<" + key + ">", ""}; + ffi::Array tokens{"<" + key + ">", ""}; return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], tokens[1]); } return ""; } -const Map SpanUtils::GetAttrs(const Span& span) { - Map attrs; +const ffi::Map SpanUtils::GetAttrs(const Span& span) { + ffi::Map attrs; for (const auto& key : StringUtils::GetClosures(span->source_name->name, "")) { attrs.Set(key, GetAttr(span, key)); } return attrs; } -const Span SpanUtils::CreateWithAttr(const String& key, const String& value) { +const Span SpanUtils::CreateWithAttr(const ffi::String& key, const ffi::String& value) { return SetAttr(Span(), key, value); } -const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs_num, - bool as_relax) { - Array input_types; +const ffi::Array ExprUtils::GetInputTypes(const ffi::String& optype, size_t inputs_num, + bool as_relax) { + ffi::Array input_types; if (as_relax && (optype == "broadcast_to" || optype == "reshape")) { input_types.push_back("input"); if (inputs_num > 1) { @@ -490,12 +496,12 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs return input_types; } -const Array ExprUtils::GetInputTypes(const Call& call) { - const String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); +const ffi::Array ExprUtils::GetInputTypes(const Call& call) { + const ffi::String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); return GetInputTypes(optype, call->args.size(), true); } -const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { +const ffi::String ExprUtils::GetSpanName(const Expr& expr, const ffi::String& suffix) { const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (suffix.size() > 0) { return name + "_" + suffix; @@ -503,13 +509,13 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { +const ffi::Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { const auto& shape_opt = sinfo->GetShape(); if (!shape_opt.defined()) { - return Array(); + return ffi::Array(); } if (as_int) { - Array shape; + ffi::Array shape; for (const auto& s : shape_opt.value()) { shape.push_back(s->IsInstance() ? s : Integer(-1)); } @@ -518,7 +524,7 @@ const Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as return shape_opt.value(); } -const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { +const ffi::Array ExprUtils::GetShape(const Expr& expr, bool as_int) { return GetShape(Downcast(GetStructInfo(expr)), as_int); } @@ -532,20 +538,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("msc.core.SpanGetAttr", SpanUtils::GetAttr) .def("msc.core.SpanGetAttrs", SpanUtils::GetAttrs) .def("msc.core.SpanCreateWithAttr", - [](const String& key, const String& value) -> Span { + [](const ffi::String& key, const ffi::String& value) -> Span { return SpanUtils::CreateWithAttr(key, value); }) .def("msc.core.SpanSetAttr", - [](const Span& span, const String& key, const String& value) -> Span { + [](const Span& span, const ffi::String& key, const ffi::String& value) -> Span { return SpanUtils::SetAttr(span, key, value); }) - .def( - "msc.core.CompareVersion", - [](const Array& given_version, const Array& target_version) -> Integer { - return Integer(CommonUtils::CompareVersion(given_version, target_version)); - }) + .def("msc.core.CompareVersion", + [](const ffi::Array& given_version, + const ffi::Array& target_version) -> Integer { + return Integer(CommonUtils::CompareVersion(given_version, target_version)); + }) .def("msc.core.ToAttrKey", - [](const String& key) -> String { return CommonUtils::ToAttrKey(key); }); + [](const ffi::String& key) -> ffi::String { return CommonUtils::ToAttrKey(key); }); }); } // namespace msc diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 19ad0020e5ca..a0732d5848ac 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -82,13 +82,13 @@ class CommonUtils { */ TVM_DLL static int CompareVersion(const std::vector& given_version, const std::vector& target_version); - TVM_DLL static int CompareVersion(const Array& given_version, - const Array& target_version); + TVM_DLL static int CompareVersion(const ffi::Array& given_version, + const ffi::Array& target_version); /*! * \brief Get attr key. * \return The attr key. */ - TVM_DLL static const String ToAttrKey(const String& key); + TVM_DLL static const ffi::String ToAttrKey(const ffi::String& key); }; /*! @@ -97,83 +97,87 @@ class CommonUtils { class StringUtils { public: /*! - * \brief Check if the String contains a substring. + * \brief Check if the ffi::String contains a substring. * \return Whether substring is contained. */ - TVM_DLL static bool Contains(const String& src_string, const String& sub_string); + TVM_DLL static bool Contains(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Check if the String starts with a substring. + * \brief Check if the ffi::String starts with a substring. * \return Whether string starts with substring. */ - TVM_DLL static bool StartsWith(const String& src_string, const String& sub_string); + TVM_DLL static bool StartsWith(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Check if the String ens with a substring. + * \brief Check if the ffi::String ens with a substring. * \return Whether string endswith substring. */ - TVM_DLL static bool EndsWith(const String& src_string, const String& sub_string); + TVM_DLL static bool EndsWith(const ffi::String& src_string, const ffi::String& sub_string); /*! - * \brief Split the String into sub Strings. + * \brief Split the ffi::String into sub Strings. * \return The SubStrings. */ - TVM_DLL static const Array Split(const String& src_string, const String& sep); + TVM_DLL static const ffi::Array Split(const ffi::String& src_string, + const ffi::String& sep); /*! * \brief Join the SubStrings into String. * \return The String. */ - TVM_DLL static const String Join(const Array& sub_strings, const String& joint); - TVM_DLL static const String Join(const std::vector& sub_strings, - const std::string& joint); + TVM_DLL static const ffi::String Join(const ffi::Array& sub_strings, + const ffi::String& joint); + TVM_DLL static const ffi::String Join(const std::vector& sub_strings, + const std::string& joint); /*! * \brief Replace the substring old to new in String. * \return The replaced String. */ - TVM_DLL static const String Replace(const String& src_string, const String& old_str, - const String& new_str); + TVM_DLL static const ffi::String Replace(const ffi::String& src_string, + const ffi::String& old_str, const ffi::String& new_str); /*! - * \brief Split the String into two sub Strings, only split by the frist seq. + * \brief Split the ffi::String into two sub Strings, only split by the frist seq. * \return The SubStrings. */ - TVM_DLL static const std::tuple SplitOnce(const String& src_string, - const String& sep, - bool from_left = false); + TVM_DLL static const std::tuple SplitOnce(const ffi::String& src_string, + const ffi::String& sep, + bool from_left = false); /*! * \brief Get the tokens between left and right. * \return The Tokens. */ - TVM_DLL static const Array GetClosures(const String& src_string, const String& left, - const String& right); + TVM_DLL static const ffi::Array GetClosures(const ffi::String& src_string, + const ffi::String& left, + const ffi::String& right); /*! * \brief Get the first token between left and right. * \return The Token. */ - TVM_DLL static const String GetClosureOnce(const String& src_string, const String& left, - const String& right, bool from_left = true); + TVM_DLL static const ffi::String GetClosureOnce(const ffi::String& src_string, + const ffi::String& left, const ffi::String& right, + bool from_left = true); /*! * \brief Change string to upper. * \return The String. */ - TVM_DLL static const String Upper(const String& src_string); + TVM_DLL static const ffi::String Upper(const ffi::String& src_string); /*! * \brief Change string to lower. * \return The String. */ - TVM_DLL static const String Lower(const String& src_string); + TVM_DLL static const ffi::String Lower(const ffi::String& src_string); /*! * \brief Change Object to String. * \return The String. */ - TVM_DLL static const String ToString(const ffi::Any& obj); + TVM_DLL static const ffi::String ToString(const ffi::Any& obj); }; /*! @@ -186,9 +190,9 @@ class ArrayUtils { * \return The replaced Array. */ template - TVM_DLL static const Array Replace(const Array& src_array, const T& old_ele, - const T& new_ele) { - Array new_array; + TVM_DLL static const ffi::Array Replace(const ffi::Array& src_array, const T& old_ele, + const T& new_ele) { + ffi::Array new_array; for (const auto& a : src_array) { if (a == old_ele) { new_array.push_back(new_ele); @@ -218,8 +222,8 @@ class ArrayUtils { * \return The downcasted array */ template - TVM_DLL static const Array Cast(const Array& src_array) { - Array new_array; + TVM_DLL static const ffi::Array Cast(const ffi::Array& src_array) { + ffi::Array new_array; for (const auto& s : src_array) { new_array.push_back(Downcast(s)); } @@ -231,21 +235,21 @@ class ArrayUtils { * \return The producted array */ template - TVM_DLL static const Array> Product(const Array>& arrays) { - Array> p_arrays; + TVM_DLL static const ffi::Array> Product(const ffi::Array>& arrays) { + ffi::Array> p_arrays; if (arrays.size() == 1) { for (const auto& a : arrays[0]) { - p_arrays.push_back(Array{a}); + p_arrays.push_back(ffi::Array{a}); } return p_arrays; } - Array> sub_arrays; + ffi::Array> sub_arrays; for (size_t i = 0; i < arrays.size() - 1; i++) { sub_arrays.push_back(arrays[i]); } for (const auto& p_array : Product(sub_arrays)) { for (const auto& a : arrays[arrays.size() - 1]) { - Array sub_array = p_array; + ffi::Array sub_array = p_array; sub_array.push_back(a); p_arrays.push_back(sub_array); } @@ -254,22 +258,23 @@ class ArrayUtils { } /*! - * \brief Compare String arrays. + * \brief Compare ffi::String arrays. * \return Whether two array are same. */ - TVM_DLL static bool CompareArrays(const Array& left, const Array& right, - int size = -1); + TVM_DLL static bool CompareArrays(const ffi::Array& left, + const ffi::Array& right, int size = -1); /*! * \brief Accumulate array. * \return The accumulate result */ - TVM_DLL static PrimExpr Accumulate(const Array& array, int pos = -1); + TVM_DLL static PrimExpr Accumulate(const ffi::Array& array, int pos = -1); /*! * \brief Check if lhs array is broadcastable to rhs. * \return broadcastable */ - TVM_DLL static bool Broadcastable(const Array& lhs, const Array& rhs); + TVM_DLL static bool Broadcastable(const ffi::Array& lhs, + const ffi::Array& rhs); }; /*! @@ -281,25 +286,26 @@ class SpanUtils { * \brief Set value to the Span. * \return The new Span. */ - TVM_DLL static const Span SetAttr(const Span& span, const String& key, const String& value); + TVM_DLL static const Span SetAttr(const Span& span, const ffi::String& key, + const ffi::String& value); /*! * \brief Get the value in value from the Span. * \return The value String. */ - TVM_DLL static String GetAttr(const Span& span, const String& key); + TVM_DLL static ffi::String GetAttr(const Span& span, const ffi::String& key); /*! * \brief Get all the key:value in format value from the Span. * \return The Attrs Map. */ - TVM_DLL static const Map GetAttrs(const Span& span); + TVM_DLL static const ffi::Map GetAttrs(const Span& span); /*! * \brief Create a span with value. * \return The created Span. */ - TVM_DLL static const Span CreateWithAttr(const String& key, const String& value); + TVM_DLL static const Span CreateWithAttr(const ffi::String& key, const ffi::String& value); }; /*! @@ -311,14 +317,14 @@ class ExprUtils { * \brief Get the input types of call. * \return The input types. */ - TVM_DLL static const Array GetInputTypes(const String& optype, size_t inputs_num, - bool as_relax); + TVM_DLL static const ffi::Array GetInputTypes(const ffi::String& optype, + size_t inputs_num, bool as_relax); /*! * \brief Get the input types of call. * \return The input types. */ - TVM_DLL static const Array GetInputTypes(const Call& call); + TVM_DLL static const ffi::Array GetInputTypes(const Call& call); /*! * \brief Get the scalar value of ndarray. @@ -371,14 +377,15 @@ class ExprUtils { * \brief Get name in span. * \return The name. */ - TVM_DLL static const String GetSpanName(const Expr& expr, const String& suffix = ""); + TVM_DLL static const ffi::String GetSpanName(const Expr& expr, const ffi::String& suffix = ""); /*! * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const TensorStructInfo& sinfo, bool as_int = true); - TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); + TVM_DLL static const ffi::Array GetShape(const TensorStructInfo& sinfo, + bool as_int = true); + TVM_DLL static const ffi::Array GetShape(const Expr& expr, bool as_int = true); /*! * \brief Get dtype of expr. diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 6a77440b7204..954341114df7 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -88,7 +88,7 @@ void TensorflowCodeGen::CodeGenGraph() { } CodeGenNode(node, config()->use_tools); } - Array idx_outputs; + ffi::Array idx_outputs; for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); idx_outputs.push_back(IdxOutputBase(pair.first, pair.second)); @@ -139,7 +139,7 @@ void TensorflowCodeGen::CodeGenInference() { .scope_end(); } -const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTFV1OpCodes(); auto it = ops_map->find(node->optype); ICHECK(it != ops_map->end()) << "Unsupported tensorflow op(" << node->optype << "): " << node; @@ -155,8 +155,8 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.tensorflow.GetTensorflowSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); diff --git a/src/contrib/msc/framework/tensorflow/codegen.h b/src/contrib/msc/framework/tensorflow/codegen.h index af2579980a39..5052c11004d2 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.h +++ b/src/contrib/msc/framework/tensorflow/codegen.h @@ -59,10 +59,10 @@ class TensorflowCodeGen : public PyCodeGen GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "tf_v1.Tensor"; } + const ffi::String TensorType() const final { return "tf_v1.Tensor"; } }; } // namespace msc diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc index 570088ee35c2..d47021d84da5 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc @@ -29,17 +29,16 @@ namespace tvm { namespace contrib { namespace msc { -const Array TFV1OpCode::GetDocs() { +const ffi::Array TFV1OpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); return stack_.GetDocs(); } -const std::pair> TFV1OpCode::GetPadding(const String& strides_key, - const String& kernel_key, - const String& padding_key) { - String pad_mod = ""; - Array padding; +const std::pair> TFV1OpCode::GetPadding( + const ffi::String& strides_key, const ffi::String& kernel_key, const ffi::String& padding_key) { + ffi::String pad_mod = ""; + ffi::Array padding; std::vector kernel_size; if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias") { const auto& weight = node()->WeightAt("weight"); @@ -98,7 +97,7 @@ const std::pair> TFV1OpCode::GetPadding(const String& stri #define TFV1_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : TFV1OpCode(func_name) {} + TypeName(const ffi::String& func_name) : TFV1OpCode(func_name) {} class TFV1ArgMaxMinCodeGen : public TFV1OpCode { TFV1_OP_CODEGEN_METHODS(TFV1ArgMaxMinCodeGen) @@ -128,23 +127,25 @@ class TFV1AstypeCodeGen : public TFV1OpCode { class TFV1AxesCodeGen : public TFV1OpCode { public: - TFV1AxesCodeGen(const String& func_name, const String& attr_name) : TFV1OpCode(func_name) { + TFV1AxesCodeGen(const ffi::String& func_name, const ffi::String& attr_name) + : TFV1OpCode(func_name) { attr_name_ = attr_name; } protected: void CodeGenBuild() final { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().op_list_arg(key, attr_name_).op_name_arg(); } private: - String attr_name_; + ffi::String attr_name_; }; class TFV1AxisCodeGen : public TFV1OpCode { public: - TFV1AxisCodeGen(const String& func_name, const String& attr_name) : TFV1OpCode(func_name) { + TFV1AxisCodeGen(const ffi::String& func_name, const ffi::String& attr_name) + : TFV1OpCode(func_name) { attr_name_ = attr_name; } @@ -154,7 +155,7 @@ class TFV1AxisCodeGen : public TFV1OpCode { } private: - String attr_name_; + ffi::String attr_name_; }; class TFV1BatchnormCodeGen : public TFV1OpCode { @@ -168,8 +169,8 @@ class TFV1BatchnormCodeGen : public TFV1OpCode { .op_arg("center") .op_arg("momentum") .op_arg("epsilon"); - Array weight_names{"gamma", "beta", "mean", "var"}; - Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; + ffi::Array weight_names{"gamma", "beta", "mean", "var"}; + ffi::Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; for (size_t i = 0; i < weight_names.size(); i++) { const auto& w_doc = DocUtils::ToStr(node()->WeightAt(weight_names[i])->name); stack_.inplace_start("tf_v1.constant_initializer", init_names[i] + "_initializer") @@ -219,7 +220,7 @@ class TFV1ConstantCodeGen : public TFV1OpCode { class TFV1ConvCodeGen : public TFV1OpCode { public: - TFV1ConvCodeGen(const String& func_name, bool use_bias) : TFV1OpCode(func_name) { + TFV1ConvCodeGen(const ffi::String& func_name, bool use_bias) : TFV1OpCode(func_name) { use_bias_ = use_bias; } @@ -318,19 +319,19 @@ class TFV1PadCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String mode; + ffi::String mode; const auto& attr_mode = node()->GetTypeAttr("pad_mode"); if (attr_mode == "constant") { mode = "CONSTANT"; } else { LOG_FATAL << "Unexpected pad mode " << node(); } - Array pad_width; + ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; + const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + + std::to_string(attr_pad_width[i + 1]) + "]"; pad_width.push_back(cur_pad); } const auto& val_producer = node()->ProducerOf(1); @@ -349,7 +350,7 @@ class TFV1Pool2dCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String pooling_type; + ffi::String pooling_type; if (node()->optype == "nn.avg_pool2d") { pooling_type = "AVG"; } else if (node()->optype == "nn.max_pool2d") { @@ -413,7 +414,7 @@ class TFV1Resize2dCodeGen : public TFV1OpCode { protected: void CodeGenBuild() final { - String func_name; + ffi::String func_name; const auto& method = node()->GetTypeAttr("method"); const auto& coordinate_transformation_mode = node()->GetTypeAttr("coordinate_transformation_mode"); @@ -502,8 +503,10 @@ class TFV1TupleCodeGen : public TFV1OpCode { void CodeGenBuild() final { stack_.op_call().op_inputs_arg(); } }; -const std::shared_ptr>> GetTFV1OpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetTFV1OpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // binary && unary ops map->emplace("abs", std::make_shared("tf_v1.abs")); diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h index bda7e6e99336..a744ffc701e4 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h @@ -50,14 +50,14 @@ class TFV1OpCode : public BaseOpCode * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit TFV1OpCode(const String& func_name) + explicit TFV1OpCode(const ffi::String& func_name) : BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; /*! \brief Get dtype string*/ - const String DType(const DataType& dtype) final { + const ffi::String DType(const DataType& dtype) final { return "tf_v1." + BaseOpCode::DType(dtype); } @@ -68,16 +68,17 @@ class TFV1OpCode : public BaseOpCode virtual void CodeGenBuild() = 0; /*! \brief Get padding mode or array*/ - const std::pair> GetPadding(const String& strides_key, - const String& kernel_key = "", - const String& padding_key = "padding"); + const std::pair> GetPadding( + const ffi::String& strides_key, const ffi::String& kernel_key = "", + const ffi::String& padding_key = "padding"); }; /*! * \brief Get the map of available TFV1OpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetTFV1OpCodes(); +const std::shared_ptr>> +GetTFV1OpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 7acd0f215502..b0d290328d62 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -48,7 +48,7 @@ void TensorRTCodeGen::CodeGenClassDeclare() { } // plugin headers if (config()->use_plugin) { - std::set plugins; + std::set plugins; for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); if (IsPlugin(node->optype) && !plugins.count(node->optype)) { @@ -95,7 +95,7 @@ void TensorRTCodeGen::CodeGenClassDeclare() { void TensorRTCodeGen::CodeGenClassDefine() { auto malloc_buffer = [this](const MSCTensor& tensor) { - const String& idx_var = "idx_" + IdxTensor(tensor); + const ffi::String& idx_var = "idx_" + IdxTensor(tensor); this->stack_ .func_call("getBindingIndex", DocUtils::ToDeclare("int", idx_var), DocUtils::ToPtr("engine")) @@ -121,8 +121,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { // save codegen before build if (config()->use_tools) { const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - before_build_codes_ = - pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag).cast>(); + before_build_codes_ = pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag) + .cast>(); } if (graph()->weight_holders.size() > 0) { stack_.func_call("TRTUtils::LoadWeights", "mWeights") @@ -144,7 +144,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.comment("Mark batch size"); stack_.func_call("createOptimizationProfile", DocUtils::ToDeclare("auto", "profile"), DocUtils::ToPtr("builder")); - Array batch_flags{"MIN", "MAX", "OPT"}; + ffi::Array batch_flags{"MIN", "MAX", "OPT"}; for (const auto& i : graph()->GetInputs()) { for (const auto& f : batch_flags) { stack_.func_call("setDimensions", std::nullopt, DocUtils::ToPtr("profile")) @@ -207,8 +207,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { // save codegen after build if (config()->use_tools) { const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - after_build_codes_ = - pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag).cast>(); + after_build_codes_ = pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag) + .cast>(); } // end define build method stack_.func_end("true"); @@ -470,7 +470,7 @@ void TensorRTCodeGen::CodeGenCmake() { if (config()->use_plugin) { stack_.line("add_definitions(-DPLUGIN_SUPPORT_TENSORRT)").line(); } - String link_libs = " ${TRT_LIBS}"; + ffi::String link_libs = " ${TRT_LIBS}"; if (config()->extern_libs.size() > 0) { stack_.line("set(EXTERN_LIBS " + StringUtils::Join(config()->extern_libs, " ") + ")"); link_libs = link_libs + " ${EXTERN_LIBS}"; @@ -481,17 +481,18 @@ void TensorRTCodeGen::CodeGenCmake() { .line("target_link_libraries(" + graph()->name + link_libs + ")"); } -const String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { +const ffi::String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { const auto& pair = graph()->FindProducerAndIdx(tensor); - const String& prefix = "tensor_" + std::to_string(pair.first->index); + const ffi::String& prefix = "tensor_" + std::to_string(pair.first->index); if (pair.first->outputs.size() > 1) { return prefix + "_" + std::to_string(pair.second); } return prefix; } -const String TensorRTCodeGen::CppDType(const DataType& dtype) { - const String& dtype_name = CppCodeGen::DType(dtype); +const ffi::String TensorRTCodeGen::CppDType(const DataType& dtype) { + const ffi::String& dtype_name = + CppCodeGen::DType(dtype); if (dtype_name == "int32") { return "int"; } @@ -507,11 +508,11 @@ const String TensorRTCodeGen::CppDType(const DataType& dtype) { return dtype_name; } -const String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { +const ffi::String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { return std::to_string(tensor->GetSize()->value) + " * sizeof(" + CppDType(tensor->dtype) + ")"; } -void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { +void TensorRTCodeGen::ReturnOnFail(const ffi::String& flag, const ffi::String& err) { stack_.cond_if("!" + flag) .func_call("logger.log") .call_arg("ILogger::Severity::kERROR") @@ -521,11 +522,11 @@ void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { } template -const String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { +const ffi::String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { if (dims.size() == 2 && !use_ndim) { return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; } - String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; for (size_t i = 0; i < dims.size(); i++) { dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); } @@ -533,7 +534,7 @@ const String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) return dims_str; } -const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) { +const ffi::String TensorRTCodeGen::ToDims(const ffi::Array& dims, bool use_ndim) { std::vector int_dims; for (const auto& d : dims) { int_dims.push_back(d->value); @@ -541,7 +542,7 @@ const String TensorRTCodeGen::ToDims(const Array& dims, bool use_ndim) return ToDims(int_dims, use_ndim); } -const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; @@ -554,8 +555,8 @@ const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { } } -const Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { - Map tensor_ctx; +const ffi::Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { + ffi::Map tensor_ctx; tensor_ctx.Set("ctx", "network"); for (const auto& pair : CppCodeGen::GetTensorCtx(tensor)) { @@ -564,8 +565,8 @@ const Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) return tensor_ctx; } -const Map TensorRTCodeGen::GetStepCtx() { - Map step_ctx; +const ffi::Map TensorRTCodeGen::GetStepCtx() { + ffi::Map step_ctx; step_ctx.Set("network", "network"); step_ctx.Set("config", "config"); step_ctx.Set("builder", "builder"); @@ -579,13 +580,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("msc.framework.tensorrt.GetTensorRTSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); }) - .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> String { + .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> ffi::String { #ifdef TENSORRT_ROOT_DIR return TENSORRT_ROOT_DIR; #else @@ -599,18 +600,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array MSCTensorRTCompiler(Array functions, - Map target_option, - Map constant_names) { - Array compiled_functions; +ffi::Array MSCTensorRTCompiler(ffi::Array functions, + ffi::Map target_option, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); ICHECK(target_option.count(name)) << "Can not find target option for " << name; - const auto& options = Downcast(target_option[name]); + const auto& options = Downcast(target_option[name]); MSCJSONSerializer serializer(constant_names, options); serializer.serialize(func); std::string graph_json = serializer.GetJSON(); diff --git a/src/contrib/msc/framework/tensorrt/codegen.h b/src/contrib/msc/framework/tensorrt/codegen.h index ea06a17f7c2b..87b4c330e40b 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.h +++ b/src/contrib/msc/framework/tensorrt/codegen.h @@ -60,34 +60,34 @@ class TensorRTCodeGen : public CppCodeGen GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get the tensor context for codegen_tensor*/ - const Map GetTensorCtx(const MSCTensor& tensor) final; + const ffi::Map GetTensorCtx(const MSCTensor& tensor) final; /*! \brief Get the step context for codegen_step*/ - const Map GetStepCtx() final; + const ffi::Map GetStepCtx() final; /*! \brief Generate return on fail codes*/ - void ReturnOnFail(const String& flag, const String& err); + void ReturnOnFail(const ffi::String& flag, const ffi::String& err); /*! \brief Get the index tensor*/ - const String IdxTensor(const MSCTensor& tensor); + const ffi::String IdxTensor(const MSCTensor& tensor); /*! \brief Get the dtype from the datatype*/ - const String CppDType(const DataType& dtype); + const ffi::String CppDType(const DataType& dtype); /*! \brief Generate describe for tensor bytes*/ - const String GetTensorBytes(const MSCTensor& tensor); + const ffi::String GetTensorBytes(const MSCTensor& tensor); /*! \brief Get the tensorrt dims from dims*/ template - const String ToDims(const std::vector& dims, bool use_ndim = true); - const String ToDims(const Array& dims, bool use_ndim = true); + const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); + const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); private: - Array before_build_codes_; - Array after_build_codes_; + ffi::Array before_build_codes_; + ffi::Array after_build_codes_; }; } // namespace msc diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h index f006b21b816e..3a16e668fe96 100644 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -40,8 +40,8 @@ namespace msc { class TensorRTCodeGenHelper : public BaseCodeGenHelper { public: /*! \brief Get describe for default node input*/ - const String IdxInputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool process = false) final { + const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool process = false) final { const auto& pair = node->ProducerAndIdxOf(idx); if (pair.first->optype == "input") { return "*" + IdxNodeBase(pair.first, prefix, suffix); @@ -53,8 +53,8 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper { } /*! \brief Get describe for default node output*/ - const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) final { + const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool mark_exit = false) final { if (node->optype == "argmax" || node->optype == "argmin") { ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; return IdxNodeBase(node, prefix, suffix) + "->getOutput(1)"; @@ -70,8 +70,8 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper { } /*! \brief Get describe for default node weight*/ - const String IdxWeightBase(const MSCJoint& node, const String& wtype, const String& suffix = "", - bool process = false) final { + const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, + const ffi::String& suffix = "", bool process = false) final { return "mWeights[\"" + node->WeightAt(wtype)->name + "\"]"; } }; diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index 5a63ecbc7d06..4fde2bf8bc2e 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -31,7 +31,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array TensorRTOpCode::GetDocs() { +const ffi::Array TensorRTOpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); if (node()->optype == "tuple") { @@ -52,7 +52,7 @@ const Array TensorRTOpCode::GetDocs() { return stack_.GetDocs(); } -void TensorRTOpCode::SetPadding(const String& key) { +void TensorRTOpCode::SetPadding(const ffi::String& key) { const auto& padding = node()->GetTypeArrayAttr("padding"); if (padding.size() == 1) { SetLayerByDimsValue("Padding", std::vector{padding[0], padding[0]}, false); @@ -67,8 +67,8 @@ void TensorRTOpCode::SetPadding(const String& key) { } } -const String TensorRTOpCode::DeclareInputs(bool simplify) { - const String& inputs_ref = "inputs_" + std::to_string(node()->index); +const ffi::String TensorRTOpCode::DeclareInputs(bool simplify) { + const ffi::String& inputs_ref = "inputs_" + std::to_string(node()->index); if (node()->parents.size() == 1 && simplify) { const auto& idx_input = StringUtils::Replace(IdxInput(), "*", ""); stack_.declare("std::vector", inputs_ref + "_vec") @@ -85,9 +85,10 @@ const String TensorRTOpCode::DeclareInputs(bool simplify) { return inputs_ref; } -const String TensorRTOpCode::DType(const DataType& dtype) { - const String& dtype_name = BaseOpCode::DType(dtype); - String dtype_enum; +const ffi::String TensorRTOpCode::DType(const DataType& dtype) { + const ffi::String& dtype_name = + BaseOpCode::DType(dtype); + ffi::String dtype_enum; if (dtype_name == "int8") { dtype_enum = "DataType::kINT8"; } else if (dtype_name == "int32") { @@ -105,11 +106,11 @@ const String TensorRTOpCode::DType(const DataType& dtype) { } template -const String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { +const ffi::String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { if (dims.size() == 2 && !use_ndim) { return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; } - String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; + ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; for (size_t i = 0; i < dims.size(); i++) { dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); } @@ -117,7 +118,7 @@ const String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { return dims_str; } -const String TensorRTOpCode::ToDims(const Array& dims, bool use_ndim) { +const ffi::String TensorRTOpCode::ToDims(const ffi::Array& dims, bool use_ndim) { std::vector int_dims; for (const auto& d : dims) { int_dims.push_back(d->value); @@ -125,7 +126,7 @@ const String TensorRTOpCode::ToDims(const Array& dims, bool use_ndim) { return ToDims(int_dims, use_ndim); } -const String TensorRTOpCode::AttrToDims(const String& key, bool use_ndim) { +const ffi::String TensorRTOpCode::AttrToDims(const ffi::String& key, bool use_ndim) { const auto& dims = node()->GetTypeArrayAttr(key); return ToDims(dims, use_ndim); } @@ -139,7 +140,7 @@ const size_t TensorRTOpCode::ToReduceAxis(const std::vector& axes, size_t n return reduce_axis; } -const size_t TensorRTOpCode::AttrToReduceAxis(const String& key, size_t ndim) { +const size_t TensorRTOpCode::AttrToReduceAxis(const ffi::String& key, size_t ndim) { std::vector axes; if (node()->GetAttr(key, &axes)) { return ToReduceAxis(axes, ndim); @@ -149,56 +150,57 @@ const size_t TensorRTOpCode::AttrToReduceAxis(const String& key, size_t ndim) { return ToReduceAxis(std::vector{axis}, ndim); } -const size_t TensorRTOpCode::AttrToAxis(const String& key, size_t ndim) { +const size_t TensorRTOpCode::AttrToAxis(const ffi::String& key, size_t ndim) { size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; int axis = node()->GetTypeAttr(key); return CommonUtils::GetIndex(axis, valid_ndim); } template -void TensorRTOpCode::SetLayerByAttr(const String& method, const String& key) { +void TensorRTOpCode::SetLayerByAttr(const ffi::String& method, const ffi::String& key) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); } template -void TensorRTOpCode::SetLayerByValue(const String& method, const T& value) { +void TensorRTOpCode::SetLayerByValue(const ffi::String& method, const T& value) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).call_arg(value); } -void TensorRTOpCode::SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim) { +void TensorRTOpCode::SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, + bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(AttrToDims(key, use_ndim)); } template -void TensorRTOpCode::SetLayerByDimsValue(const String& method, const std::vector& value, +void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, const std::vector& value, bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } -void TensorRTOpCode::SetLayerByDimsValue(const String& method, const Array& value, - bool use_ndim) { +void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, + const ffi::Array& value, bool use_ndim) { stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } #define TENSORRT_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : TensorRTOpCode(func_name) {} + TypeName(const ffi::String& func_name) : TensorRTOpCode(func_name) {} -#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const String& func_name, const String& symbol) : TensorRTOpCode(func_name) { \ - symbol_ = symbol; \ - } \ - \ - private: \ - String symbol_; +#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const ffi::String& func_name, const ffi::String& symbol) : TensorRTOpCode(func_name) { \ + symbol_ = symbol; \ + } \ + \ + private: \ + ffi::String symbol_; class TensorRTActivationCodeGen : public TensorRTOpCode { public: - explicit TensorRTActivationCodeGen(const String& symbol) : TensorRTOpCode("Activation") { + explicit TensorRTActivationCodeGen(const ffi::String& symbol) : TensorRTOpCode("Activation") { symbol_ = symbol; } @@ -214,7 +216,7 @@ class TensorRTActivationCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { @@ -232,7 +234,7 @@ class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { stride.push_back(in_sizes[i] / out_sizes[i]); kernel.push_back((in_sizes[i] - (out_sizes[i] - 1) * stride[i])); } - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; stack_.op_call() .op_input_arg() .call_arg("PoolingType::k" + symbol_) @@ -243,7 +245,7 @@ class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { class TensorRTArgmaxminCodeGen : public TensorRTOpCode { public: - explicit TensorRTArgmaxminCodeGen(const String& symbol) : TensorRTOpCode("TopK") { + explicit TensorRTArgmaxminCodeGen(const ffi::String& symbol) : TensorRTOpCode("TopK") { symbol_ = symbol; } @@ -258,7 +260,7 @@ class TensorRTArgmaxminCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTAstypeCodeGen : public TensorRTOpCode { @@ -318,7 +320,7 @@ class TensorRTConstantCodeGen : public TensorRTOpCode { class TensorRTConvCodeGen : public TensorRTOpCode { public: - TensorRTConvCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + TensorRTConvCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { use_bias_ = use_bias; } @@ -342,7 +344,7 @@ class TensorRTConvCodeGen : public TensorRTOpCode { } else { stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); } - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; SetLayerByDimsAttr("Stride" + suffix, "strides", false); SetLayerByDimsAttr("Dilation" + suffix, "dilation", false); SetLayerByAttr("NbGroups", "groups"); @@ -355,7 +357,7 @@ class TensorRTConvCodeGen : public TensorRTOpCode { class TensorRTElemwiseCodeGen : public TensorRTOpCode { public: - explicit TensorRTElemwiseCodeGen(const String& symbol) : TensorRTOpCode("ElementWise") { + explicit TensorRTElemwiseCodeGen(const ffi::String& symbol) : TensorRTOpCode("ElementWise") { symbol_ = symbol; } @@ -365,7 +367,7 @@ class TensorRTElemwiseCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTGetItemCodeGen : public TensorRTOpCode { @@ -396,7 +398,7 @@ class TensorRTInputCodeGen : public TensorRTOpCode { class TensorRTLinearCodeGen : public TensorRTOpCode { public: - TensorRTLinearCodeGen(const String& func_name, bool use_bias) : TensorRTOpCode(func_name) { + TensorRTLinearCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { use_bias_ = use_bias; } @@ -464,7 +466,7 @@ class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { axes.push_back(i - 1); } } - const String& perm_ref = "perm_" + std::to_string(node()->index); + const ffi::String& perm_ref = "perm_" + std::to_string(node()->index); stack_.op_call().op_input_arg().declare("Permutation", perm_ref); for (size_t i = 0; i < axes.size(); i++) { stack_.assign(perm_ref + ".order[" + std::to_string(i) + "]", @@ -476,7 +478,7 @@ class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { class TensorRTPool2dCodeGen : public TensorRTOpCode { public: - explicit TensorRTPool2dCodeGen(const String& symbol) : TensorRTOpCode("PoolingNd") { + explicit TensorRTPool2dCodeGen(const ffi::String& symbol) : TensorRTOpCode("PoolingNd") { symbol_ = symbol; } @@ -486,7 +488,7 @@ class TensorRTPool2dCodeGen : public TensorRTOpCode { .op_input_arg() .call_arg("PoolingType::k" + symbol_) .call_arg(AttrToDims("pool_size", false)); - const String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; + const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; SetLayerByDimsAttr("Stride" + suffix, "strides", false); if (node()->GetTypeAttr("ceil_mode")) { SetLayerByValue("PaddingMode", "PaddingMode::kEXPLICIT_ROUND_UP"); @@ -498,12 +500,12 @@ class TensorRTPool2dCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTReduceCodeGen : public TensorRTOpCode { public: - explicit TensorRTReduceCodeGen(const String& symbol) : TensorRTOpCode("Reduce") { + explicit TensorRTReduceCodeGen(const ffi::String& symbol) : TensorRTOpCode("Reduce") { symbol_ = symbol; } @@ -517,7 +519,7 @@ class TensorRTReduceCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTReshapeCodeGen : public TensorRTOpCode { @@ -540,7 +542,7 @@ class TensorRTResize2dCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call().op_input_arg(); const auto& method = node()->GetTypeAttr("method"); - String resize_mode; + ffi::String resize_mode; if (method == "linear") { resize_mode = "LINEAR"; } else if (method == "nearest_neighbor") { @@ -663,7 +665,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - const String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; + const ffi::String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; stack_.op_call() .op_input_arg() .call_arg("TopKOperation::k" + symbol) @@ -685,7 +687,7 @@ class TensorRTTupleCodeGen : public TensorRTOpCode { class TensorRTUnaryCodeGen : public TensorRTOpCode { public: - explicit TensorRTUnaryCodeGen(const String& symbol) : TensorRTOpCode("Unary") { + explicit TensorRTUnaryCodeGen(const ffi::String& symbol) : TensorRTOpCode("Unary") { symbol_ = symbol; } @@ -695,7 +697,7 @@ class TensorRTUnaryCodeGen : public TensorRTOpCode { } private: - String symbol_; + ffi::String symbol_; }; class TensorRTWhereCodeGen : public TensorRTOpCode { @@ -718,9 +720,9 @@ class TensorRTPluginOpCodeGen : public TensorRTOpCode { const auto& plugin = GetPlugin(node()->optype); const auto& input_ref = "inputs_" + std::to_string(producer->index); - const String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; - const String& plugin_ref = "plugin_" + std::to_string(node()->index); - const String& layouts_ref = "layouts_" + std::to_string(node()->index); + const ffi::String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; + const ffi::String& plugin_ref = "plugin_" + std::to_string(node()->index); + const ffi::String& layouts_ref = "layouts_" + std::to_string(node()->index); stack_.declare("std::vector", layouts_ref, 0, false); for (const auto& i : node()->GetInputs()) { stack_.declare_arg(DocUtils::ToStr(i->layout.name())); @@ -735,9 +737,10 @@ class TensorRTPluginOpCodeGen : public TensorRTOpCode { } }; -const std::shared_ptr>> +const std::shared_ptr>> GetTensorRTOpCodes() { - static auto map = std::make_shared>>(); + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // unary ops map->emplace("abs", std::make_shared("ABS")); diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h index 2d9bcb6acfa2..ddf7fb1522be 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h @@ -49,22 +49,22 @@ class TensorRTOpCode : public BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; /*! \brief Get func_name for the default node*/ - const String callee_name() final { + const ffi::String callee_name() final { return "network->add" + BaseOpCode::callee_name(); } /*! \brief Get valid return name for the default node*/ - const String ret_name() final { return "auto " + IdxNode(); } + const ffi::String ret_name() final { return "auto " + IdxNode(); } /*! \brief Get the dtype from the datatype*/ - const String DType(const DataType& dtype) final; + const ffi::String DType(const DataType& dtype) final; protected: TensorRTOpCodeStack stack_; @@ -73,50 +73,52 @@ class TensorRTOpCode : public BaseOpCode - const String ToDims(const std::vector& dims, bool use_ndim = true); - const String ToDims(const Array& dims, bool use_ndim = true); + const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); + const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); /*! \brief Get the tensorrt dims from attribute*/ - const String AttrToDims(const String& key, bool use_ndim = true); + const ffi::String AttrToDims(const ffi::String& key, bool use_ndim = true); /*! \brief Get the tensorrt reduce axis from dims*/ const size_t ToReduceAxis(const std::vector& axes, size_t ndim = 0); /*! \brief Get the tensorrt reduce axis from attribute*/ - const size_t AttrToReduceAxis(const String& key = "axis", size_t ndim = 0); + const size_t AttrToReduceAxis(const ffi::String& key = "axis", size_t ndim = 0); /*! \brief Get the attribute axis from attribute*/ - const size_t AttrToAxis(const String& key = "axis", size_t ndim = 0); + const size_t AttrToAxis(const ffi::String& key = "axis", size_t ndim = 0); /*! \brief Set layer by attribute*/ template - void SetLayerByAttr(const String& method, const String& key); + void SetLayerByAttr(const ffi::String& method, const ffi::String& key); /*! \brief Set layer by value*/ template - void SetLayerByValue(const String& method, const T& value); + void SetLayerByValue(const ffi::String& method, const T& value); /*! \brief Set layer by dims attribute*/ - void SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim = true); + void SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, bool use_ndim = true); /*! \brief Set layer by dims value*/ template - void SetLayerByDimsValue(const String& method, const std::vector& value, bool use_ndim = true); - void SetLayerByDimsValue(const String& method, const Array& value, bool use_ndim = true); + void SetLayerByDimsValue(const ffi::String& method, const std::vector& value, + bool use_ndim = true); + void SetLayerByDimsValue(const ffi::String& method, const ffi::Array& value, + bool use_ndim = true); }; /*! * \brief Get the map of available TensorRTOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> +const std::shared_ptr>> GetTensorRTOpCodes(); } // namespace msc diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 3d43c74958ec..06f694d463d7 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -58,7 +58,7 @@ struct TensorRTTransConfig { } }; -const TensorRTTransConfig ParseConfig(const String& config_str) { +const TensorRTTransConfig ParseConfig(const ffi::String& config_str) { TensorRTTransConfig config; if (config_str.size() > 0) { std::istringstream is(config_str); @@ -70,12 +70,12 @@ const TensorRTTransConfig ParseConfig(const String& config_str) { using FRewriteTensorRT = ffi::TypedFunction& new_calls, const String& config)>; + const ffi::Map& new_calls, const ffi::String& config)>; -const Array BroadcastShape(const Array& src_shape, - const Array& out_shape) { +const ffi::Array BroadcastShape(const ffi::Array& src_shape, + const ffi::Array& out_shape) { size_t diff = out_shape.size() - src_shape.size(); - Array leading_shape, tailing_shape; + ffi::Array leading_shape, tailing_shape; for (size_t i = 0; i < diff; i++) { leading_shape.push_back(Integer(1)); } @@ -95,7 +95,7 @@ const Array BroadcastShape(const Array& src_shape, } Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& shape_a = ExprUtils::GetShape(call->args[0]); const auto& shape_b = ExprUtils::GetShape(call->args[1]); @@ -118,7 +118,7 @@ Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; if (new_calls.count(call->args[0]) && new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { @@ -135,7 +135,7 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, const auto* conv_attrs = conv2d->attrs.as(); if (conv_attrs->data_layout == "NCHW") { // expand bias reshape - Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; + ffi::Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); const auto& exp_bias = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_bias"), reshape_op, @@ -155,14 +155,14 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& out_dtype = ExprUtils::GetDataType(var); const auto* src_attrs = src_call->attrs.as(); ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) << "Unexpected out dtype " << out_dtype; static const Op& topk_op = Op::Get("relax.topk"); - auto topk_attrs = make_object(); + auto topk_attrs = ffi::make_object(); topk_attrs->k = 1; if (src_attrs->axis.has_value()) { topk_attrs->axis = src_attrs->axis.value(); @@ -187,7 +187,7 @@ Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); @@ -218,8 +218,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call static const Op& exp_op = Op::Get("relax.exp"); // prepare q,k,v - auto permute_attrs = make_object(); - Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; + auto permute_attrs = ffi::make_object(); + ffi::Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; permute_attrs->axes = axes; const auto& q_trans = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op, @@ -230,17 +230,17 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call const auto& v_trans = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op, {call->args[2]}, Attrs(permute_attrs)); - Array q_shape({batch_size * num_head, seq_len, head_dim}); + ffi::Array q_shape({batch_size * num_head, seq_len, head_dim}); const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"), reshape_op, {q_trans, ShapeExpr(q_shape)}); - Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); + ffi::Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"), reshape_op, {k_trans, ShapeExpr(k_shape)}); - Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); + ffi::Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"), reshape_op, {v_trans, ShapeExpr(v_shape)}); - auto reduce_permute_attrs = make_object(); - Array v_axes{Integer(0), Integer(2), Integer(1)}; + auto reduce_permute_attrs = ffi::make_object(); + ffi::Array v_axes{Integer(0), Integer(2), Integer(1)}; reduce_permute_attrs->axes = v_axes; // transpose for batch_matmul const auto& k_reshape_trans = @@ -248,7 +248,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs)); // calculate product - auto matmul_attrs = make_object(); + auto matmul_attrs = ffi::make_object(); matmul_attrs->out_dtype = in_dtype; const auto& qk_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op, @@ -273,8 +273,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // bias Expr prod = p_scale; if (call->args.size() == 4) { - Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; - Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; + ffi::Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; + ffi::Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"), reshape_op, {prod, ShapeExpr(exp_shape)}); const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"), @@ -286,7 +286,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // causal_mask Expr s_value; if (!src_attrs->causal_mask.has_value()) { - auto softmax_attrs = make_object(); + auto softmax_attrs = ffi::make_object(); softmax_attrs->axis = 2; s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, {prod}, Attrs(softmax_attrs)); @@ -302,8 +302,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call } const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"), tril_op, {prod, tril_k}); - auto reduce_attrs = make_object(); - Array axis{Integer(2)}; + auto reduce_attrs = ffi::make_object(); + ffi::Array axis{Integer(2)}; reduce_attrs->axis = axis; reduce_attrs->keepdims = true; const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"), @@ -324,18 +324,18 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // final calculation const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"), matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); - Array o_shape{batch_size, num_head, seq_len, head_dim_v}; + ffi::Array o_shape{batch_size, num_head, seq_len, head_dim_v}; return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define expand shape - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array exp_shape(input_shape.size(), Integer(1)); exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); // create eps constant @@ -380,11 +380,11 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, {res, exp_offset}); } - return Tuple(Array{res}, call->span); + return Tuple(ffi::Array{res}, call->span); } Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& output_shape = ExprUtils::GetShape(var); @@ -394,8 +394,8 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca int64_t in_dim = Downcast(input_shape[i])->value; int64_t out_dim = Downcast(output_shape[i])->value; if (in_dim != out_dim) { - Array concat_inputs(out_dim / in_dim, concat_input); - auto concat_attrs = make_object(); + ffi::Array concat_inputs(out_dim / in_dim, concat_input); + auto concat_attrs = ffi::make_object(); concat_attrs->axis = i; concat_input = RewriteUtils::MakeCall( builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, @@ -406,17 +406,19 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca } Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto* src_attrs = src_call->attrs.as(); const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& weight_shape = ExprUtils::GetShape(call->args[1]); const auto& output_shape = ExprUtils::GetShape(var); if (src_attrs->data_layout == "NCW") { - Array new_args; + ffi::Array new_args; // expand inputs - Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), input_shape[2]}; - Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), weight_shape[2]}; + ffi::Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), + input_shape[2]}; + ffi::Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), + weight_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_input"), reshape_op, @@ -426,11 +428,11 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, {call->args[1], ShapeExpr(exp_weight_shape)})); // change to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = make_object(); - conv_attrs->strides = Array{src_attrs->strides[0], Integer(1)}; + auto conv_attrs = ffi::make_object(); + conv_attrs->strides = ffi::Array{src_attrs->strides[0], Integer(1)}; conv_attrs->padding = - Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; - conv_attrs->dilation = Array{src_attrs->dilation[0], Integer(1)}; + ffi::Array{Integer(0), src_attrs->padding[0], Integer(0), src_attrs->padding[1]}; + conv_attrs->dilation = ffi::Array{src_attrs->dilation[0], Integer(1)}; conv_attrs->groups = src_attrs->groups; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; @@ -448,7 +450,7 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { // 0.5 * x * (1 + erf(sqrt(0.5) * x)) const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); @@ -476,7 +478,7 @@ Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x))) const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); @@ -517,13 +519,13 @@ Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); - Array group_shape = input_shape; - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array group_shape = input_shape; + ffi::Array exp_shape(input_shape.size(), Integer(1)); size_t axis = CommonUtils::GetIndex(src_attrs->channel_axis, input_shape.size()); int64_t channel_dim = Downcast(input_shape[axis])->value * Downcast(input_shape[axis + 1])->value / src_attrs->num_groups; @@ -551,7 +553,7 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call {call->args[0], ShapeExpr(group_shape)}); // mean(input) - auto mean_attrs = make_object(); + auto mean_attrs = ffi::make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, @@ -566,7 +568,7 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) - Array exp_eps_shape(input_shape.size(), Integer(1)); + ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), reshape_op, {eps, ShapeExpr(exp_eps_shape)}); const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), @@ -599,12 +601,12 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); - Array exp_shape(input_shape.size(), Integer(1)); + ffi::Array exp_shape(input_shape.size(), Integer(1)); for (const auto& a : src_attrs->axes) { size_t index = CommonUtils::GetIndex(static_cast(a->value), input_shape.size()); exp_shape.Set(index, input_shape[index]); @@ -624,7 +626,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // mean(input) - auto mean_attrs = make_object(); + auto mean_attrs = ffi::make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, @@ -639,7 +641,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) - Array exp_eps_shape(input_shape.size(), Integer(1)); + ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), reshape_op, {eps, ShapeExpr(exp_eps_shape)}); const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), @@ -676,7 +678,7 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& trt_config = ParseConfig(config); const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& shape_a = ExprUtils::GetShape(call->args[0]); @@ -686,27 +688,27 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, trt_config.linear_to_conv) { const auto& out_shape = ExprUtils::GetShape(var); PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1); - Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; + ffi::Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; const auto& exp_in = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_in"), reshape_op, {call->args[0], ShapeExpr(exp_shape)}); // transpose and expand weight to OIHW static const Op& permute_dims_op = Op::Get("relax.permute_dims"); - auto permute_attrs = make_object(); - Array axes{Integer(1), Integer(0)}; + auto permute_attrs = ffi::make_object(); + ffi::Array axes{Integer(1), Integer(0)}; permute_attrs->axes = axes; const auto& trans_weight = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "trans_weight"), permute_dims_op, {call->args[1]}, Attrs(permute_attrs)); - Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; + ffi::Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; const auto& exp_weight = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), reshape_op, {trans_weight, ShapeExpr(weight_shape)}); // to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = make_object(); - conv_attrs->strides = Array{Integer(1), Integer(1)}; - conv_attrs->padding = Array{Integer(0), Integer(0), Integer(0), Integer(0)}; - conv_attrs->dilation = Array{Integer(1), Integer(1)}; + auto conv_attrs = ffi::make_object(); + conv_attrs->strides = ffi::Array{Integer(1), Integer(1)}; + conv_attrs->padding = ffi::Array{Integer(0), Integer(0), Integer(0), Integer(0)}; + conv_attrs->dilation = ffi::Array{Integer(1), Integer(1)}; conv_attrs->groups = 1; conv_attrs->data_layout = "NCHW"; conv_attrs->kernel_layout = "OIHW"; @@ -717,7 +719,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), call->sinfo_args, call->span); } if (shape_a.size() > shape_b.size()) { - Array exp_shape(shape_a.size(), Integer(1)); + ffi::Array exp_shape(shape_a.size(), Integer(1)); size_t diff = shape_a.size() - shape_b.size(); for (size_t i = diff; i < shape_a.size(); i++) { exp_shape.Set(i, shape_b[i - diff]); @@ -728,7 +730,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); } if (shape_a.size() < shape_b.size()) { - Array exp_shape(shape_b.size(), Integer(1)); + ffi::Array exp_shape(shape_b.size(), Integer(1)); size_t diff = shape_b.size() - shape_a.size(); for (size_t i = diff; i < shape_b.size(); i++) { exp_shape.Set(i, shape_a[i - diff]); @@ -742,7 +744,7 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); @@ -761,7 +763,7 @@ Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; // create ops static const Op& multiply_op = Op::Get("relax.multiply"); @@ -773,7 +775,7 @@ Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, } Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& output_shape = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); @@ -782,7 +784,7 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call } Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const String& config) { + const ffi::Map& new_calls, const ffi::String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); @@ -797,7 +799,7 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, split_ends.push_back(i * size + size); } } else if (src_attrs->indices_or_sections->IsInstance()) { - const auto& indices = Downcast>(src_attrs->indices_or_sections); + const auto& indices = Downcast>(src_attrs->indices_or_sections); int64_t last_index = 0; for (size_t i = 0; i < indices.size(); ++i) { split_begins.push_back(last_index); @@ -811,14 +813,15 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, << src_attrs->indices_or_sections->GetTypeKey() << ")"; } // create strided_slices - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < split_begins.size(); i++) { static const Op& strided_slice_op = Op::Get("relax.strided_slice"); - const auto& axes = Tuple(Array{PrimValue(IntImm(DataType::Int(64), axis))}); - const auto& begin = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); - const auto& end = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); - const auto& strides = Tuple(Array{PrimValue(IntImm(DataType::Int(64), 1))}); - auto attrs = make_object(); + const auto& axes = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), axis))}); + const auto& begin = + Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); + const auto& end = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); + const auto& strides = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), 1))}); + auto attrs = ffi::make_object(); attrs->assume_inbound = true; const auto& slice = RewriteUtils::MakeCall( builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), strided_slice_op, @@ -872,17 +875,17 @@ TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", Re class TensorRTTransformer : public ExprMutator { public: - explicit TensorRTTransformer(IRModule ctx_module, const String& config) + explicit TensorRTTransformer(IRModule ctx_module, const ffi::String& config) : ExprMutator(ctx_module) { config_ = config; } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { if (const auto* op_node = call_node->op.as()) { - const auto& op = Downcast(GetRef(op_node)); + const auto& op = Downcast(ffi::GetRef(op_node)); const auto& rewrite_map = Op::GetAttrMap("FRewriteTensorRT"); if (rewrite_map.count(op)) { - const auto& call = GetRef(call_node); + const auto& call = ffi::GetRef(call_node); FRewriteTensorRT f = rewrite_map[op]; const auto& new_call = f(builder_, binding->var, call, new_calls_, config_); if (new_call != call) { @@ -897,17 +900,18 @@ class TensorRTTransformer : public ExprMutator { } private: - Map new_calls_; - String config_; + ffi::Map new_calls_; + ffi::String config_; }; -Function TransformTensorRT(const Function& func, const IRModule& module, const String& config) { +Function TransformTensorRT(const Function& func, const IRModule& module, + const ffi::String& config) { return Downcast(TensorRTTransformer(module, config).VisitExpr(func)); } namespace transform { -Pass TransformTensorRT(const String& config) { +Pass TransformTensorRT(const ffi::String& config) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return relax::TransformTensorRT(f, m, config); }; diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 68c55bb9cbce..b1ab14b9fd06 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -92,7 +92,7 @@ void TorchCodeGen::CodeGenGraph() { } CodeGenNode(node, config()->use_tools); } - Array idx_outputs; + ffi::Array idx_outputs; for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); idx_outputs.push_back(IdxOutputBase(pair.first, pair.second, true)); @@ -140,7 +140,7 @@ void TorchCodeGen::CodeGenInference() { } } -const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; @@ -156,8 +156,8 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.torch.GetTorchSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); diff --git a/src/contrib/msc/framework/torch/codegen.h b/src/contrib/msc/framework/torch/codegen.h index 0ee860bb55c8..1e5032309cb6 100644 --- a/src/contrib/msc/framework/torch/codegen.h +++ b/src/contrib/msc/framework/torch/codegen.h @@ -56,10 +56,10 @@ class TorchCodeGen : public PyCodeGen { void CodeGenInference() final; /*! \brief Get the docs for the op*/ - const Array GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "torch.Tensor"; } + const ffi::String TensorType() const final { return "torch.Tensor"; } private: bool is_init_; diff --git a/src/contrib/msc/framework/torch/codegen_utils.h b/src/contrib/msc/framework/torch/codegen_utils.h index c63de27519e0..13dee306e942 100644 --- a/src/contrib/msc/framework/torch/codegen_utils.h +++ b/src/contrib/msc/framework/torch/codegen_utils.h @@ -39,8 +39,8 @@ namespace msc { class TorchCodeGenHelper : public BaseCodeGenHelper { public: /*! \brief Get describe for default node input*/ - const String IdxOutputBase(const MSCJoint& node, const String& prefix = "", int idx = 0, - const String& suffix = "", bool mark_exit = false) final { + const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, + const ffi::String& suffix = "", bool mark_exit = false) final { if ((node->optype == "max" || node->optype == "min") && node->OutputAt(0)->Ndim() > 0) { ICHECK(idx == 0) << "max and min op only support 1 outputs, get " << node; return IdxNodeBase(node, prefix, suffix) + ".values"; diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 9e3652f04118..8f649469855e 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -30,7 +30,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array TorchOpCode::GetDocs() { +const ffi::Array TorchOpCode::GetDocs() { stack_.Config(this); if (is_init()) { CodeGenInit(); @@ -50,7 +50,7 @@ void TorchOpCode::CodeGenInit() { void TorchOpCode::CodeGenForward() { stack_.op_call().op_inputs_arg(false); } -const StrictListDoc TorchOpCode::GetPadding(const String& key) { +const StrictListDoc TorchOpCode::GetPadding(const ffi::String& key) { std::vector padding, src_padding; ICHECK(node()->GetAttr(key, &src_padding)); if (node()->optype == "nn.conv1d" || node()->optype == "msc.conv1d_bias") { @@ -76,9 +76,9 @@ const StrictListDoc TorchOpCode::GetPadding(const String& key) { return DocUtils::ToList(padding); } -#define TORCH_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const String& module_name, const String& func_name) \ +#define TORCH_OP_CODEGEN_METHODS(TypeName) \ + public: \ + TypeName(const ffi::String& module_name, const ffi::String& func_name) \ : TorchOpCode(module_name, func_name) {} class TorchAdaptivePoolCodeGen : public TorchOpCode { @@ -118,7 +118,7 @@ class TorchAxesCodeGen : public TorchOpCode { protected: void CodeGenInit() final { if (module_name().size() > 0) { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_list_arg(key, ""); } else { TorchOpCode::CodeGenInit(); @@ -129,7 +129,7 @@ class TorchAxesCodeGen : public TorchOpCode { if (module_name().size() > 0) { TorchOpCode::CodeGenForward(); } else { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().op_list_arg(key, ""); } } @@ -268,7 +268,7 @@ class TorchConstantCodeGen : public TorchOpCode { class TorchConvCodeGen : public TorchOpCode { public: - TorchConvCodeGen(const String& module_name, const String& func_name, bool use_bias) + TorchConvCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} protected: @@ -343,9 +343,9 @@ class TorchExpandDimsCodeGen : public TorchOpCode { protected: void CodeGenForward() final { const auto& axes = node()->GetTypeArrayAttr("axis"); - String idx_input = IdxInput(); + ffi::String idx_input = IdxInput(); for (size_t i = 0; i < axes.size(); i++) { - String idx_out = IdxNode(); + ffi::String idx_out = IdxNode(); if (i < axes.size() - 1) { idx_out = idx_out + "_" + std::to_string(i); } @@ -400,7 +400,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { << "Only support center and scale batchnorm, get " << node(); const auto& axes = CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); - Array normalized_shape; + ffi::Array normalized_shape; for (const auto& a : axes) { normalized_shape.push_back(node()->InputAt(0)->DimAt(a)); } @@ -412,7 +412,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { class TorchLinearCodeGen : public TorchOpCode { public: - TorchLinearCodeGen(const String& module_name, const String& func_name, bool use_bias) + TorchLinearCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} protected: @@ -546,7 +546,7 @@ class TorchReshapeCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - Array shape = node()->OutputAt(0)->shape; + ffi::Array shape = node()->OutputAt(0)->shape; const auto& out_layout = node()->OutputAt(0)->layout; if (out_layout.defined()) { int32_t batch_dim = out_layout.IndexOf(tvm::tir::LayoutAxis::Get("N")); @@ -564,7 +564,7 @@ class TorchResize2dCodeGen : public TorchOpCode { protected: void CodeGenForward() final { const auto& method = node()->GetTypeAttr("method"); - String v_method; + ffi::String v_method; if (method == "nearest_neighbor") { v_method = "nearest"; } else { @@ -657,7 +657,7 @@ class TorchStridedSliceCodeGen : public TorchOpCode { for (size_t i = 0; i < axes.size(); i++) { axes_map[axes[i]] = i; } - Array slice; + ffi::Array slice; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { if (axes_map.count(i)) { size_t idx = axes_map[i]; @@ -712,8 +712,10 @@ class TorchPluginOpCodeGen : public TorchOpCode { void CodeGenForward() final { stack_.op_call().op_inputs_arg(false); } }; -const std::shared_ptr>> GetTorchOpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetTorchOpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // simple ops diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h index 80b7f5c60d1d..e732e502ce31 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ b/src/contrib/msc/framework/torch/torch_opcode.h @@ -49,31 +49,31 @@ class TorchOpCode : public BaseOpCode { * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit TorchOpCode(const String& module_name, const String& func_name) + explicit TorchOpCode(const ffi::String& module_name, const ffi::String& func_name) : BaseOpCode(func_name) { module_name_ = module_name; } /*! \brief Config the TorchOpCode*/ void Config(const MSCJoint& node, const std::shared_ptr config, bool is_init, - const Map& prims) { + const ffi::Map& prims) { BaseOpCode::Config(node, config, prims); is_init_ = is_init; module_ref_ = "self." + StringUtils::Replace(node->name, ".", "_"); } /*! \brief Get return describe for default node*/ - const String IdxNode() final { + const ffi::String IdxNode() final { return is_init_ ? module_ref_ : BaseOpCode::IdxNode(); }; /*! \brief Get dtype string*/ - const String DType(const DataType& dtype) final { + const ffi::String DType(const DataType& dtype) final { return "torch." + BaseOpCode::DType(dtype); } /*! \brief Get func_name for the default node*/ - const String callee_name() final { + const ffi::String callee_name() final { if (is_init_) { return module_name_; } @@ -84,7 +84,7 @@ class TorchOpCode : public BaseOpCode { } /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; protected: TorchOpCodeStack stack_; @@ -96,28 +96,29 @@ class TorchOpCode : public BaseOpCode { virtual void CodeGenForward(); /*! \brief Get the padding from op*/ - const StrictListDoc GetPadding(const String& key = "padding"); + const StrictListDoc GetPadding(const ffi::String& key = "padding"); /*! \brief Get the is_init_ of codegen*/ bool is_init() { return is_init_; } /*! \brief Get the module_name of codegen*/ - const String module_name() { return module_name_; } + const ffi::String module_name() { return module_name_; } /*! \brief Get the module_ref of codegen*/ - const String module_ref() { return module_ref_; } + const ffi::String module_ref() { return module_ref_; } private: bool is_init_; - String module_name_; - String module_ref_; + ffi::String module_name_; + ffi::String module_ref_; }; /*! * \brief Get the map of available TorchOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetTorchOpCodes(); +const std::shared_ptr>> +GetTorchOpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 7c42ba8d142a..2a9ed4c8f703 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -35,7 +35,7 @@ void RelaxCodeGen::CodeGenHeader() { void RelaxCodeGen::CodeGenGraph() { stack_.func_def(graph()->name, "tvm.IRModule"); - Array idx_inputs; + ffi::Array idx_inputs; for (const auto& i : graph()->GetInputs()) { const auto& pair = graph()->FindProducerAndIdx(i); const auto& idx_input = IdxOutputBase(pair.first, pair.second); @@ -89,13 +89,13 @@ void RelaxCodeGen::CodeGenGraph() { } // mark outputs stack_.comment("Emit the outputs"); - Array idx_exits; + ffi::Array idx_exits; for (const auto& e : graph()->GetExits()) { const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : ""); if (config()->use_tools) { if (e->outputs.size() > 1) { - Array tuple_outputs; + ffi::Array tuple_outputs; for (size_t o_idx = 0; o_idx < e->outputs.size(); o_idx++) { const auto& t_output = IdxOutputBase(e, o_idx, true); tuple_outputs.push_back(t_output); @@ -151,7 +151,7 @@ void RelaxCodeGen::CodeGenInference() { const auto& producer = graph()->FindProducer(i); stack_.call_arg(IdxNodeBase(producer)); } - String target, device; + ffi::String target, device; if (config()->test_device == "cpu") { target = "llvm"; device = "tvm.cpu()"; @@ -189,7 +189,7 @@ void RelaxCodeGen::CodeGenInference() { } } -const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { +const ffi::String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { if (prim->optype == "shape") { const auto& producer = graph()->FindNode(prim->GetTypeAttr("producer")); int out_idx = prim->GetTypeAttr("out_idx"); @@ -199,7 +199,7 @@ const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { return PyCodeGen::DescribePrim(prim); } -const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { +const ffi::Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; @@ -215,8 +215,8 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.framework.tvm.GetRelaxSources", - [](const MSCGraph& graph, const String& codegen_config, - const String& print_config) -> Map { + [](const MSCGraph& graph, const ffi::String& codegen_config, + const ffi::String& print_config) -> ffi::Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); codegen.Init(); return codegen.GetSources(print_config); diff --git a/src/contrib/msc/framework/tvm/codegen.h b/src/contrib/msc/framework/tvm/codegen.h index 249105b5a50b..0874e21acd4d 100644 --- a/src/contrib/msc/framework/tvm/codegen.h +++ b/src/contrib/msc/framework/tvm/codegen.h @@ -56,13 +56,13 @@ class RelaxCodeGen : public PyCodeGen { void CodeGenInference() final; /*! \brief Describe the prim*/ - const String DescribePrim(const MSCPrim& prim) final; + const ffi::String DescribePrim(const MSCPrim& prim) final; /*! \brief Get the docs for the op*/ - const Array GetOpCodes(const MSCJoint& node) final; + const ffi::Array GetOpCodes(const MSCJoint& node) final; /*! \brief Get tensor type of the framework*/ - const String TensorType() const final { return "relax.Expr"; } + const ffi::String TensorType() const final { return "relax.Expr"; } }; } // namespace msc diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index a4be884858dc..54d55721ac4a 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -29,7 +29,7 @@ namespace tvm { namespace contrib { namespace msc { -const Array RelaxOpCode::GetDocs() { +const ffi::Array RelaxOpCode::GetDocs() { stack_.Config(this); CodeGenBuild(); bool emit_var = true; @@ -43,14 +43,14 @@ const Array RelaxOpCode::GetDocs() { return stack_.GetDocs(); } -void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { +void RelaxOpCode::BuilderEmit(const ffi::String& ret, const ffi::String& name) { stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { stack_.call_arg(DocUtils::ToStr(name), "name_hint"); } } -const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { +const ExprDoc RelaxOpCode::GetOutDtype(const ffi::String& key, int input_idx) { if (config()->use_tools && input_idx >= 0 && node()->inputs.size() > static_cast(input_idx)) { return DocUtils::ToDoc(IdxInput(input_idx) + ".struct_info.dtype"); @@ -62,7 +62,7 @@ const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { return DocUtils::ToStr(out_dtype); } -const std::vector RelaxOpCode::GetAxes(const String& key) { +const std::vector RelaxOpCode::GetAxes(const ffi::String& key) { std::vector axes; int axis; if (!node()->GetAttr(key, &axes) && node()->GetAttr(key, &axis)) { @@ -73,7 +73,7 @@ const std::vector RelaxOpCode::GetAxes(const String& key) { #define RELAX_OP_CODEGEN_METHODS(TypeName) \ public: \ - TypeName(const String& func_name) : RelaxOpCode(func_name) {} + TypeName(const ffi::String& func_name) : RelaxOpCode(func_name) {} class RelaxAdaptivePool2dCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAdaptivePool2dCodeGen) @@ -101,7 +101,7 @@ class RelaxAttentionCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { for (size_t i = 0; i < 3; i++) { - const String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); + const ffi::String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); stack_.op_call("relax.op.permute_dims", IdxInput(i)) .op_input_arg(i) .op_list_arg(axes_key, "axes"); @@ -129,7 +129,7 @@ class RelaxAxesCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = node()->HasAttr("axes") ? "axes" : "axis"; + const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(GetAxes(key)), key); } }; @@ -210,7 +210,7 @@ class RelaxBiasAddCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { int axis = CommonUtils::GetIndex(node()->GetTypeAttr("axis"), node()->OutputAt(0)->Ndim()); - Array expand_shape; + ffi::Array expand_shape; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { if (i == static_cast(axis)) { expand_shape.push_back(node()->InputAt(0)->DimAt(i)); @@ -263,7 +263,7 @@ class RelaxConstantCodeGen : public RelaxOpCode { class RelaxConvCodeGen : public RelaxOpCode { public: - RelaxConvCodeGen(const String& func_name, bool use_bias) + RelaxConvCodeGen(const ffi::String& func_name, bool use_bias) : RelaxOpCode(func_name), use_bias_(use_bias) {} protected: @@ -286,7 +286,7 @@ class RelaxConvCodeGen : public RelaxOpCode { << "out_layout or data_layout should be given, get " << node(); } const auto& out_layout = tir::Layout(out_layout_str); - Array expand_shape; + ffi::Array expand_shape; for (size_t i = 0; i < node()->OutputAt(0)->Ndim(); i++) { if (out_layout[i].name() == "C") { expand_shape.push_back(node()->OutputAt(0)->DimAt(i)); @@ -335,7 +335,7 @@ class RelaxEinsumCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = config()->from_relay ? "equation" : "subscripts"; + const ffi::String& key = config()->from_relay ? "equation" : "subscripts"; stack_.op_call().op_inputs_arg().op_str_arg(key, "subscripts"); } }; @@ -480,12 +480,12 @@ class RelaxPadCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - Array pad_width; + ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; + const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + + std::to_string(attr_pad_width[i + 1]) + "]"; pad_width.push_back(cur_pad); } stack_.op_call() @@ -530,7 +530,7 @@ class RelaxPermuteDimsCodeGen : public RelaxOpCode { class RelaxReduceAxisCodeGen : public RelaxOpCode { public: - RelaxReduceAxisCodeGen(const String& func_name, bool as_list) + RelaxReduceAxisCodeGen(const ffi::String& func_name, bool as_list) : RelaxOpCode(func_name), as_list_(as_list) {} protected: @@ -602,7 +602,7 @@ class RelaxResize2dCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { // roi has forced to be float list - Array roi_list; + ffi::Array roi_list; std::vector roi = node()->GetTypeArrayAttr("roi"); for (const auto& r : roi) { roi_list.push_back("float(" + std::to_string(r) + ")"); @@ -680,7 +680,7 @@ class RelaxTileCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - const String& key = config()->from_relay ? "reps" : "repeats"; + const ffi::String& key = config()->from_relay ? "reps" : "repeats"; stack_.op_call().op_input_arg().op_list_arg(key, "repeats"); } }; @@ -698,7 +698,7 @@ class RelaxTriCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { if (node()->optype == "trilu") { - const String& func_name = + const ffi::String& func_name = node()->GetTypeAttr("upper") ? "relax.op.triu" : "relax.op.tril"; stack_.op_call(func_name).op_input_arg().op_arg("k"); } else { @@ -720,8 +720,10 @@ class RelaxPluginOpCodeGen : public RelaxOpCode { } }; -const std::shared_ptr>> GetRelaxOpCodes() { - static auto map = std::make_shared>>(); +const std::shared_ptr>> +GetRelaxOpCodes() { + static auto map = + std::make_shared>>(); if (!map->empty()) return map; // binary && unary ops map->emplace("abs", std::make_shared("relax.op.abs")); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.h b/src/contrib/msc/framework/tvm/relax_opcode.h index e5914149184e..bbbee44d822d 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.h +++ b/src/contrib/msc/framework/tvm/relax_opcode.h @@ -49,11 +49,11 @@ class RelaxOpCode : public BaseOpCode { * \param func_name the function name for the node. * \param config the config json for the node. */ - explicit RelaxOpCode(const String& func_name) + explicit RelaxOpCode(const ffi::String& func_name) : BaseOpCode(func_name) {} /*! \brief Convert node to docs*/ - const Array GetDocs() final; + const ffi::Array GetDocs() final; protected: RelaxOpCodeStack stack_; @@ -62,20 +62,21 @@ class RelaxOpCode : public BaseOpCode { virtual void CodeGenBuild() = 0; /*! \brief coda stack emit docs*/ - void BuilderEmit(const String& ret, const String& name = ""); + void BuilderEmit(const ffi::String& ret, const ffi::String& name = ""); /*! \brief Get the out_dtype attribute*/ - const ExprDoc GetOutDtype(const String& key = "out_dtype", int input_idx = 0); + const ExprDoc GetOutDtype(const ffi::String& key = "out_dtype", int input_idx = 0); /*! \brief Get the axes attribute*/ - const std::vector GetAxes(const String& key = "axes"); + const std::vector GetAxes(const ffi::String& key = "axes"); }; /*! * \brief Get the map of available RelaxOpCode, use optype as key * \return Map of */ -const std::shared_ptr>> GetRelaxOpCodes(); +const std::shared_ptr>> +GetRelaxOpCodes(); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/plugin/base_codegen.h b/src/contrib/msc/plugin/base_codegen.h index cd5f03ff7716..fcb1f3982f79 100644 --- a/src/contrib/msc/plugin/base_codegen.h +++ b/src/contrib/msc/plugin/base_codegen.h @@ -66,13 +66,15 @@ class BasePluginCodeGen { virtual ~BasePluginCodeGen() = default; /*! \brief Get plugin sources*/ - virtual const Map GetBuildSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetBuildSources( + const std::string& print_options = "") { + ffi::Map sources; // plugin sources for (const auto& name : ListPluginNames()) { const auto& plugin = GetPlugin(name); // attr declare - const String& attr_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; + const ffi::String& attr_macro = + "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; this->stack_.line("#ifndef " + attr_macro) .line("#define " + attr_macro) .line() @@ -90,7 +92,8 @@ class BasePluginCodeGen { EndNamespace(); sources.Set(plugin->name + "_attr.cc", ToCppSource(print_options)); // op decalre - const String& op_macro = "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; + const ffi::String& op_macro = + "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; this->stack_.line("#ifndef " + op_macro).line("#define " + op_macro).line(); CodeGenOpHeader(plugin); StartNamespace(); @@ -114,7 +117,7 @@ class BasePluginCodeGen { } } // cmakelists - std::set devices; + std::set devices; for (const auto& name : ListPluginNames()) { const auto& plugin = GetPlugin(name); for (const auto& pair : plugin->externs) { @@ -129,8 +132,9 @@ class BasePluginCodeGen { } /*! \brief Get manager sources*/ - virtual const Map GetManagerSources(const std::string& print_options = "") { - Map sources; + virtual const ffi::Map GetManagerSources( + const std::string& print_options = "") { + ffi::Map sources; CodeGenManagerDepends(); this->stack_.class_def("PluginManager(object)").class_start(); CodeGenManagerMethods(); @@ -138,7 +142,7 @@ class BasePluginCodeGen { CodeGenOpBuilder(GetPlugin(name)); } if (this->config()->need_convert) { - Map symbols; + ffi::Map symbols; this->stack_.func_def("get_convert_map") .func_decorator("classmethod") .func_arg("cls", "object") @@ -165,7 +169,7 @@ class BasePluginCodeGen { /*! \brief Header of plugin files*/ virtual void CodeGenOpHeader(const Plugin& plugin) { this->stack_.line("#include \"" + plugin->name + "_attr.h\""); - std::set include_headers; + std::set include_headers; for (const auto& pair : plugin->externs) { if (pair.second->header.size() > 0 && !include_headers.count(pair.second->header)) { this->stack_.line("#include \"" + pair.second->header + "\""); @@ -194,7 +198,8 @@ class BasePluginCodeGen { /*! \brief Codegen safe call extern*/ void CodeGenSafeCall(const PluginExtern& extern_func, - const Array& call_args = Array(), const String& ret = "") { + const ffi::Array& call_args = ffi::Array(), + const ffi::String& ret = "") { this->stack_.scope_start("try {").func_call(extern_func->name, ret); for (const auto& arg : call_args) { this->stack_.call_arg(arg); @@ -244,14 +249,15 @@ class BasePluginCodeGen { virtual void CodeGenOpRuntime(const Plugin& plugin) {} /*! \brief Codegen cmake file*/ - virtual void CodeGenCmake(const std::set& devices) { + virtual void CodeGenCmake(const std::set& devices) { CodeGenPreCmake(devices); CodeGenPostCmake(devices); } /*! \brief Codegen cmake start*/ - void CodeGenPreCmake(const std::set& devices, - const Map& extra_flags = Map()) { + void CodeGenPreCmake(const std::set& devices, + const ffi::Map& extra_flags = + ffi::Map()) { const auto& p_name = this->config()->project_name; stack_.line("cmake_minimum_required(VERSION " + this->config()->cmake_version + " FATAL_ERROR)") .line("project(" + p_name + ")"); @@ -277,9 +283,9 @@ class BasePluginCodeGen { } /*! \brief Codegen cmake end*/ - void CodeGenPostCmake(const std::set& devices, - const Array& extra_includes = Array(), - const Array& extra_libs = Array()) { + void CodeGenPostCmake(const std::set& devices, + const ffi::Array& extra_includes = ffi::Array(), + const ffi::Array& extra_libs = ffi::Array()) { const auto& p_name = this->config()->project_name; stack_.line() .line("file(GLOB_RECURSE PLUGIN_HEADERS src/*.h)") @@ -293,7 +299,7 @@ class BasePluginCodeGen { stack_.line("add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS})"); } // define includes - String includes = StringUtils::Join(extra_includes, " "); + ffi::String includes = StringUtils::Join(extra_includes, " "); if (this->config()->includes.size() > 0) { includes = includes + " " + StringUtils::Join(this->config()->includes, " "); } @@ -301,7 +307,7 @@ class BasePluginCodeGen { stack_.line("target_include_directories(" + p_name + " PUBLIC " + includes + ")"); } // define libs - String link_libs = StringUtils::Join(extra_libs, " "); + ffi::String link_libs = StringUtils::Join(extra_libs, " "); const auto& libs = StringUtils::Join(this->config()->libs, " "); if (libs.size() > 0) { link_libs = link_libs + " " + libs; @@ -496,10 +502,10 @@ class BasePluginCodeGen { } /*! \brief Codegen convert function for plugin*/ - virtual const String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } + virtual const ffi::String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } /*! \brief Change code stack to cpp source*/ - const String ToCppSource(const std::string& print_options = "") { + const ffi::String ToCppSource(const std::string& print_options = "") { CppPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -509,7 +515,7 @@ class BasePluginCodeGen { } /*! \brief Change code stack to python source*/ - const String ToPySource(const std::string& print_options = "") { + const ffi::String ToPySource(const std::string& print_options = "") { PythonPrinter printer(print_options); for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); @@ -518,23 +524,23 @@ class BasePluginCodeGen { return printer.GetString(); } - std::vector> GetDtypeMatrix(const Plugin& plugin) { - std::vector> matrix; + std::vector> GetDtypeMatrix(const Plugin& plugin) { + std::vector> matrix; if (plugin->support_dtypes.size() == 0) { - std::unordered_map dtypes; + std::unordered_map dtypes; for (size_t i = 0; i < plugin->inputs.size(); i++) { dtypes[i] = plugin->inputs[i]->dtype; } matrix.push_back(dtypes); } else { - Array templates; - Array> condidates; + ffi::Array templates; + ffi::Array> condidates; for (const auto& pair : plugin->support_dtypes) { templates.push_back(pair.first); condidates.push_back(pair.second); } for (const auto& t_dtypes : ArrayUtils::Product(condidates)) { - std::unordered_map dtypes; + std::unordered_map dtypes; for (size_t i = 0; i < templates.size(); i++) { for (size_t in_idx = 0; in_idx < plugin->inputs.size(); in_idx++) { if (plugin->inputs[in_idx]->dtype == templates[i]) { @@ -554,11 +560,11 @@ class BasePluginCodeGen { return matrix; } - const Map GetTensorDtypes(const Plugin& plugin, - const std::unordered_map& dtypes) { - Map tensor_dtypes; + const ffi::Map GetTensorDtypes( + const Plugin& plugin, const std::unordered_map& dtypes) { + ffi::Map tensor_dtypes; for (const auto& pair : dtypes) { - const String& ref_dtype = plugin->inputs[pair.first]->dtype; + const ffi::String& ref_dtype = plugin->inputs[pair.first]->dtype; for (const auto& t : plugin->inputs) { if (t->dtype == ref_dtype) { tensor_dtypes.Set(t->name, pair.second); @@ -579,8 +585,8 @@ class BasePluginCodeGen { } /*! \brief Change plugin comment in python*/ - const String GetPyComment(const Plugin& plugin) { - String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; + const ffi::String GetPyComment(const Plugin& plugin) { + ffi::String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; for (const auto& t : plugin->inputs) { comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; } @@ -598,16 +604,16 @@ class BasePluginCodeGen { } /*! \brief Get class name for meta attrs*/ - const String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } + const ffi::String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } /*! \brief Get converter name for plugin*/ - const String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } + const ffi::String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } /*! \brief Check if the type is list type. */ - bool IsListType(const String& type) { return StringUtils::StartsWith(type, "list"); } + bool IsListType(const ffi::String& type) { return StringUtils::StartsWith(type, "list"); } /*! \brief Get type of element. */ - const String GetEleType(const String& type) { + const ffi::String GetEleType(const ffi::String& type) { if (!IsListType(type)) { return ""; } @@ -615,7 +621,7 @@ class BasePluginCodeGen { } /*! \brief Type name in cpp*/ - virtual const String ToCppType(const String& type) { + virtual const ffi::String ToCppType(const ffi::String& type) { if (IsListType(type)) { const auto& ele_type = GetEleType(type); return "std::vector<" + ToCppType(ele_type) + ">"; @@ -636,7 +642,7 @@ class BasePluginCodeGen { } /*! \brief Type name in python*/ - virtual const String ToPyType(const String& type) { + virtual const ffi::String ToPyType(const ffi::String& type) { if (IsListType(type)) { const auto& ele_type = GetEleType(type); return "List[" + ToPyType(ele_type) + "]"; diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index f1ab676b707f..b9ca02bcb9d5 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -120,7 +120,7 @@ void TensorRTPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { .for_start("i", 0, plugin->attrs.size()); for (size_t i = 0; i < plugin->attrs.size(); i++) { const auto& attr = plugin->attrs[i]; - const String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; + const ffi::String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; if (i == 0) { stack_.switch_start(cond); } else { @@ -275,7 +275,7 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .declare("bool", "support"); size_t cnt = 0; for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; + const ffi::String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; if (cnt == 0) { stack_.switch_start(cond); } else { @@ -374,7 +374,7 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .declare("bool", "support"); size_t cnt = 0; for (const auto& dtypes : GetDtypeMatrix(plugin)) { - String cond; + ffi::String cond; for (size_t i = 0; i < plugin->inputs.size(); i++) { cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" + dtypes.at(i) + "\")"; @@ -419,8 +419,8 @@ void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { CodegenCreator(plugin, true, false); } -void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TENSORRT", ""); flags.Set("TRT_MAJOR", std::to_string(config()->version[0])); flags.Set("TRT_MINOR", std::to_string(config()->version[1])); @@ -432,7 +432,7 @@ void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + " PATH_SUFFIXES lib)") .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-terminate\")"); - Array includes, libs; + ffi::Array includes, libs; includes.push_back("${TRT_INCLUDE_DIR}"); libs.push_back("${TRT_LIBS}"); CodeGenPostCmake(devices, includes, libs); @@ -454,7 +454,7 @@ void TensorRTPluginCodeGen::CodeGenManagerMethods() { void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, bool in_declare) { const auto& op_cls = OpCls(plugin, dynamic); - const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; if (in_declare) { stack_.comment("common methods for " + op_cls); stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&"); @@ -567,7 +567,7 @@ void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dy .line("assert(char_buf == (start_buf + getSerializationSize()));") .func_end(); // getPluginType - const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); stack_.func_def(op_cls + "::getPluginType", "const char*") .func_decorator("const noexcept") .func_start() @@ -644,7 +644,7 @@ void TensorRTPluginCodeGen::CodegenOpMembers(const Plugin& plugin, bool dynamic) void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) { const auto& creator_cls = CreatorCls(plugin, dynamic); - const String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; + const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; if (in_declare) { stack_.class_def(creator_cls + " : public IPluginCreator") .class_start() @@ -679,7 +679,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .line() .class_end(); } else { - const String& attr_name = MetaAttrCls(plugin); + const ffi::String& attr_name = MetaAttrCls(plugin); // static members stack_.comment("static members and register for " + plugin->name) .declare("PluginFieldCollection", creator_cls + "::collection_") @@ -705,7 +705,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .func_call("data", fields_doc, DocUtils::ToDoc("fields_")) .constructor_end(); // getPluginName - const String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); + const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); stack_.func_def(creator_cls + "::getPluginName", "const char*") .func_decorator("const noexcept") .func_start() @@ -753,7 +753,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .for_start("i", plugin->attrs.size(), fields_size); for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& tensor = plugin->inputs[i]; - const String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; + const ffi::String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; if (i == 0) { stack_.switch_start(cond); } else { @@ -794,7 +794,7 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b } void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_desc) { - Array infer_args{"input_metas_", "meta_attr_", "false"}; + ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; stack_.line("assert(n_inputs == " + std::to_string(plugin->inputs.size()) + ");") .func_call("resize", "", "input_metas_") .call_arg(plugin->inputs.size()) @@ -810,7 +810,7 @@ void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_des } void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { - Array infer_args{"input_metas_", "meta_attr_", "false"}; + ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; CodeGenSafeCall(plugin->externs["infer_buffer"], infer_args, "buffer_metas_"); stack_.for_start("b", "buffer_metas_") .assign("size", "size + max_batch * b.size(false)") @@ -820,12 +820,12 @@ void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { ICHECK(plugin->externs.count("cuda_compute")) << "cuda_compute is needed fo TensorRT plugin"; auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor, - const Map& dtypes, size_t idx, - const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + const ffi::Map& dtypes, + size_t idx, const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TRTUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)); const auto& t_meta = DocUtils::ToIndex(collect + "_metas_", idx); if (dynamic) { @@ -844,8 +844,8 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { }; for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; if (dynamic) { for (size_t i = 0; i < plugin->inputs.size(); i++) { dtype_cond = dtype_cond + "input_descs[" + std::to_string(i) + @@ -858,19 +858,19 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } if (plugin->buffers.size() > 0) { stack_.assign("offset", 0, "size_t"); for (size_t i = 0; i < plugin->buffers.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); compute_args.push_back(t_name); - const String& size_name = "size_" + plugin->buffers[i]->name; + const ffi::String& size_name = "size_" + plugin->buffers[i]->name; stack_ .func_call("size", DocUtils::ToDeclare("size_t", size_name), DocUtils::ToIndex("buffer_metas_", i)) @@ -888,8 +888,8 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTensorRTPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -897,7 +897,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); }); diff --git a/src/contrib/msc/plugin/tensorrt_codegen.h b/src/contrib/msc/plugin/tensorrt_codegen.h index 24fb4e5dfca2..c5b0e585a139 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.h +++ b/src/contrib/msc/plugin/tensorrt_codegen.h @@ -79,25 +79,25 @@ class TensorRTPluginCodeGen : public BasePluginCodeGen& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager methods*/ void CodeGenManagerMethods() final; private: /*! \brief Op class name of plugin*/ - const String OpCls(const Plugin& plugin, bool dynamic) const { + const ffi::String OpCls(const Plugin& plugin, bool dynamic) const { return plugin->name + (dynamic ? "DynamicPlugin" : "Plugin"); } /*! \brief Creator class name of plugin*/ - const String CreatorCls(const Plugin& plugin, bool dynamic) const { + const ffi::String CreatorCls(const Plugin& plugin, bool dynamic) const { return plugin->name + (dynamic ? "DynamicCreator" : "Creator"); } bool IsMixPrecision(const Plugin& plugin) { for (const auto& dtypes : GetDtypeMatrix(plugin)) { - String ref_dtype = ""; + ffi::String ref_dtype = ""; for (const auto& pair : dtypes) { if (ref_dtype.size() == 0) { ref_dtype = pair.second; diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 63d068acab34..79c61d13e965 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -153,7 +153,7 @@ void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { CodeGenMalloc(plugin, plugin->buffers, "buffer"); } // do the compute - String device_cond = ""; + ffi::String device_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { device_cond = device_cond + "input_tensors[" + std::to_string(i) + "].is_cuda()"; @@ -216,15 +216,15 @@ void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .func_end(); } -void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TORCH", ""); CodeGenPreCmake(devices, flags); stack_.line() .line("set(CMAKE_CXX_STANDARD 17)") .line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")") .line("find_package(Torch REQUIRED)"); - Array includes, libs; + ffi::Array includes, libs; libs.push_back("${TORCH_LIBRARIES}"); CodeGenPostCmake(devices, includes, libs); } @@ -366,14 +366,14 @@ void TorchPluginCodeGen::CodeGenConvertDepends() { .line(); } -const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { +const ffi::String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { stack_.func_def(ConverterName(plugin), "relax.Var") .func_arg("node", "fx.node.Node") .func_arg("ctx", "TorchFXImporter") .func_start() .func_call("retrieve_args", "args", "ctx") .call_arg("node"); - Array args; + ffi::Array args; for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& tensor = plugin->inputs[i]; stack_.assign(tensor->name, DocUtils::ToIndex("args", i + 1)); @@ -407,9 +407,9 @@ const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { .call_arg("op") .call_arg("name"); if (plugin->outputs.size() == 1) { - stack_.func_end(DocUtils::ToList(Array{"var"})); + stack_.func_end(DocUtils::ToList(ffi::Array{"var"})); } else { - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < plugin->outputs.size(); i++) { const auto& tensor = plugin->outputs[i]; stack_.func_call("relax.TupleGetItem", tensor->name).call_arg("var").call_arg(i); @@ -420,9 +420,10 @@ const String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { return EntryName(plugin); } -void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const Array& tensors, - const String& collect) { - Array call_args{"input_metas", "meta_attr_", "true"}; +void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, + const ffi::Array& tensors, + const ffi::String& collect) { + ffi::Array call_args{"input_metas", "meta_attr_", "true"}; stack_.line().comment("malloc " + collect).declare("std::vector", collect + "_metas"); CodeGenSafeCall(plugin->externs["infer_" + collect], call_args, collect + "_metas"); for (size_t i = 0; i < tensors.size(); i++) { @@ -442,13 +443,14 @@ void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const Array& dtypes, - size_t idx, const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; +void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { + auto prepare_tensor = [this](const PluginTensor& tensor, + const ffi::Map& dtypes, size_t idx, + const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TorchUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) .call_arg(DocUtils::ToIndex(collect + "_tensors", idx)) .call_arg(DocUtils::ToIndex(collect + "_metas", idx)) @@ -459,8 +461,8 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi if (plugin->externs.count(device + "_compute")) { for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { dtype_cond = dtype_cond + "input_metas[" + std::to_string(i) + "].data_type() == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; @@ -469,15 +471,15 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->buffers.size(); i++) { - const String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); + const ffi::String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); compute_args.push_back(t_name); } compute_args.push_back("meta_attr_"); @@ -497,8 +499,8 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTorchPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -506,7 +508,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); }); diff --git a/src/contrib/msc/plugin/torch_codegen.h b/src/contrib/msc/plugin/torch_codegen.h index 4452650e2271..1dae9134e704 100644 --- a/src/contrib/msc/plugin/torch_codegen.h +++ b/src/contrib/msc/plugin/torch_codegen.h @@ -79,7 +79,7 @@ class TorchPluginCodeGen : public BasePluginCodeGen { void CodeGenOpDefine(const Plugin& plugin) final; /*! \brief Codegen cmake file*/ - void CodeGenCmake(const std::set& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager depends*/ void CodeGenManagerDepends() final; @@ -94,18 +94,18 @@ class TorchPluginCodeGen : public BasePluginCodeGen { void CodeGenConvertDepends() final; /*! \brief Codegen convert function for plugin*/ - const String CodeGenOpConvert(const Plugin& plugin) final; + const ffi::String CodeGenOpConvert(const Plugin& plugin) final; private: /*! \brief Codegen malloc for outputs/buffers*/ - void CodeGenMalloc(const Plugin& plugin, const Array& tensors, - const String& collect); + void CodeGenMalloc(const Plugin& plugin, const ffi::Array& tensors, + const ffi::String& collect); /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const String& device); + void CodeGenCompute(const Plugin& plugin, const ffi::String& device); /*! \brief Entry name of torch function*/ - const String EntryName(const Plugin& plugin) { + const ffi::String EntryName(const Plugin& plugin) { std::string lower_name; const std::string& name = std::string(plugin->name); for (size_t i = 0; i < name.size(); i++) { @@ -119,7 +119,7 @@ class TorchPluginCodeGen : public BasePluginCodeGen { } /*! \brief Type name in torch*/ - const String ToTorchType(const String& type) { + const ffi::String ToTorchType(const ffi::String& type) { if (type == "float") { return "double"; } diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 7410867aaf25..373e9aaac294 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -35,7 +35,7 @@ void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { stack_.comment("convert exprs to meta attrs method") .func_def(attr_name + "_from_exprs", "const " + attr_name); for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } // args to meta_attr @@ -50,12 +50,12 @@ void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { // exprs to meta_attr stack_.func_def(attr_name + "_from_exprs", "const " + attr_name); for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } stack_.func_start().declare(attr_name, "meta_attr"); for (const auto& a : plugin->attrs) { - const String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; + const ffi::String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; stack_.func_call("TVMUtils::" + convert) .call_arg(a->name) .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); @@ -92,30 +92,30 @@ void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { void TVMPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { // infer struct info - stack_.func_def("InferStructInfo" + plugin->name, "Array"); + stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); for (const auto& t : plugin->inputs) { stack_.func_arg(t->name, "const Expr&"); } for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } // infer layout stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const Array&") + .func_arg("inputs", "const ffi::Array&") .func_arg("var_layout_map", "const VarLayoutMap&"); } void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { const auto& attr_name = MetaAttrCls(plugin); // infer struct info - Array infer_args{"input_metas", "meta_attr", "false"}; - stack_.func_def("InferStructInfo" + plugin->name, "Array"); + ffi::Array infer_args{"input_metas", "meta_attr", "false"}; + stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); for (const auto& t : plugin->inputs) { stack_.func_arg(t->name, "const Expr&"); } for (const auto& a : plugin->attrs) { - const String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; stack_.func_arg(a->name, "const " + anno + "&"); } stack_.func_start() @@ -133,7 +133,7 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { } stack_.declare("std::vector", "output_metas"); CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); - stack_.declare("Array", "output_sinfo"); + stack_.declare("ffi::Array", "output_sinfo"); for (size_t i = 0; i < plugin->outputs.size(); i++) { stack_.func_call("push_back", "", "output_sinfo") .inplace_start("TVMUtils::ToTensorStructInfo") @@ -152,20 +152,20 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { // infer layout stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const Array&") + .func_arg("inputs", "const ffi::Array&") .func_arg("var_layout_map", "const VarLayoutMap&") .func_start() .comment("define attrs"); for (size_t i = 0; i < plugin->attrs.size(); i++) { const auto& attr = plugin->attrs[i]; - const String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; + const ffi::String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; stack_ .func_call("Downcast<" + anno + ">", DocUtils::ToDeclare("const auto&", "attr_" + attr->name)) .call_arg(DocUtils::ToIndex("inputs", i + plugin->inputs.size())); } - stack_.declare("Array", "arg_layouts") - .declare("Array", "output_layouts") + stack_.declare("ffi::Array", "arg_layouts") + .declare("ffi::Array", "output_layouts") .comment("extract meta attrs") .func_call(attr_name + "_from_exprs", "const " + attr_name + "& meta_attr"); for (const auto& a : plugin->attrs) { @@ -201,7 +201,7 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("output_metas", "i"), "layout_name()")) .inplace_end() .for_end() - .declare("Array", "input_layouts") + .declare("ffi::Array", "input_layouts") .func_call("push_back", "", "input_layouts") .inplace_start("LayoutDecision") .call_arg(DocUtils::ToStr("")) @@ -229,10 +229,10 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { ICHECK(!plugin->externs.count("infer_buffer")) << "infer_buffer is not supported for tvm runtime"; const auto& attr_name = MetaAttrCls(plugin); const auto& func_name = ComputeName(plugin); - String device_cond = ""; - String device_index = ""; + ffi::String device_cond = ""; + ffi::String device_index = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { - String device_type = ""; + ffi::String device_type = ""; if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { device_type = "DLDeviceType::kDLCUDA"; } else { @@ -267,8 +267,8 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { .line(); } -void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { - Map flags; +void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { + ffi::Map flags; flags.Set("PLUGIN_SUPPORT_TVM", ""); CodeGenPreCmake(devices, flags); stack_.line("set(CMAKE_CXX_STANDARD 17)") @@ -276,7 +276,7 @@ void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { .line() .line("set(TVM_ROOT " + config()->tvm_root + ")") .line("find_library(TVM_LIB NAMES tvm HINTS ${TVM_ROOT}/build NO_DEFAULT_PATH)"); - Array includes, libs; + ffi::Array includes, libs; includes.push_back("${TVM_ROOT}/include"); includes.push_back("${TVM_ROOT}/3rdparty/dmlc-core/include"); includes.push_back("${TVM_ROOT}/3rdparty/dlpack/include"); @@ -318,7 +318,7 @@ void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { stack_.func_arg(attr->name, ToPyType(attr->type), attr->default_value); } stack_.func_arg("name", "str", "\"" + plugin->name + "\"").func_start(); - Array args; + ffi::Array args; for (const auto& t : plugin->inputs) { args.push_back(t->name); } @@ -345,15 +345,17 @@ void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { stack_.func_end("op").comment(GetPyComment(plugin), true); } -void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device) { +void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { if (plugin->externs.count(device + "_compute")) { // compute with dtype - auto prepare_tensor = [this](const PluginTensor& tensor, const Map& dtypes, - size_t idx, const String& collect) { - const String& t_name = "d_" + tensor->name; - const String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const String& tensor_type = "DataTensor<" + t_dtype + ">"; - const String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; + auto prepare_tensor = [this](const PluginTensor& tensor, + const ffi::Map& dtypes, size_t idx, + const ffi::String& collect) { + const ffi::String& t_name = "d_" + tensor->name; + const ffi::String& t_dtype = + dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; + const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; + const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; stack_.func_call("TVMUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) .call_arg(tensor->name) .call_arg(collect == "input"); @@ -361,8 +363,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device }; for (const auto& dtypes : GetDtypeMatrix(plugin)) { const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - Array compute_args; - String dtype_cond = ""; + ffi::Array compute_args; + ffi::String dtype_cond = ""; for (size_t i = 0; i < plugin->inputs.size(); i++) { const auto& t_name = plugin->inputs[i]->name; dtype_cond = dtype_cond + "TVMUtils::ToMetaType(" + t_name + @@ -372,11 +374,11 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device // prepare compute datas stack_.cond_if(dtype_cond).comment("prepare compute datas"); for (size_t i = 0; i < plugin->inputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); + const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); compute_args.push_back(t_name); } for (size_t i = 0; i < plugin->outputs.size(); i++) { - const String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); + const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; @@ -397,8 +399,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("msc.plugin.GetTVMPluginSources", - [](const String& codegen_config, const String& print_config, - const String& codegen_type) -> Map { + [](const ffi::String& codegen_config, const ffi::String& print_config, + const ffi::String& codegen_type) -> ffi::Map { TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); if (codegen_type == "build") { return codegen.GetBuildSources(print_config); @@ -406,7 +408,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (codegen_type == "manager") { return codegen.GetManagerSources(print_config); } - return Map(); + return ffi::Map(); }); }); diff --git a/src/contrib/msc/plugin/tvm_codegen.h b/src/contrib/msc/plugin/tvm_codegen.h index 520e35de95c6..926c5162005a 100644 --- a/src/contrib/msc/plugin/tvm_codegen.h +++ b/src/contrib/msc/plugin/tvm_codegen.h @@ -82,7 +82,7 @@ class TVMPluginCodeGen : public BasePluginCodeGen { void CodeGenOpRuntime(const Plugin& plugin) final; /*! \brief Codegen cmake file*/ - void CodeGenCmake(const std::set& devices) final; + void CodeGenCmake(const std::set& devices) final; /*! \brief Codegen manager depends*/ void CodeGenManagerDepends() final; @@ -95,13 +95,13 @@ class TVMPluginCodeGen : public BasePluginCodeGen { private: /*! \brief Func name of compute*/ - const String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } + const ffi::String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const String& device); + void CodeGenCompute(const Plugin& plugin, const ffi::String& device); /*! \brief Type name in tvm*/ - const String ToTVMType(const String& type) { + const ffi::String ToTVMType(const ffi::String& type) { if (type == "string") { return "StringImm"; } diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 41c75c875b78..72fc1803715d 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -29,17 +29,17 @@ namespace tvm { namespace ir { -Map> CollectCallMap(const IRModule& mod) { +ffi::Map> CollectCallMap(const IRModule& mod) { struct CalleeCollectorImpl : CalleeCollector { void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } support::OrderedSet gvars; }; - Map> call_map; + ffi::Map> call_map; for (const auto& [gvar, base_func] : mod->functions) { CalleeCollectorImpl collector; CalleeCollector::vtable()(base_func, &collector); - call_map.Set(gvar, Array{collector.gvars.begin(), collector.gvars.end()}); + call_map.Set(gvar, ffi::Array{collector.gvars.begin(), collector.gvars.end()}); } return call_map; } diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 3436d49b02ee..bf5138924b7f 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -56,7 +56,7 @@ BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) { } } // namespace -Pass ApplyPassToFunction(Pass pass, String func_name_regex, +Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, bool error_if_no_function_matches_regex) { auto pass_name = static_cast(std::stringstream() << "ApplyPassTo" << func_name_regex) @@ -65,15 +65,15 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex]( IRModule mod, PassContext) -> IRModule { bool at_least_one_function_matched_regex = false; - std::unordered_set keep_original_version; - std::unordered_set internal_functions; + std::unordered_set keep_original_version; + std::unordered_set internal_functions; IRModule subset; for (auto [gvar, func] : mod->functions) { std::string name = gvar->name_hint; if (tvm::runtime::regex_match(name, func_name_regex)) { at_least_one_function_matched_regex = true; - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { // Function may be mutated, but is an internal function. Mark // it as externally-exposed, so that any call-tracing internal // transforms do not remove this function, in case it its @@ -97,7 +97,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, if (error_if_no_function_matches_regex) { CHECK(at_least_one_function_matched_regex) << "No function matched regex '" << func_name_regex << "', out of functions " << [&]() { - Array function_names; + ffi::Array function_names; for (const auto& [gvar, func] : mod->functions) { function_names.push_back(gvar->name_hint); } diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 66a43f93c7d5..911e829ea9c9 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -33,7 +33,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DictAttrsNode::RegisterReflection(); }); -DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { +DictAttrs WithAttrs(DictAttrs attrs, ffi::Map new_attrs) { if (new_attrs.empty()) { return attrs; } @@ -45,7 +45,7 @@ DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { return attrs; } -DictAttrs WithAttr(DictAttrs attrs, String key, ffi::Any value) { +DictAttrs WithAttr(DictAttrs attrs, ffi::String key, ffi::Any value) { attrs.CopyOnWrite()->dict.Set(key, value); return attrs; } @@ -57,14 +57,14 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { - String key = args[i].cast(); + ffi::String key = args[i].cast(); ffi::AnyView val = args[i + 1]; dict.Set(key, val); } } -DictAttrs::DictAttrs(Map dict) { - ObjectPtr n = make_object(); +DictAttrs::DictAttrs(ffi::Map dict) { + ObjectPtr n = ffi::make_object(); n->dict = std::move(dict); data_ = std::move(n); } diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index fa48ceba288b..ac8b11575239 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -41,13 +41,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("diagnostics.Diagnostic", [](int level, Span span, String message) { + refl::GlobalDef().def("diagnostics.Diagnostic", [](int level, Span span, ffi::String message) { return Diagnostic(static_cast(level), span, message); }); }); Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) { - auto n = make_object(); + auto n = ffi::make_object(); n->level = level; n->span = span; n->message = message; @@ -94,13 +94,15 @@ DiagnosticBuilder Diagnostic::Help(ObjectRef loc) { return DiagnosticBuilder(DiagnosticLevel::kHelp, loc); } -DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(ffi::GetRef(loc)); } -DiagnosticBuilder Diagnostic::Error(const Object* loc) { return Error(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Error(const Object* loc) { + return Error(ffi::GetRef(loc)); +} -DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(ffi::GetRef(loc)); } -DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(GetRef(loc)); } +DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(ffi::GetRef(loc)); } /* Diagnostic Renderer */ @@ -108,7 +110,7 @@ void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->rendere TVM_DLL DiagnosticRenderer::DiagnosticRenderer( ffi::TypedFunction renderer) { - auto n = make_object(); + auto n = ffi::make_object(); n->renderer = renderer; data_ = std::move(n); } @@ -152,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; - auto n = make_object(); + auto n = ffi::make_object(); n->module = module; n->renderer = renderer; data_ = std::move(n); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index bc91db0ce45d..77c346eabcce 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -42,13 +42,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ObjectPtr CreateEnvNode(const std::string& name) { auto f = tvm::ffi::Function::GetGlobal(name); ICHECK(f.has_value()) << "Cannot find global function \'" << name << '\''; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->func = *f; n->name = name; return n; } -EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } +EnvFunc EnvFunc::Get(const ffi::String& name) { return EnvFunc(CreateEnvNode(name)); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 43112335988f..101a00cf5a5d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -48,7 +48,7 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr PrimExpr::ConvertFallbackValue(String value) { return tir::StringImm(value); } +PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringImm(value); } IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype @@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK_LT(value, 1LL << (dtype.bits() - 1)) << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->value = value; node->span = span; @@ -174,7 +174,7 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { << dtype; } } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->value = value; node->span = span; @@ -189,17 +189,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); Range::Range(PrimExpr begin, PrimExpr end, Span span) - : Range(make_object(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} + : Range(ffi::make_object(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { - return Range(make_object(min, extent, span)); + return Range(ffi::make_object(min, extent, span)); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.Range_from_min_extent", Range::FromMinExtent) - .def("ir.Range", [](PrimExpr begin, Optional end, Span span) -> Range { + .def("ir.Range", [](PrimExpr begin, ffi::Optional end, Span span) -> Range { if (end.defined()) { return Range(begin, end.value(), span); } else { @@ -208,8 +208,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -GlobalVar::GlobalVar(String name_hint, Span span) { - ObjectPtr n = make_object(); +GlobalVar::GlobalVar(ffi::String name_hint, Span span) { + ObjectPtr n = ffi::make_object(); n->name_hint = std::move(name_hint); n->span = std::move(span); data_ = std::move(n); @@ -218,7 +218,7 @@ GlobalVar::GlobalVar(String name_hint, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.GlobalVar", [](String name) { return GlobalVar(name); }) + .def("ir.GlobalVar", [](ffi::String name) { return GlobalVar(name); }) .def("ir.DebugPrint", [](ObjectRef ref) { std::stringstream ss; ss << ref; diff --git a/src/ir/function.cc b/src/ir/function.cc index cb30325ffff9..21fdb7975b89 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -36,7 +36,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("ir.BaseFunc_Attrs", [](BaseFunc func) { return func->attrs; }) .def("ir.BaseFuncCopy", [](BaseFunc func) { return func; }) .def("ir.BaseFuncWithAttr", - [](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { + [](ffi::RValueRef func_ref, ffi::String key, Any value) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); @@ -49,13 +49,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ } }) .def("ir.BaseFuncWithAttrs", - [](ffi::RValueRef func_ref, Map attr_map) -> BaseFunc { + [](ffi::RValueRef func_ref, + ffi::Map attr_map) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } if (const auto f = tvm::ffi::Function::GetGlobal("relax.FuncWithAttrs")) { - if (auto ret = (*f)(func, attr_map).cast>()) { + if (auto ret = (*f)(func, attr_map).cast>()) { return ret.value(); } } @@ -65,17 +66,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); TVM_FFI_UNREACHABLE(); }) - .def("ir.BaseFuncWithoutAttr", [](ffi::RValueRef func_ref, String key) -> BaseFunc { - BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - TVM_FFI_UNREACHABLE(); - } - }); + .def("ir.BaseFuncWithoutAttr", + [](ffi::RValueRef func_ref, ffi::String key) -> BaseFunc { + BaseFunc func = *std::move(func_ref); + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_UNREACHABLE(); + } + }); }); } // namespace tvm diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 566702f5dd63..b318c86b0f00 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -34,13 +34,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ir.DummyGlobalInfo", []() { - auto n = DummyGlobalInfo(make_object()); + auto n = DummyGlobalInfo(ffi::make_object()); return n; }); }); VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->target = std::move(tgt); n->vdevice_id = std::move(dev_id); n->memory_scope = std::move(mem_scope); diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 9d4e66bfa466..71505430c5cc 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -36,15 +36,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ GlobalVarSupplyNode::RegisterReflection(); }); GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map) { - auto n = make_object(name_supply, name_to_var_map); + auto n = ffi::make_object(name_supply, name_to_var_map); data_ = std::move(n); } std::string GetModuleName(const IRModule& module) { - return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); + return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); } -GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupply() { +GlobalVarSupply::GlobalVarSupply(const ffi::Array& modules) : GlobalVarSupply() { if (!modules.empty()) { IRModule first_mod = modules.front(); this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); @@ -57,7 +57,7 @@ GlobalVarSupply::GlobalVarSupply(const Array& modules) : GlobalVarSupp } GlobalVarSupply::GlobalVarSupply(const IRModule module) - : GlobalVarSupply(Array{module}) {} + : GlobalVarSupply(ffi::Array{module}) {} void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { name_supply_->ReserveName(var->name_hint, false); @@ -72,8 +72,8 @@ GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, std::unordered_map name_to_var_map) : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} -GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_prefix) { - String final_name = name_supply_->ReserveName(name, add_prefix); +GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name_supply_->ReserveName(name, add_prefix); auto it = name_to_var_map_.find(final_name); if (it != name_to_var_map_.end()) { @@ -85,8 +85,8 @@ GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const String& name, bool add_pref } } -GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { - String final_name = name_supply_->FreshName(name, add_prefix); +GlobalVar GlobalVarSupplyNode::FreshGlobal(ffi::String name, bool add_prefix) { + ffi::String final_name = name_supply_->FreshName(name, add_prefix); ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) << "GlobalVar already exists for name " << final_name; GlobalVar var = GlobalVar(final_name); @@ -102,7 +102,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("ir.GlobalVarSupply_IRModule", [](IRModule mod) { return GlobalVarSupply(std::move(mod)); }) .def("ir.GlobalVarSupply_IRModules", - [](const Array& mods) { return GlobalVarSupply(mods); }) + [](const ffi::Array& mods) { return GlobalVarSupply(mods); }) .def_method("ir.GlobalVarSupply_FreshGlobal", &GlobalVarSupplyNode::FreshGlobal) .def_method("ir.GlobalVarSupply_UniqueGlobalFor", &GlobalVarSupplyNode::UniqueGlobalFor) .def_method("ir.GlobalVarSupply_ReserveGlobalVar", &GlobalVarSupplyNode::ReserveGlobalVar); diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 74176cb373cc..463235cc97f6 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -110,7 +110,7 @@ class BasePassInstrument : public PassInstrument { * \param run_after_pass_callback Callback to call after a pass run. */ TVM_DLL BasePassInstrument( - String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::String name, ffi::TypedFunction enter_pass_ctx_callback, ffi::TypedFunction exit_pass_ctx_callback, ffi::TypedFunction should_run_callback, ffi::TypedFunction @@ -122,12 +122,12 @@ class BasePassInstrument : public PassInstrument { }; BasePassInstrument::BasePassInstrument( - String name, ffi::TypedFunction enter_pass_ctx_callback, + ffi::String name, ffi::TypedFunction enter_pass_ctx_callback, ffi::TypedFunction exit_pass_ctx_callback, ffi::TypedFunction should_run_callback, ffi::TypedFunction run_before_pass_callback, ffi::TypedFunction run_after_pass_callback) { - auto pi = make_object(); + auto pi = ffi::make_object(); pi->name = std::move(name); pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); @@ -180,7 +180,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "instrument.PassInstrument", - [](String name, ffi::TypedFunction enter_pass_ctx, + [](ffi::String name, ffi::TypedFunction enter_pass_ctx, ffi::TypedFunction exit_pass_ctx, ffi::TypedFunction should_run, ffi::TypedFunction run_before_pass, @@ -204,7 +204,7 @@ struct PassProfile { using Time = std::chrono::time_point; /*! \brief The name of the pass being profiled. */ - String name; + ffi::String name; /*! \brief The time when the pass was entered. */ Time start; /*! \brief The time when the pass completed. */ @@ -214,13 +214,13 @@ struct PassProfile { /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ std::vector children; - explicit PassProfile(String name) + explicit PassProfile(ffi::String name) : name(name), start(Clock::now()), end(Clock::now()), children() {} /*! \brief Gets the PassProfile of the currently executing pass. */ static PassProfile* Current(); /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); + static void EnterPass(ffi::String name); /*! \brief Pops the current PassProfile. */ static void ExitPass(); }; @@ -237,7 +237,7 @@ struct PassProfileThreadLocalEntry { /*! \brief Thread local store to hold the pass profiling data. */ typedef dmlc::ThreadLocalStore PassProfileThreadLocalStore; -void PassProfile::EnterPass(String name) { +void PassProfile::EnterPass(ffi::String name) { PassProfile* cur = PassProfile::Current(); cur->children.emplace_back(name); PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); @@ -260,13 +260,13 @@ PassProfile* PassProfile::Current() { } } -String RenderPassProfiles() { +ffi::String RenderPassProfiles() { PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; if (entry->root.children.empty()) { LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; - return String(); + return ffi::String(); } // (depth, parent_duration, pass) diff --git a/src/ir/module.cc b/src/ir/module.cc index 3ca4457b9871..05eaca3a4764 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -38,9 +38,9 @@ namespace tvm { TVM_FFI_STATIC_INIT_BLOCK({ IRModuleNode::RegisterReflection(); }); -IRModule::IRModule(tvm::Map functions, SourceMap source_map, DictAttrs attrs, - Map> global_infos) { - auto n = make_object(); +IRModule::IRModule(tvm::ffi::Map functions, SourceMap source_map, + DictAttrs attrs, ffi::Map> global_infos) { + auto n = ffi::make_object(); n->functions = std::move(functions); n->global_var_map_ = {}; n->source_map = source_map; @@ -109,11 +109,11 @@ uint64_t IRModuleNode::SHash(uint64_t init_hash, return hash_value; } -bool IRModuleNode::ContainGlobalVar(const String& name) const { +bool IRModuleNode::ContainGlobalVar(const ffi::String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const ffi::String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -132,7 +132,7 @@ GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { return (*it).second; } -tvm::Array IRModuleNode::GetGlobalVars() const { +tvm::ffi::Array IRModuleNode::GetGlobalVars() const { std::vector global_vars; for (const auto& pair : global_var_map_) { global_vars.push_back(pair.second); @@ -140,7 +140,7 @@ tvm::Array IRModuleNode::GetGlobalVars() const { std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) { return lhs->name_hint < rhs->name_hint; }); - return tvm::Array(global_vars); + return tvm::ffi::Array(global_vars); } void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { @@ -165,7 +165,7 @@ void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { this->Add(var, func, true); } -void IRModuleNode::UpdateGlobalInfo(const String& name, const Array& info) { +void IRModuleNode::UpdateGlobalInfo(const ffi::String& name, const ffi::Array& info) { this->global_infos.Set(name, info); } @@ -182,7 +182,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } -BaseFunc IRModuleNode::Lookup(const String& name) const { +BaseFunc IRModuleNode::Lookup(const ffi::String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } @@ -199,15 +199,15 @@ IRModule IRModuleNode::ShallowCopy() { } IRModule IRModule::FromExpr(const RelaxExpr& expr, - const tvm::Map& global_funcs) { + const tvm::ffi::Map& global_funcs) { auto mod = IRModule(global_funcs); - String gv_name; + ffi::String gv_name; // All global definitions must be functions. BaseFunc func; if (auto func_node = expr.as()) { func = func_node.value(); - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } @@ -229,18 +229,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.IRModule", - [](tvm::Map funcs, tvm::ObjectRef attrs, - Map> global_infos) { + [](tvm::ffi::Map funcs, tvm::ObjectRef attrs, + ffi::Map> global_infos) { auto dict_attrs = [&attrs]() { if (!attrs.defined()) { return DictAttrs(); } else if (auto* as_dict_attrs = attrs.as()) { - return GetRef(as_dict_attrs); + return ffi::GetRef(as_dict_attrs); } else if (attrs.as()) { - return tvm::DictAttrs(Downcast>(attrs)); + return tvm::DictAttrs(Downcast>(attrs)); } else { - LOG(FATAL) - << "Expected attrs argument to be either DictAttrs or Map"; + LOG(FATAL) << "Expected attrs argument to be either DictAttrs or " + "ffi::Map"; } }(); @@ -259,11 +259,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ return mod; }) .def("ir.Module_Remove", - [](IRModule mod, Variant var) -> IRModule { + [](IRModule mod, ffi::Variant var) -> IRModule { GlobalVar gvar = [&]() { if (auto opt = var.as()) { return opt.value(); - } else if (auto opt = var.as()) { + } else if (auto opt = var.as()) { return mod->GetGlobalVar(opt.value()); } else { LOG(FATAL) << "InternalError: " @@ -274,10 +274,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ return mod; }) .def("ir.Module_Contains", - [](IRModule mod, Variant var) -> bool { + [](IRModule mod, ffi::Variant var) -> bool { if (auto opt = var.as()) { return mod->functions.count(opt.value()); - } else if (auto opt = var.as()) { + } else if (auto opt = var.as()) { return mod->global_var_map_.count(opt.value()); } else { LOG(FATAL) << "InternalError: " @@ -288,30 +288,30 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("ir.Module_GetGlobalVars", &IRModuleNode::GetGlobalVars) .def_method("ir.Module_ContainGlobalVar", &IRModuleNode::ContainGlobalVar) .def("ir.Module_Lookup", [](IRModule mod, GlobalVar var) { return mod->Lookup(var); }) - .def("ir.Module_Lookup_str", [](IRModule mod, String var) { return mod->Lookup(var); }) + .def("ir.Module_Lookup_str", [](IRModule mod, ffi::String var) { return mod->Lookup(var); }) .def("ir.Module_FromExpr", &IRModule::FromExpr) .def("ir.Module_Update", [](IRModule mod, IRModule from) { mod->Update(from); }) .def("ir.Module_UpdateFunction", [](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }) .def("ir.Module_UpdateGlobalInfo", - [](IRModule mod, String name, Array global_info) { + [](IRModule mod, ffi::String name, ffi::Array global_info) { mod->UpdateGlobalInfo(name, global_info); }) .def("ir.Module_GetAttrs", [](IRModule mod) -> ObjectRef { return mod->GetAttrs(); }) .def("ir.Module_WithAttr", - [](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { + [](ffi::RValueRef mod, ffi::String key, ffi::Any value) -> IRModule { return WithAttr(*std::move(mod), key, value); }) .def("ir.Module_WithoutAttr", - [](ffi::RValueRef mod, String key) -> IRModule { + [](ffi::RValueRef mod, ffi::String key) -> IRModule { return WithoutAttr(*std::move(mod), key); }) .def("ir.Module_WithAttrs", - [](ffi::RValueRef mod, Map attr_map) -> IRModule { + [](ffi::RValueRef mod, ffi::Map attr_map) -> IRModule { return WithAttrs(*std::move(mod), attr_map); }) .def("ir.Module_GetAttr", - [](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); + [](IRModule mod, ffi::String key) -> ObjectRef { return mod->GetAttr(key); }); }); } // namespace tvm diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 24b5e72735a0..253812470313 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -30,13 +30,13 @@ namespace tvm { -NameSupply::NameSupply(const String& prefix, std::unordered_map name_map) { - auto n = make_object(prefix, std::move(name_map)); +NameSupply::NameSupply(const ffi::String& prefix, std::unordered_map name_map) { + auto n = ffi::make_object(prefix, std::move(name_map)); data_ = std::move(n); } -String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { - String final_name = name; +ffi::String NameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name; if (add_prefix) { final_name = add_prefix_to_name(name); } @@ -44,8 +44,9 @@ String NameSupplyNode::ReserveName(const String& name, bool add_prefix) { return final_name; } -String NameSupplyNode::FreshName(const String& name, bool add_prefix, bool add_underscore) { - String unique_name = name; +ffi::String NameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, + bool add_underscore) { + ffi::String unique_name = name; if (add_prefix) { unique_name = add_prefix_to_name(name); } @@ -53,8 +54,8 @@ String NameSupplyNode::FreshName(const String& name, bool add_prefix, bool add_u return unique_name; } -bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { - String unique_name = name; +bool NameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { + ffi::String unique_name = name; if (add_prefix) { unique_name = add_prefix_to_name(name); } @@ -62,7 +63,7 @@ bool NameSupplyNode::ContainsName(const String& name, bool add_prefix) { return name_map.count(unique_name); } -String NameSupplyNode::add_prefix_to_name(const String& name) { +ffi::String NameSupplyNode::add_prefix_to_name(const ffi::String& name) { if (prefix_.empty()) { return name; } @@ -93,7 +94,7 @@ std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.NameSupply", [](String prefix) { return NameSupply(prefix); }) + .def("ir.NameSupply", [](ffi::String prefix) { return NameSupply(prefix); }) .def_method("ir.NameSupply_FreshName", &NameSupplyNode::FreshName) .def_method("ir.NameSupply_ReserveName", &NameSupplyNode::ReserveName) .def_method("ir.NameSupply_ContainsName", &NameSupplyNode::ContainsName); diff --git a/src/ir/op.cc b/src/ir/op.cc index 1bb0e7007b28..a57fcea8e0a2 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -44,36 +44,38 @@ using tir::FLowerIntrinsic; using OpRegistry = AttrRegistry; // find operator by name -const Op& Op::Get(const String& name) { +const Op& Op::Get(const ffi::String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; return reg->op(); } OpRegEntry::OpRegEntry(uint32_t reg_index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->index_ = reg_index; op_ = Op(n); } -OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) { +OpRegEntry& OpRegEntry::RegisterOrGet(const ffi::String& name) { return OpRegistry::Global()->RegisterOrGet(name); } // Get attribute map by key -const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const String& attr_name) { +const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const ffi::String& attr_name) { return OpRegistry::Global()->GetAttrMap(attr_name); } // Check if a key is present in the registry. -bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); } +bool Op::HasAttrMap(const ffi::String& attr_name) { + return OpRegistry::Global()->HasAttrMap(attr_name); +} // Resets attr of the OpAttrMap. void OpRegEntry::reset_attr(const std::string& attr_name) { OpRegistry::Global()->ResetAttr(attr_name, op_); } -void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { +void OpRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel) { OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); } @@ -82,9 +84,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ir.ListOpNames", []() { return OpRegistry::Global()->ListAllNames(); }) - .def("ir.GetOp", [](String name) -> Op { return Op::Get(name); }) + .def("ir.GetOp", [](ffi::String name) -> Op { return Op::Get(name); }) .def("ir.OpGetAttr", - [](Op op, String attr_name) -> ffi::Any { + [](Op op, ffi::String attr_name) -> ffi::Any { auto op_map = Op::GetAttrMap(attr_name); ffi::Any rv; if (op_map.count(op)) { @@ -93,19 +95,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ return rv; }) .def("ir.OpHasAttr", - [](Op op, String attr_name) -> bool { return Op::HasAttrMap(attr_name); }) + [](Op op, ffi::String attr_name) -> bool { return Op::HasAttrMap(attr_name); }) .def("ir.OpSetAttr", - [](Op op, String attr_name, ffi::AnyView value, int plevel) { + [](Op op, ffi::String attr_name, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }) .def("ir.OpResetAttr", - [](Op op, String attr_name) { + [](Op op, ffi::String attr_name) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }) .def("ir.RegisterOp", - [](String op_name, String descr) { + [](ffi::String op_name, ffi::String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; @@ -113,7 +115,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ op.describe(descr); }) .def("ir.OpAddArgument", - [](Op op, String name, String type, String description) { + [](Op op, ffi::String name, ffi::String type, ffi::String description) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.add_argument(name, type, description); }) @@ -128,12 +130,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ reg.set_num_inputs(n); }) .def("ir.OpSetAttrsTypeKey", - [](Op op, String key) { + [](Op op, ffi::String key) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attrs_type_key(key); }) .def("ir.RegisterOpAttr", - [](String op_name, String attr_key, ffi::AnyView value, int plevel) { + [](ffi::String op_name, ffi::String attr_key, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties if (attr_key == "num_inputs" && plevel > 128) { @@ -145,18 +147,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ } }) .def("ir.RegisterOpLowerIntrinsic", - [](String name, ffi::Function f, String target, int plevel) { + [](ffi::String name, ffi::Function f, ffi::String target, int plevel) { tvm::OpRegEntry::RegisterOrGet(name).set_attr( target + ".FLowerIntrinsic", f, plevel); }); // override OpNode to use name as the repr refl::TypeAttrDef() .def("__data_to_json__", - [](const OpNode* node) -> String { + [](const OpNode* node) -> ffi::String { // simply save as the string return node->name; }) - .def("__data_from_json__", [](const String& name) -> Op { return Op::Get(name); }); + .def("__data_from_json__", [](const ffi::String& name) -> Op { return Op::Get(name); }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 9887a111f958..13337dca36a6 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -31,7 +31,7 @@ namespace tvm { namespace transform { -IRModule ReplaceGlobalVars(IRModule mod, Map replacements) { +IRModule ReplaceGlobalVars(IRModule mod, ffi::Map replacements) { if (replacements.empty()) { return mod; } @@ -69,26 +69,30 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); IRModule ModuleReplaceGlobalVars( - IRModule mod, Map, Variant> replacements) { - Map gvar_replacements; + IRModule mod, + ffi::Map, ffi::Variant> + replacements) { + ffi::Map gvar_replacements; for (const auto& [before, after] : replacements) { GlobalVar gvar_before; if (auto gvar = before.as()) { gvar_before = gvar.value(); - } else if (auto str = before.as()) { + } else if (auto str = before.as()) { gvar_before = mod->GetGlobalVar(str.value()); } else { - LOG(FATAL) << "Variant must contain either String or GlobalVar"; + LOG(FATAL) + << "ffi::Variant must contain either ffi::String or GlobalVar"; } GlobalVar gvar_after; if (auto gvar = after.as()) { gvar_after = gvar.value(); - } else if (auto str = after.as()) { + } else if (auto str = after.as()) { gvar_after = gvar_before; gvar_after.CopyOnWrite()->name_hint = str.value(); } else { - LOG(FATAL) << "Variant must contain either String or GlobalVar"; + LOG(FATAL) + << "ffi::Variant must contain either ffi::String or GlobalVar"; } gvar_replacements.Set(gvar_before, gvar_after); diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 588efe9c6a4e..26fbe07cf6d3 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -46,14 +46,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("__data_from_json__", SourceName::Get); }); -ObjectPtr GetSourceNameNode(const String& name) { +ObjectPtr GetSourceNameNode(const ffi::String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map> source_map; + static std::unordered_map> source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); source_map[name] = n; n->name = std::move(name); return n; @@ -66,7 +66,7 @@ ObjectPtr GetSourceNameNodeByStr(const std::string& name) { return GetSourceNameNode(name); } -SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } +SourceName SourceName::Get(const ffi::String& name) { return SourceName(GetSourceNameNode(name)); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); Span::Span(SourceName source_name, int line, int end_line, int column, int end_column) { - auto n = make_object(); + auto n = ffi::make_object(); n->source_name = std::move(source_name); n->line = line; n->end_line = end_line; @@ -99,9 +99,9 @@ Span Span::Merge(const Span& other) const { std::max((*this)->end_column, other->end_column)); } -SequentialSpan::SequentialSpan(tvm::Array spans) { - auto n = make_object(); - tvm::Array tmp_spans; +SequentialSpan::SequentialSpan(tvm::ffi::Array spans) { + auto n = ffi::make_object(); + tvm::ffi::Array tmp_spans; for (const Span& s : spans) { if (const SequentialSpanNode* seq_s = s.as()) { tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); @@ -120,9 +120,9 @@ SequentialSpan::SequentialSpan(tvm::Array spans) { } SequentialSpan::SequentialSpan(std::initializer_list init) { - auto n = make_object(); - tvm::Array spans = tvm::Array(init); - tvm::Array tmp_spans; + auto n = ffi::make_object(); + tvm::ffi::Array spans = tvm::ffi::Array(init); + tvm::ffi::Array tmp_spans; for (const Span& s : spans) { if (const SequentialSpanNode* seq_s = s.as()) { tmp_spans.insert(tmp_spans.end(), seq_s->spans.begin(), seq_s->spans.end()); @@ -147,7 +147,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](SourceName source_name, int line, int end_line, int column, int end_column) { return Span(source_name, line, end_line, column, end_column); }) - .def("ir.SequentialSpan", [](tvm::Array spans) { return SequentialSpan(spans); }); + .def("ir.SequentialSpan", [](tvm::ffi::Array spans) { return SequentialSpan(spans); }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -172,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /*! \brief Construct a source from a string. */ Source::Source(SourceName src_name, std::string source) { - auto n = make_object(); + auto n = ffi::make_object(); n->source_name = std::move(src_name); n->source = std::move(source); @@ -201,7 +201,7 @@ Source::Source(SourceName src_name, std::string source) { data_ = n; } -tvm::String Source::GetLine(int line) { +tvm::ffi::String Source::GetLine(int line) { VLOG(1) << "Source::GetLine: line=" << line; ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) @@ -212,14 +212,14 @@ tvm::String Source::GetLine(int line) { int line_start = range.first; int line_length = range.second; VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; - // TODO(@jroesch): expose substring on tvm::String. + // TODO(@jroesch): expose substring on tvm::ffi::String. auto line_text = std::string((*this)->source).substr(line_start, line_length); VLOG(1) << "Source::GetLine: line_text=" << line_text; return line_text; } -SourceMap::SourceMap(Map source_map) { - auto n = make_object(); +SourceMap::SourceMap(ffi::Map source_map) { + auto n = ffi::make_object(); n->source_map = std::move(source_map); data_ = std::move(n); } @@ -228,7 +228,7 @@ void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->sour TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("SourceMapAdd", [](SourceMap map, String name, String content) { + refl::GlobalDef().def("SourceMapAdd", [](SourceMap map, ffi::String name, ffi::String content) { auto src_name = SourceName::Get(name); Source source(src_name, content); map.Add(source); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d82f02f3dfb9..cd7349f1e489 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -54,7 +54,9 @@ struct PassContextThreadLocalEntry { /*! \brief The current pass context. */ std::stack context_stack; - PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } + PassContextThreadLocalEntry() { + default_context = PassContext(ffi::make_object()); + } }; /*! \brief Thread local store to hold the pass context. */ @@ -86,7 +88,7 @@ PassContext PassContext::Current() { } // linearly scan the pass array to match pass_name -bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { +bool PassArrayContains(const ffi::Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { if (x == pass_name) return true; } @@ -107,7 +109,7 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, String value_type_str, + void Register(std::string key, ffi::String value_type_str, std::function legalization) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; @@ -117,7 +119,7 @@ class PassConfigManager { } // Trying to validate and legalize a config. - void Legalize(Map* config) { + void Legalize(ffi::Map* config) { std::vector> update; for (auto [key, value] : *config) { auto it = key2vtype_.find(key); @@ -149,10 +151,10 @@ class PassConfigManager { } } - Map> ListConfigs() { - Map> configs; + ffi::Map> ListConfigs() { + ffi::Map> configs; for (const auto& kv : key2vtype_) { - Map metadata; + ffi::Map metadata; metadata.Set("type", kv.second.type_str); configs.Set(kv.first, metadata); } @@ -173,20 +175,20 @@ class PassConfigManager { std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, String value_type_str, +void PassContext::RegisterConfigOption(const char* key, ffi::String value_type_str, std::function legalization) { PassConfigManager::Global()->Register(key, value_type_str, legalization); } -Map> PassContext::ListConfigs() { +ffi::Map> PassContext::ListConfigs() { return PassConfigManager::Global()->ListConfigs(); } -PassContext PassContext::Create() { return PassContext(make_object()); } +PassContext PassContext::Create() { return PassContext(ffi::make_object()); } namespace { struct ClearOnError { - Array* instruments{nullptr}; + ffi::Array* instruments{nullptr}; ~ClearOnError() { if (instruments) { @@ -244,7 +246,7 @@ struct ExitPassSuccesses { bool all_initialized{false}; std::vector successes; - Array* instruments{nullptr}; + ffi::Array* instruments{nullptr}; }; } // namespace @@ -378,8 +380,9 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required, bool traceable) { - auto pass_info = make_object(); +PassInfo::PassInfo(int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { + auto pass_info = ffi::make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); @@ -389,7 +392,7 @@ PassInfo::PassInfo(int opt_level, String name, tvm::Array required, bool ModulePass::ModulePass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -429,15 +432,15 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c return mod; } -Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { - auto n = make_object(); +Sequential::Sequential(tvm::ffi::Array passes, PassInfo pass_info) { + auto n = ffi::make_object(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } -Sequential::Sequential(tvm::Array passes, String name) { - auto n = make_object(); +Sequential::Sequential(tvm::ffi::Array passes, ffi::String name) { + auto n = ffi::make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); @@ -457,7 +460,7 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { << "\n"; } -Pass GetPass(const String& pass_name) { +Pass GetPass(const ffi::String& pass_name) { std::optional f; if (pass_name.operator std::string().find("transform.") != std::string::npos) { f = tvm::ffi::Function::GetGlobal(pass_name); @@ -492,7 +495,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c } Pass CreateModulePass(std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable) { + ffi::String name, tvm::ffi::Array required, bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(std::move(pass_func), pass_info); } @@ -501,9 +504,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("transform.PassInfo", - [](int opt_level, String name, tvm::Array required, bool traceable) { - return PassInfo(opt_level, name, required, traceable); - }) + [](int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { return PassInfo(opt_level, name, required, traceable); }) .def_packed("transform.Info", [](ffi::PackedArgs args, ffi::Any* ret) { Pass pass = args[0].cast(); *ret = pass->Info(); @@ -561,10 +563,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("transform.Sequential", [](ffi::PackedArgs args, ffi::Any* ret) { - auto passes = args[0].cast>(); + auto passes = args[0].cast>(); int opt_level = args[1].cast(); std::string name = args[2].cast(); - auto required = args[3].cast>(); + auto required = args[3].cast>(); bool traceable = args[4].cast(); PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); @@ -589,8 +591,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "transform.PassContext", - [](int opt_level, Array required, Array disabled, - Array instruments, Optional> config) { + [](int opt_level, ffi::Array required, ffi::Array disabled, + ffi::Array instruments, + ffi::Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -634,14 +637,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("transform.EnterPassContext", PassContext::Internal::EnterScope) .def("transform.ExitPassContext", PassContext::Internal::ExitScope) .def("transform.OverrideInstruments", - [](PassContext pass_ctx, Array instruments) { + [](PassContext pass_ctx, ffi::Array instruments) { pass_ctx.InstrumentExitPassContext(); pass_ctx->instruments = instruments; pass_ctx.InstrumentEnterPassContext(); }); }); -Pass PrintIR(String header, bool show_meta_data) { +Pass PrintIR(ffi::String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; diff --git a/src/ir/type.cc b/src/ir/type.cc index 4afa785aaedd..dc2bfb984b22 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -36,7 +36,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); PrimType::PrimType(runtime::DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->span = std::move(span); data_ = std::move(n); @@ -47,8 +47,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("ir.PrimType", [](runtime::DataType dtype) { return PrimType(dtype); }); }); -PointerType::PointerType(Type element_type, String storage_scope) { - ObjectPtr n = make_object(); +PointerType::PointerType(Type element_type, ffi::String storage_scope) { + ObjectPtr n = ffi::make_object(); if (storage_scope.empty()) { n->storage_scope = "global"; } else { @@ -60,13 +60,13 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.PointerType", [](Type element_type, String storage_scope = "") { + refl::GlobalDef().def("ir.PointerType", [](Type element_type, ffi::String storage_scope = "") { return PointerType(element_type, storage_scope); }); }); -FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { - ObjectPtr n = make_object(); +FuncType::FuncType(tvm::ffi::Array arg_types, Type ret_type, Span span) { + ObjectPtr n = ffi::make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->span = std::move(span); @@ -75,29 +75,29 @@ FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.FuncType", [](tvm::Array arg_types, Type ret_type) { + refl::GlobalDef().def("ir.FuncType", [](tvm::ffi::Array arg_types, Type ret_type) { return FuncType(arg_types, ret_type); }); }); -TupleType::TupleType(Array fields, Span span) { - ObjectPtr n = make_object(); +TupleType::TupleType(ffi::Array fields, Span span) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); data_ = std::move(n); } -TupleType TupleType::Empty() { return TupleType(Array()); } +TupleType TupleType::Empty() { return TupleType(ffi::Array()); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("ir.TupleType", [](Array fields) { return TupleType(fields); }) + .def("ir.TupleType", [](ffi::Array fields) { return TupleType(fields); }) .def("ir.TensorMapType", [](Span span) { return TensorMapType(span); }); }); TensorMapType::TensorMapType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = std::move(span); data_ = std::move(n); } diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 774c9d8f245f..3c81ca107eab 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -49,7 +49,7 @@ Type TypeMutator::VisitType(const Type& t) { } // Type Mutator. -Array TypeMutator::MutateArray(Array arr) { +ffi::Array TypeMutator::MutateArray(ffi::Array arr) { // The array will do copy on write // If no changes are made, the original array will be returned. return arr.Map([this](const Type& ty) { return VisitType(ty); }); @@ -58,32 +58,32 @@ Array TypeMutator::MutateArray(Array arr) { Type TypeMutator::VisitType_(const FuncTypeNode* op) { bool changed = false; - Array new_args = MutateArray(op->arg_types); + ffi::Array new_args = MutateArray(op->arg_types); changed = changed || !new_args.same_as(op->arg_types); Type new_ret_type = VisitType(op->ret_type); changed = changed || !new_ret_type.same_as(op->ret_type); - if (!changed) return GetRef(op); + if (!changed) return ffi::GetRef(op); return FuncType(new_args, new_ret_type); } Type TypeMutator::VisitType_(const TupleTypeNode* op) { - Array new_fields = MutateArray(op->fields); + ffi::Array new_fields = MutateArray(op->fields); if (new_fields.same_as(op->fields)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TupleType(new_fields); } } -Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } +Type TypeMutator::VisitType_(const PrimTypeNode* op) { return ffi::GetRef(op); } Type TypeMutator::VisitType_(const PointerTypeNode* op) { Type element_type = VisitType(op->element_type); if (element_type.same_as(op->element_type)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PointerType(element_type, op->storage_scope); } diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 12c6e29eb295..44fa338fefa1 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -40,7 +40,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { if (const auto* func = base_func.as()) { last_func = func; if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - return GetRef(func); + return ffi::GetRef(func); } if (gv->name_hint == "main") { main_func = func; @@ -50,7 +50,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 2: PrimFunc whose name is `main` if (main_func != nullptr) { - return GetRef(main_func); + return ffi::GetRef(main_func); } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { @@ -61,7 +61,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" << mod; } - return GetRef(last_func); + return ffi::GetRef(last_func); } /******** ArgInfo ********/ @@ -69,11 +69,11 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { // The JSON object is always an array whose first element is a tag. For example: // `['TENSOR', 'float32', [1, 224, 224, 3]] // Step 1. Extract the tag - Optional tag{std::nullopt}; + ffi::Optional tag{std::nullopt}; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() >= 1); - tag = json_array->at(0).cast(); + tag = json_array->at(0).cast(); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj << "\nThe error is: " << e.what(); @@ -86,12 +86,12 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { throw; } -Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { +ffi::Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { using support::AsVector; - Array result; + ffi::Array result; result.reserve(func->params.size()); for (const tir::Var& arg : func->params) { - if (Optional _buffer = func->buffer_map.Get(arg)) { + if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { tir::Buffer buffer = _buffer.value(); result.push_back(TensorInfo(/*dtype=*/buffer->dtype, /*shape=*/AsVector(buffer->shape))); @@ -102,7 +102,7 @@ Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { return result; } -Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { +ffi::Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { if (remove_preproc) { IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_tensor_rewrite*/ true)(mod); @@ -114,28 +114,28 @@ Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) /******** TensorInfo ********/ TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->shape = shape; this->data_ = std::move(n); } ObjectRef TensorInfoNode::AsJSON() const { - static String tag = "TENSOR"; - String dtype = DLDataTypeToString(this->dtype); - Array shape = support::AsArray(this->shape); - return Array{tag, dtype, shape}; + static ffi::String tag = "TENSOR"; + ffi::String dtype = DLDataTypeToString(this->dtype); + ffi::Array shape = support::AsArray(this->shape); + return ffi::Array{tag, dtype, shape}; } TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { DLDataType dtype; - Array shape; + ffi::Array shape; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 3); // Load json[1] => dtype { - String dtype_str = json_array->at(1).cast(); + ffi::String dtype_str = json_array->at(1).cast(); dtype = StringToDLDataType(dtype_str); } // Load json[2] => shape diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 5657a362acce..c4822f41971c 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -26,23 +26,24 @@ namespace meta_schedule { /******** Constructors ********/ BuilderInput::BuilderInput(IRModule mod, Target target, - Optional> params) { - ObjectPtr n = make_object(); + ffi::Optional> params) { + ObjectPtr n = ffi::make_object(); n->mod = std::move(mod); n->target = std::move(target); n->params = std::move(params); data_ = std::move(n); } -BuilderResult::BuilderResult(Optional artifact_path, Optional error_msg) { - ObjectPtr n = make_object(); +BuilderResult::BuilderResult(ffi::Optional artifact_path, + ffi::Optional error_msg) { + ObjectPtr n = ffi::make_object(); n->artifact_path = std::move(artifact_path); n->error_msg = std::move(error_msg); data_ = std::move(n); } Builder Builder::PyBuilder(BuilderNode::FBuild f_build) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_build = std::move(f_build); return Builder(std::move(n)); } @@ -59,12 +60,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.BuilderInput", - [](IRModule mod, Target target, Optional> params) - -> BuilderInput { return BuilderInput(mod, target, params); }) - .def("meta_schedule.BuilderResult", - [](Optional artifact_path, Optional error_msg) -> BuilderResult { - return BuilderResult(artifact_path, error_msg); + [](IRModule mod, Target target, + ffi::Optional> params) -> BuilderInput { + return BuilderInput(mod, target, params); }) + .def("meta_schedule.BuilderResult", + [](ffi::Optional artifact_path, ffi::Optional error_msg) + -> BuilderResult { return BuilderResult(artifact_path, error_msg); }) .def_method("meta_schedule.BuilderBuild", &BuilderNode::Build) .def("meta_schedule.BuilderPyBuilder", Builder::PyBuilder); }); diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 242939802885..dddb798af2fe 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -23,24 +23,25 @@ namespace tvm { namespace meta_schedule { -void PyCostModelNode::Load(const String& path) { +void PyCostModelNode::Load(const ffi::String& path) { ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; f_load(path); } -void PyCostModelNode::Save(const String& path) { +void PyCostModelNode::Save(const ffi::String& path) { ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; f_save(path); } -void PyCostModelNode::Update(const TuneContext& context, const Array& candidates, - const Array& results) { +void PyCostModelNode::Update(const TuneContext& context, + const ffi::Array& candidates, + const ffi::Array& results) { ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; f_update(context, candidates, results); } std::vector PyCostModelNode::Predict(const TuneContext& context, - const Array& candidates) { + const ffi::Array& candidates) { ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; std::vector result(candidates.size(), 0.0); f_predict(context, candidates, result.data()); @@ -52,7 +53,7 @@ CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // PyCostModelNode::FUpdate f_update, // PyCostModelNode::FPredict f_predict, // PyCostModelNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_load = std::move(f_load); n->f_save = std::move(f_save); n->f_update = std::move(f_update); @@ -77,9 +78,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.CostModelSave", &CostModelNode::Save) .def_method("meta_schedule.CostModelUpdate", &CostModelNode::Update) .def("meta_schedule.CostModelPredict", - [](CostModel model, // - const TuneContext& context, // - Array candidates, // + [](CostModel model, // + const TuneContext& context, // + ffi::Array candidates, // void* p_addr) -> void { std::vector result = model->Predict(context, candidates); std::copy(result.begin(), result.end(), static_cast(p_addr)); diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 3b96ed0ca8b0..b3c02607bddc 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -46,7 +46,7 @@ ObjectRef WorkloadNode::AsJSON() const { // Dump the JSON string to base64 std::string b64_mod = Base64Encode(json_mod); // Output - return Array{SHash2Str(this->shash), String(b64_mod)}; + return ffi::Array{SHash2Str(this->shash), ffi::String(b64_mod)}; } Workload Workload::FromJSON(const ObjectRef& json_obj) { @@ -56,10 +56,10 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 2); // Load json[0] => shash - String str_shash = json_array->at(0).cast(); + ffi::String str_shash = json_array->at(0).cast(); // Load json[1] => mod { - String b64_mod = json_array->at(1).cast(); + ffi::String b64_mod = json_array->at(1).cast(); std::string json_mod = Base64Decode(b64_mod); mod = LoadJSON(json_mod).cast(); std::stringstream(str_shash) >> shash; @@ -73,9 +73,11 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { /******** TuningRecord ********/ -TuningRecord::TuningRecord(tir::Trace trace, Workload workload, Optional> run_secs, - Optional target, Optional> args_info) { - ObjectPtr n = make_object(); +TuningRecord::TuningRecord(tir::Trace trace, Workload workload, + ffi::Optional> run_secs, + ffi::Optional target, + ffi::Optional> args_info) { + ObjectPtr n = ffi::make_object(); n->trace = trace; n->workload = workload; n->run_secs = run_secs; @@ -96,10 +98,10 @@ MeasureCandidate TuningRecordNode::AsMeasureCandidate() const { } ObjectRef TuningRecordNode::AsJSON() const { - Optional> json_args_info; - Optional json_target; + ffi::Optional> json_args_info; + ffi::Optional json_target; if (args_info.defined()) { - Array info; + ffi::Array info; info.reserve(args_info.value().size()); for (const ArgInfo& arg_info : args_info.value()) { info.push_back(arg_info->AsJSON()); @@ -109,10 +111,10 @@ ObjectRef TuningRecordNode::AsJSON() const { if (target.defined()) { json_target = target.value()->Export(); } - return Array{trace->AsJSON(false), // - run_secs, // - json_target, // - json_args_info}; + return ffi::Array{trace->AsJSON(false), // + run_secs, // + json_target, // + json_args_info}; } bool TuningRecordNode::IsValid() const { @@ -132,9 +134,9 @@ bool TuningRecordNode::IsValid() const { TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { tir::Trace trace{nullptr}; - Optional> run_secs; - Optional target; - Optional> args_info; + ffi::Optional> run_secs; + ffi::Optional target; + ffi::Optional> args_info; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 4); @@ -144,12 +146,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } // Load json[2] => target if (json_array->at(2) != nullptr) { - target = Target(json_array->at(2).cast>()); + target = Target(json_array->at(2).cast>()); } // Load json[3] => args_info if (json_array->at(3) != nullptr) { const ffi::ArrayObj* json_args_info = json_array->at(3).cast(); - Array info; + ffi::Array info; info.reserve(json_args_info->size()); for (Any json_arg_info : *json_args_info) { info.push_back(ArgInfo::FromJSON(json_arg_info.cast())); @@ -173,15 +175,18 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } /******** Database ********/ -DatabaseNode::DatabaseNode(String mod_eq_name) { mod_eq_ = ModuleEquality::Create(mod_eq_name); } +DatabaseNode::DatabaseNode(ffi::String mod_eq_name) { + mod_eq_ = ModuleEquality::Create(mod_eq_name); +} DatabaseNode::~DatabaseNode() = default; -Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) { +ffi::Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, + const Target& target, + const ffi::String& workload_name) { if (!this->HasWorkload(mod)) { return std::nullopt; } - Array records = this->GetTopK(this->CommitWorkload(mod), 1); + ffi::Array records = this->GetTopK(this->CommitWorkload(mod), 1); if (records.empty()) { return std::nullopt; } @@ -189,9 +194,10 @@ Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, cons return records[0]; } -Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) { - if (Optional opt_record = this->QueryTuningRecord(mod, target, workload_name)) { +ffi::Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) { + if (ffi::Optional opt_record = + this->QueryTuningRecord(mod, target, workload_name)) { TuningRecord record = opt_record.value(); tir::Schedule sch = tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, @@ -203,9 +209,9 @@ Optional DatabaseNode::QuerySchedule(const IRModule& mod, const T } } -Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, - const String& workload_name) { - if (Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { +ffi::Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) { + if (ffi::Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { return opt_sch.value()->mod(); } else { return std::nullopt; @@ -244,7 +250,7 @@ void Database::EnterWithScope() { ThreadLocalDatabases()->push_back(*this); } void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); } -Optional Database::Current() { +ffi::Optional Database::Current() { std::vector* tls = ThreadLocalDatabases(); if (tls->empty()) { return std::nullopt; @@ -254,7 +260,7 @@ Optional Database::Current() { } /******** PyDatabase ********/ -PyDatabaseNode::PyDatabaseNode(String mod_eq_name) : DatabaseNode(mod_eq_name) {} +PyDatabaseNode::PyDatabaseNode(ffi::String mod_eq_name) : DatabaseNode(mod_eq_name) {} Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, @@ -264,8 +270,8 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FQueryTuningRecord f_query_tuning_record, PyDatabaseNode::FQuerySchedule f_query_schedule, PyDatabaseNode::FQueryIRModule f_query_ir_module, - PyDatabaseNode::FSize f_size, String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); + PyDatabaseNode::FSize f_size, ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; n->f_commit_tuning_record = f_commit_tuning_record; @@ -293,8 +299,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.WorkloadAsJSON", &WorkloadNode::AsJSON) .def("meta_schedule.WorkloadFromJSON", &Workload::FromJSON) .def("meta_schedule.TuningRecord", - [](tir::Trace trace, Workload workload, Optional> run_secs, - Optional target, Optional> args_info) { + [](tir::Trace trace, Workload workload, ffi::Optional> run_secs, + ffi::Optional target, ffi::Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }) .def_method("meta_schedule.TuningRecordAsMeasureCandidate", diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index fd24072aae8f..10274fd2f792 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -57,10 +57,10 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { os << "]"; } else if (const auto* dict = json_obj.as()) { int n = dict->size(); - std::vector> key_values; + std::vector> key_values; key_values.reserve(n); for (const auto& kv : *dict) { - if (auto key = kv.first.try_cast()) { + if (auto key = kv.first.try_cast()) { key_values.emplace_back(key.value(), kv.second); } else { LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " @@ -81,7 +81,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { } os << "}"; } else if (json_obj.as()) { - JSONDumps(String(SaveJSON(json_obj)), os); + JSONDumps(ffi::String(SaveJSON(json_obj)), os); } else { LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj.GetTypeKey(); } @@ -241,7 +241,7 @@ class JSONTokenizer { LOG(FATAL) << "ValueError: Unexpected end of string"; } ++cur_; - *token = Token{TokenType::kString, String(str)}; + *token = Token{TokenType::kString, ffi::String(str)}; return true; } @@ -315,9 +315,9 @@ class JSONParser { } } - Array ParseArray() { + ffi::Array ParseArray() { bool is_first = true; - Array results; + ffi::Array results; for (;;) { Token token; if (is_first) { @@ -347,9 +347,9 @@ class JSONParser { return results; } - Map ParseDict() { + ffi::Map ParseDict() { bool is_first = true; - Map results; + ffi::Map results; for (;;) { Token token; if (is_first) { @@ -376,7 +376,7 @@ class JSONParser { CHECK(token.type == TokenType::kColon) << "ValueError: Unexpected token before: " << tokenizer_.cur_; Any value = ParseObject(tokenizer_.Next()); - results.Set(Downcast(key), value); + results.Set(Downcast(key), value); continue; } else { LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index aeae22f4ca41..cef4b6437ba2 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -35,10 +35,10 @@ namespace meta_schedule { * \param allow_missing Whether to create new file when the given path is not found. * \return An array containing lines read from the json file. */ -std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing) { +std::vector JSONFileReadLines(const ffi::String& path, int num_threads, bool allow_missing) { std::ifstream is(path); if (is.good()) { - std::vector json_strs; + std::vector json_strs; for (std::string str; std::getline(is, str);) { json_strs.push_back(str); } @@ -61,7 +61,7 @@ std::vector JSONFileReadLines(const String& path, int num_threads, bool all * \param path The path to the json file. * \param line The line to append. */ -void JSONFileAppendLine(const String& path, const std::string& line) { +void JSONFileAppendLine(const ffi::String& path, const std::string& line) { std::ofstream os(path, std::ofstream::app); CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; os << line << std::endl; @@ -70,14 +70,14 @@ void JSONFileAppendLine(const String& path, const std::string& line) { /*! \brief The default database implementation, which mimics two database tables with two files. */ class JSONDatabaseNode : public DatabaseNode { public: - explicit JSONDatabaseNode(String mod_eq_name = "structural") + explicit JSONDatabaseNode(ffi::String mod_eq_name = "structural") : DatabaseNode(mod_eq_name), workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) {} /*! \brief The path to the workload table */ - String path_workload; + ffi::String path_workload; /*! \brief The path to the tuning record table */ - String path_tuning_record; + ffi::String path_tuning_record; /*! \brief All the workloads in the database */ std::unordered_map workloads2idx_; /*! \brief All the tuning records in the database */ @@ -115,18 +115,18 @@ class JSONDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) { this->tuning_records_.insert(record); JSONFileAppendLine(this->path_tuning_record, - JSONDumps(Array{ + JSONDumps(ffi::Array{ /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), /*tuning_record=*/record->AsJSON() // })); } - Array GetTopK(const Workload& workload, int top_k) { + ffi::Array GetTopK(const Workload& workload, int top_k) { CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; if (top_k == 0) { return {}; } - Array results; + ffi::Array results; results.reserve(top_k); for (const TuningRecord& record : this->tuning_records_) { auto run_secs = record->run_secs; @@ -144,8 +144,8 @@ class JSONDatabaseNode : public DatabaseNode { return results; } - Array GetAllTuningRecords() { - Array results; + ffi::Array GetAllTuningRecords() { + ffi::Array results; results.reserve(Size()); for (const TuningRecord& record : this->tuning_records_) { results.push_back(record); @@ -156,10 +156,10 @@ class JSONDatabaseNode : public DatabaseNode { int64_t Size() { return tuning_records_.size(); } }; -Database Database::JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing, - String mod_eq_name) { +Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuning_record, + bool allow_missing, ffi::String mod_eq_name) { int num_threads = std::thread::hardware_concurrency(); - ObjectPtr n = make_object(mod_eq_name); + ObjectPtr n = ffi::make_object(mod_eq_name); // Load `n->workloads2idx_` from `path_workload` std::vector workloads; { @@ -173,7 +173,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, // Todo(tvm-team): re-enable the shash check when we get environment // independent structural hash values. if (recalc_hash != workload->shash) { - ObjectPtr wkl = make_object(*workload.get()); + ObjectPtr wkl = ffi::make_object(*workload.get()); wkl->shash = recalc_hash; workload = Workload(wkl); } diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index ec08fd62a232..8c355dc0e5c5 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -26,10 +26,10 @@ namespace meta_schedule { class MemoryDatabaseNode : public DatabaseNode { public: - explicit MemoryDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} + explicit MemoryDatabaseNode(ffi::String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} - Array records; - Array workloads; + ffi::Array records; + ffi::Array workloads; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -64,7 +64,7 @@ class MemoryDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; if (top_k == 0) { return {}; @@ -88,13 +88,13 @@ class MemoryDatabaseNode : public DatabaseNode { } } - Array GetAllTuningRecords() final { return records; } + ffi::Array GetAllTuningRecords() final { return records; } int64_t Size() final { return records.size(); } }; -Database Database::MemoryDatabase(String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); +Database Database::MemoryDatabase(ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->records.clear(); n->workloads.clear(); return Database(n); diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 07526fbc45ab..3446517132a4 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -25,7 +25,7 @@ namespace meta_schedule { class OrderedUnionDatabaseNode : public DatabaseNode { public: - Array databases; + ffi::Array databases; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -37,10 +37,10 @@ class OrderedUnionDatabaseNode : public DatabaseNode { TVM_DECLARE_FINAL_OBJECT_INFO(OrderedUnionDatabaseNode, DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& task_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& task_name) final { for (const Database& db : databases) { - if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + if (ffi::Optional record = db->QueryTuningRecord(mod, target, task_name)) { return record; } } @@ -62,12 +62,12 @@ class OrderedUnionDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetAllTuningRecords"; throw; } @@ -78,8 +78,8 @@ class OrderedUnionDatabaseNode : public DatabaseNode { } }; -Database Database::OrderedUnionDatabase(Array databases) { - ObjectPtr n = make_object(); +Database Database::OrderedUnionDatabase(ffi::Array databases) { + ObjectPtr n = ffi::make_object(); n->databases = std::move(databases); return Database(n); } diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 1f85654cfa0c..32c6e0194f49 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -25,7 +25,8 @@ namespace meta_schedule { class ScheduleFnDatabaseNode : public DatabaseNode { public: - explicit ScheduleFnDatabaseNode(String mod_eq_name = "structural") : DatabaseNode(mod_eq_name) {} + explicit ScheduleFnDatabaseNode(ffi::String mod_eq_name = "structural") + : DatabaseNode(mod_eq_name) {} ffi::TypedFunction schedule_fn; @@ -39,9 +40,9 @@ class ScheduleFnDatabaseNode : public DatabaseNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& workload_name) final { - if (Optional sch = this->QuerySchedule(mod, target, workload_name)) { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { + if (ffi::Optional sch = this->QuerySchedule(mod, target, workload_name)) { return TuningRecord(sch.value()->trace().value(), /*workload=*/Workload(mod, 0), // /*run_secs=*/std::nullopt, // @@ -51,8 +52,8 @@ class ScheduleFnDatabaseNode : public DatabaseNode { return std::nullopt; } - Optional QuerySchedule(const IRModule& mod, const Target& target, - const String& workload_name) final { + ffi::Optional QuerySchedule(const IRModule& mod, const Target& target, + const ffi::String& workload_name) final { tir::Schedule sch = tir::Schedule::Traced(WithAttr(mod, "task_name", workload_name), /*rand_state=*/-1, @@ -79,12 +80,12 @@ class ScheduleFnDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords"; throw; } @@ -96,8 +97,8 @@ class ScheduleFnDatabaseNode : public DatabaseNode { }; Database Database::ScheduleFnDatabase(ffi::TypedFunction schedule_fn, - String mod_eq_name) { - ObjectPtr n = make_object(mod_eq_name); + ffi::String mod_eq_name) { + ObjectPtr n = ffi::make_object(mod_eq_name); n->schedule_fn = std::move(schedule_fn); return Database(n); } diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 38864a5fcc03..82e76ad43f2d 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -25,7 +25,7 @@ namespace meta_schedule { class UnionDatabaseNode : public DatabaseNode { public: - Array databases; + ffi::Array databases; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -36,17 +36,17 @@ class UnionDatabaseNode : public DatabaseNode { TVM_DECLARE_FINAL_OBJECT_INFO(UnionDatabaseNode, DatabaseNode); public: - Optional QueryTuningRecord(const IRModule& mod, const Target& target, - const String& task_name) final { + ffi::Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const ffi::String& task_name) final { std::vector results; results.reserve(databases.size()); for (const Database& db : databases) { - if (Optional record = db->QueryTuningRecord(mod, target, task_name)) { + if (ffi::Optional record = db->QueryTuningRecord(mod, target, task_name)) { results.push_back(record.value()); } } std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); - return results.empty() ? Optional(std::nullopt) : results[0]; + return results.empty() ? ffi::Optional(std::nullopt) : results[0]; } bool HasWorkload(const IRModule& mod) final { @@ -64,12 +64,12 @@ class UnionDatabaseNode : public DatabaseNode { throw; } - Array GetTopK(const Workload& workload, int top_k) final { + ffi::Array GetTopK(const Workload& workload, int top_k) final { LOG(FATAL) << "NotImplementedError: UnionDatabase.GetTopK"; throw; } - Array GetAllTuningRecords() final { + ffi::Array GetAllTuningRecords() final { LOG(FATAL) << "NotImplementedError: UnionDatabase.GetAllTuningRecords"; throw; } @@ -80,8 +80,8 @@ class UnionDatabaseNode : public DatabaseNode { } }; -Database Database::UnionDatabase(Array databases) { - ObjectPtr n = make_object(); +Database Database::UnionDatabase(ffi::Array databases) { + ObjectPtr n = ffi::make_object(); n->databases = std::move(databases); return Database(n); } diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index 41980adc0034..ad93f1d5e8ab 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -28,9 +28,9 @@ namespace tvm { namespace meta_schedule { -ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, - Array dispatched, int weight) { - ObjectPtr n = make_object(); +ExtractedTask::ExtractedTask(ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight) { + ObjectPtr n = ffi::make_object(); n->task_name = task_name; n->mod = mod; n->target = target; @@ -44,8 +44,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ ExtractedTaskNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("meta_schedule.ExtractedTask", - [](String task_name, IRModule mod, Target target, - Array dispatched, int weight) -> ExtractedTask { + [](ffi::String task_name, IRModule mod, Target target, + ffi::Array dispatched, int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); }); }); diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index e2fa1fc176b4..983d24ed25c6 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -23,8 +23,8 @@ namespace tvm { namespace meta_schedule { -Array PyFeatureExtractorNode::ExtractFrom( - const TuneContext& context, const Array& candidates) { +ffi::Array PyFeatureExtractorNode::ExtractFrom( + const TuneContext& context, const ffi::Array& candidates) { ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; return f_extract_from(context, candidates); } @@ -32,7 +32,7 @@ Array PyFeatureExtractorNode::ExtractFrom( FeatureExtractor FeatureExtractor::PyFeatureExtractor( PyFeatureExtractorNode::FExtractFrom f_extract_from, // PyFeatureExtractorNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_extract_from = std::move(f_extract_from); n->f_as_string = std::move(f_as_string); return FeatureExtractor(n); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 7c9a809e7178..549e3d58541d 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -84,7 +84,8 @@ std::vector GetBufferShape(const Buffer& buffer, arith::Analyzer* analy * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist */ int64_t GetPragmaAutoUnroll(const ForNode* loop) { - if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + if (ffi::Optional auto_unroll = + GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { return auto_unroll.value()->value; } return -1; @@ -267,16 +268,16 @@ Pass SimplifyForFeatureExtraction() { PrimExpr VisitExpr_(const SelectNode* node) final { if (HasBufferLoad(node->true_value) || HasBufferLoad(node->false_value) || HasBufferLoad(node->condition)) { - return GetRef(node); } return make_const(node->dtype, 1.0); } PrimExpr VisitExpr_(const VarNode* var) final { - if (unit_vars_.count(GetRef(var))) { + if (unit_vars_.count(ffi::GetRef(var))) { return make_const(var->dtype, 0.0); } - return GetRef(var); + return ffi::GetRef(var); } Stmt VisitStmt_(const ForNode* loop) final { @@ -859,7 +860,7 @@ void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* // For each buffer, we find the loop stride on it const BufferNode* buffer = this->buffer; int ndim = this->buffer->shape.size(); - IntVec buffer_shape = utils::GetBufferShape(GetRef(buffer), analyzer); + IntVec buffer_shape = utils::GetBufferShape(ffi::GetRef(buffer), analyzer); // Calculate the buffer's stride from its shape IntVec buffer_stride(ndim); if (ndim >= 1) { @@ -1398,8 +1399,8 @@ class PerStoreFeatureNode : public FeatureExtractorNode { } } - Array ExtractFrom(const TuneContext& tune_context, - const Array& candidates) { + ffi::Array ExtractFrom(const TuneContext& tune_context, + const ffi::Array& candidates) { auto& target_keys = tune_context->target.value()->keys; bool is_gpu = std::find(target_keys.begin(), target_keys.end(), "gpu") != target_keys.end(); std::vector results; @@ -1430,7 +1431,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode { FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, int arith_intensity_curve_num_samples, int cache_line_bytes, bool extract_workload) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->buffers_per_store = buffers_per_store; n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; n->cache_line_bytes = cache_line_bytes; diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 89b2934fe28e..320233bdf848 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -26,9 +26,9 @@ namespace meta_schedule { class AddToDatabaseNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { if (!task_scheduler->database_.defined()) { return; } @@ -42,11 +42,11 @@ class AddToDatabaseNode : public MeasureCallbackNode { for (int i = 0; i < n; ++i) { RunnerResult result = runner_results[i]; MeasureCandidate candidate = measure_candidates[i]; - Array run_secs{nullptr}; + ffi::Array run_secs{nullptr}; if (result->run_secs.defined()) { run_secs = result->run_secs.value(); } else { - run_secs = Array{FloatImm(DataType::Float(32), 1e10)}; + run_secs = ffi::Array{FloatImm(DataType::Float(32), 1e10)}; } database->CommitTuningRecord(TuningRecord( /*trace=*/candidate->sch->trace().value(), @@ -62,7 +62,7 @@ class AddToDatabaseNode : public MeasureCallbackNode { }; MeasureCallback MeasureCallback::AddToDatabase() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 08feaf354eee..dbc6b634665d 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -23,11 +23,11 @@ namespace tvm { namespace meta_schedule { -void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, // - int task_id, // - const Array& measure_candidates, // - const Array& builds, // - const Array& results) { +void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, // + int task_id, // + const ffi::Array& measure_candidates, // + const ffi::Array& builds, // + const ffi::Array& results) { ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; auto _ = Profiler::TimedScope("MeasureCallback/" + this->f_as_string()); return f_apply(task_scheduler, task_id, measure_candidates, builds, results); @@ -35,13 +35,13 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // PyMeasureCallbackNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_apply = std::move(f_apply); n->f_as_string = std::move(f_as_string); return MeasureCallback(n); } -Array MeasureCallback::Default() { +ffi::Array MeasureCallback::Default() { return { MeasureCallback::AddToDatabase(), MeasureCallback::RemoveBuildArtifact(), diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index 69fcd186f3c4..455eaeba0fc3 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -26,13 +26,13 @@ namespace meta_schedule { class RemoveBuildArtifactNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { static auto f_rm = tvm::ffi::Function::GetGlobalRequired("meta_schedule.remove_build_dir"); auto _ = Profiler::TimedScope("MeasureCallback/RemoveBuildArtifact"); for (const BuilderResult& build_result : builder_results) { - if (Optional path = build_result->artifact_path) { + if (ffi::Optional path = build_result->artifact_path) { f_rm(path.value()); } } @@ -43,7 +43,7 @@ class RemoveBuildArtifactNode : public MeasureCallbackNode { }; MeasureCallback MeasureCallback::RemoveBuildArtifact() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 1db62d5e5068..80353e3546a4 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -26,9 +26,9 @@ namespace meta_schedule { class UpdateCostModelNode : public MeasureCallbackNode { public: void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { + const ffi::Array& measure_candidates, + const ffi::Array& builder_results, + const ffi::Array& runner_results) final { auto _ = Profiler::TimedScope("MeasureCallback/UpdateCostModel"); const TaskRecord& task = task_scheduler->tasks_[task_id]; if (!task_scheduler->cost_model_.defined()) { @@ -39,8 +39,8 @@ class UpdateCostModelNode : public MeasureCallbackNode { ICHECK_EQ(measure_candidates.size(), builder_results.size()); ICHECK_EQ(runner_results.size(), builder_results.size()); int n = builder_results.size(); - Array pruned_candidate; - Array pruned_runner_result; + ffi::Array pruned_candidate; + ffi::Array pruned_runner_result; pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { @@ -60,7 +60,7 @@ class UpdateCostModelNode : public MeasureCallbackNode { }; MeasureCallback MeasureCallback::UpdateCostModel() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return MeasureCallback(n); } diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index c3b38cf341d9..8eb1f46b0b22 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -34,7 +34,7 @@ class ModuleEqualityStructural : public ModuleEquality { public: size_t Hash(IRModule mod) const { return tvm::StructuralHash()(mod); } bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } - String GetName() const { return "structural"; } + ffi::String GetName() const { return "structural"; } }; class ModuleEqualityIgnoreTensor : public ModuleEquality { @@ -47,7 +47,7 @@ class ModuleEqualityIgnoreTensor : public ModuleEquality { return tvm::ffi::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false, /*skip_tensor_content=*/true); } - String GetName() const { return "ignore-tensor"; } + ffi::String GetName() const { return "ignore-tensor"; } }; // The Tensor-ignoring variant of structural equal / hash is used for the module equality @@ -56,7 +56,7 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { auto anchor_block = tir::FindAnchorBlock(mod); if (anchor_block) { - return ffi::StructuralHash::Hash(GetRef(anchor_block), + return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } @@ -66,14 +66,14 @@ class ModuleEqualityAnchorBlock : public ModuleEquality { auto anchor_block_lhs = tir::FindAnchorBlock(lhs); auto anchor_block_rhs = tir::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return tvm::ffi::StructuralEqual::Equal(GetRef(anchor_block_lhs), - GetRef(anchor_block_rhs), + return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), + ffi::GetRef(anchor_block_rhs), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } return ModuleEqualityIgnoreTensor().Equal(lhs, rhs); } - String GetName() const { return "anchor-block"; } + ffi::String GetName() const { return "anchor-block"; } }; std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index cd337c6d7ede..f9546438157d 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -34,7 +34,7 @@ class ModuleEquality { virtual size_t Hash(IRModule mod) const = 0; virtual bool Equal(IRModule lhs, IRModule rhs) const = 0; - virtual String GetName() const = 0; + virtual ffi::String GetName() const = 0; /*! * \brief Create a ModuleEquality instance diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 7825e8909429..f5be3f36788d 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -47,10 +47,10 @@ class MutateComputeLocationNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } @@ -86,9 +86,9 @@ std::vector MutateComputeLocationNode::Fin InstructionKind::Get("SampleComputeLocation"); std::vector candidates; - auto f_decision_provider = [&](const tir::Instruction& inst, // - const Array& inputs, // - const Array& attrs, // + auto f_decision_provider = [&](const tir::Instruction& inst, // + const ffi::Array& inputs, // + const ffi::Array& attrs, // const Any& decision) -> Any { if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. @@ -118,7 +118,7 @@ std::vector MutateComputeLocationNode::Fin return candidates; } -Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { return std::nullopt; @@ -129,7 +129,7 @@ Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* } Mutator Mutator::MutateComputeLocation() { - return Mutator(make_object()); + return Mutator(ffi::make_object()); } TVM_FFI_STATIC_INIT_BLOCK({ MutateComputeLocationNode::RegisterReflection(); }); diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index b7c532ae5b0f..8a5fc485cf9b 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -37,7 +37,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { return false; } ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::meta_schedule_parallel; } @@ -79,13 +79,13 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { * \return The parallel structure */ std::vector> AnalyzeParallel(const ScheduleState& self, - const String& block_name, const String& func_name, - int64_t limit) { - Array block_srefs = + const ffi::String& block_name, + const ffi::String& func_name, int64_t limit) { + ffi::Array block_srefs = tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]); - ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); std::vector> results; results.reserve(info.realizes.size()); for (const BlockRealize& realize : info.realizes) { @@ -189,10 +189,10 @@ class MutateParallelNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -204,9 +204,9 @@ struct MutateParallelNode::Candidate { /*! \brief The current parallel extent */ int64_t parallel_extent; /*! \brief The name of the root block */ - String block_name; + ffi::String block_name; /*! \brief The name of the PrimFunc */ - String func_name; + ffi::String func_name; }; /*! @@ -241,14 +241,14 @@ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, const InstructionNode* get_block_inst = get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); ICHECK_EQ(get_block_inst->attrs.size(), 2); - candidate->inst = GetRef(ann_inst); + candidate->inst = ffi::GetRef(ann_inst); candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; - candidate->block_name = Downcast(get_block_inst->attrs[0]); - candidate->func_name = Downcast(get_block_inst->attrs[1]); + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); return true; } -Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { // Step 1. Find a parallel decision. Candidate candidate; if (!FindParallelDecision(trace, rand_state, &candidate)) { @@ -293,7 +293,7 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s } int64_t limit = it->second; // Step 6. Assemble a new trace - Array insts; + ffi::Array insts; insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { if (inst.same_as(candidate.inst)) { @@ -308,7 +308,7 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s } Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; return Mutator(n); } diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 26e3a4709a91..aff00a600e77 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -47,10 +47,10 @@ class MutateThreadBindingNode : public MutatorNode { this->json_mod_ = SaveJSON(context->mod.value()); } // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } @@ -111,7 +111,7 @@ std::vector MutateThreadBindingNode::FindCan } ICHECK_EQ(inst->inputs.size(), 1); ICHECK_EQ(inst->attrs.size(), 1); - if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; + if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != sampled_split_insts.end(); @@ -143,17 +143,17 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + int decision = Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; std::vector probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + support::AsVector(Downcast>(sample_inst->attrs[1])); - candidates.emplace_back(GetRef(sample_inst), probs, decision); + candidates.emplace_back(ffi::GetRef(sample_inst), probs, decision); } return candidates; } -Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { return std::nullopt; @@ -168,7 +168,9 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); } -Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } +Mutator Mutator::MutateThreadBinding() { + return Mutator(ffi::make_object()); +} TVM_FFI_STATIC_INIT_BLOCK({ MutateThreadBindingNode::RegisterReflection(); }); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index fc56feedfba8..963906bac600 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -37,7 +37,7 @@ using tir::Trace; */ std::vector DowncastTilingDecision(const ObjectRef& decision) { const auto* arr = TVM_TYPE_AS(decision, ffi::ArrayObj); - return support::AsVector(GetRef>(arr)); + return support::AsVector(ffi::GetRef>(arr)); } /*! @@ -68,10 +68,10 @@ class MutateTileSizeNode : public MutatorNode { // Inherit from `MutatorNode` void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -119,7 +119,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (inst->kind.same_as(inst_annotate)) { ICHECK_EQ(inst->attrs.size(), 1); ICHECK_EQ(inst->inputs.size(), 2); - if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { + if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { const auto* ann_val = inst->inputs[1].as(); ICHECK(ann_val); annotated.insert(ann_val); @@ -134,7 +134,7 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, if (annotated.count(inst->outputs[0].as())) { ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + support::AsVector(Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; @@ -191,8 +191,8 @@ struct FactorMemo { std::mutex mutex_; }; -Optional MutateSampleTileSize(const Trace& trace, Instruction inst, - std::vector tiles, TRandState* rand_state) { +ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, + std::vector tiles, TRandState* rand_state) { int n_splits = tiles.size(); // Step 1. Choose two loops, `x` and `y` int x, y; @@ -235,11 +235,11 @@ Optional MutateSampleTileSize(const Trace& trace, Instruction inst, } } -Optional MutateSampleVectorize(const Trace& trace, Instruction inst, - int64_t original_decision, TRandState* rand_state) { +ffi::Optional MutateSampleVectorize(const Trace& trace, Instruction inst, + int64_t original_decision, TRandState* rand_state) { ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + support::AsVector(Downcast>(inst->attrs[1])); probs.erase(probs.begin() + original_decision); int result = tir::MakeMultinomialSampler(rand_state, probs)(); if (result >= original_decision) { @@ -248,7 +248,7 @@ Optional MutateSampleVectorize(const Trace& trace, Instruction inst, return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); } -Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector sample_perfect_tile_insts; std::vector sample_vectorize_insts; std::vector> sample_perfect_tile_tiles; @@ -271,7 +271,7 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s } } -Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } +Mutator Mutator::MutateTileSize() { return Mutator(ffi::make_object()); } TVM_FFI_STATIC_INIT_BLOCK({ MutateTileSizeNode::RegisterReflection(); }); diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 74b3cae05d52..4e021ffcb2e7 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -35,7 +35,7 @@ bool IsAnnotateWithUnroll(const Instruction& inst) { return false; } ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::meta_schedule_unroll_explicit || ann_key == attr::meta_schedule_unroll_implicit; } @@ -65,10 +65,10 @@ class MutateUnrollNode : public MutatorNode { // Inherit from `MutatorNode` void InitializeWithTuneContext(const TuneContext& context) final {} // Inherit from `MutatorNode` - Optional Apply(const Trace& trace, TRandState* rand_state) final; + ffi::Optional Apply(const Trace& trace, TRandState* rand_state) final; // Inherit from `MutatorNode` Mutator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Mutator(n); } }; @@ -118,14 +118,15 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK(sample_insts.count(var_rv)); const InstructionNode* sample_inst = sample_insts.at(var_rv); ICHECK_EQ(sample_inst->attrs.size(), 2); - candidate->inst = GetRef(sample_inst); - candidate->decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->inst = ffi::GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; candidate->probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + support::AsVector(Downcast>(sample_inst->attrs[1])); return true; } -Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { +ffi::Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { Candidate candidate; if (!FindUnrollDecision(trace, rand_state, &candidate)) { return std::nullopt; @@ -141,7 +142,7 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); } -Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } +Mutator Mutator::MutateUnroll() { return Mutator(ffi::make_object()); } TVM_FFI_STATIC_INIT_BLOCK({ MutateUnrollNode::RegisterReflection(); }); diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 50ab81f95f27..6862a9b202cc 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -29,7 +29,7 @@ void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -Optional PyMutatorNode::Apply( +ffi::Optional PyMutatorNode::Apply( const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; return f_apply(trace, *rand_state); @@ -45,7 +45,7 @@ Mutator Mutator::PyMutator( PyMutatorNode::FApply f_apply, // PyMutatorNode::FClone f_clone, // PyMutatorNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -53,25 +53,25 @@ Mutator Mutator::PyMutator( return Mutator(n); } -Map Mutator::DefaultLLVM() { - return Map{ +ffi::Map Mutator::DefaultLLVM() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; } -Map Mutator::DefaultCUDA() { - return Map{ +ffi::Map Mutator::DefaultCUDA() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.08)}, {Mutator::MutateThreadBinding(), FloatImm(DataType::Float(64), 0.02)}}; } -Map Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); } +ffi::Map Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); } -Map Mutator::DefaultHexagon() { - return Map{ +ffi::Map Mutator::DefaultHexagon() { + return ffi::Map{ {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, @@ -98,7 +98,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("meta_schedule.MutatorInitializeWithTuneContext", &MutatorNode::InitializeWithTuneContext) .def("meta_schedule.MutatorApply", - [](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + [](Mutator self, tir::Trace trace, TRandState seed) -> ffi::Optional { TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 0aef44c58bcf..88b6c2c649fb 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -83,7 +83,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } // map loop variable to zero for the store index & simplify - Array store_index = bufferstorenode->indices; + ffi::Array store_index = bufferstorenode->indices; // Use DetectIterMap to detect whether store index is non-contiguous. arith::Analyzer analyzer; @@ -94,7 +94,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } // map loop variable to zero for the load index & simplify - Array load_index = bufferloadnode->indices; + ffi::Array load_index = bufferloadnode->indices; // Use DetectIterMap to detect whether load index is non-contiguous. auto load_iter_map = DetectIterMap(load_index, input_iters, 1, @@ -110,7 +110,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { } bool found_ = false; - Map input_iters = Map(); + ffi::Map input_iters = ffi::Map(); }; } // namespace tir @@ -135,7 +135,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { if (const auto* prim_func = base_func.as()) { IRModule lowered{nullptr}; try { - auto pass_list = Array(); + auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::BindTarget(this->target)); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); @@ -152,9 +152,10 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); - tir::PrimFunc f = - WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); + IRModule mod = + IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const dmlc::Error& e) { return false; @@ -169,7 +170,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { // Inherited from PostprocNode Postproc Clone() const { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return Postproc(n); } @@ -181,7 +182,8 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { }; Postproc Postproc::DisallowAsyncStridedMemCopy() { - ObjectPtr n = make_object(); + ObjectPtr n = + ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 47588c42a0a5..88993a010989 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -71,7 +71,7 @@ class DisallowDynamicLoopNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } // Inherited from PostprocNode Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -80,7 +80,7 @@ class DisallowDynamicLoopNode : public PostprocNode { }; Postproc Postproc::DisallowDynamicLoop() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 6d119296480a..b93f47c69fa6 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -44,7 +44,7 @@ Postproc Postproc::PyPostproc( PyPostprocNode::FApply f_apply, // PyPostprocNode::FClone f_clone, // PyPostprocNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -52,8 +52,8 @@ Postproc Postproc::PyPostproc( return Postproc(n); } -Array Postproc::DefaultLLVM() { - return Array{ +ffi::Array Postproc::DefaultLLVM() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), @@ -61,24 +61,24 @@ Array Postproc::DefaultLLVM() { }; } -Array Postproc::DefaultCPUTensorization() { - return Array{ +ffi::Array Postproc::DefaultCPUTensorization() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/true), Postproc::RewriteLayout(), }; } -Array Postproc::DefaultRISCV() { - return Array{ +ffi::Array Postproc::DefaultRISCV() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false), Postproc::RewriteLayout(), }; } -Array Postproc::DefaultCUDA() { - return Array{ +ffi::Array Postproc::DefaultCUDA() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteCooperativeFetch(), Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), @@ -88,8 +88,8 @@ Array Postproc::DefaultCUDA() { }; } -Array Postproc::DefaultCUDATensorCore() { - return Array{ +ffi::Array Postproc::DefaultCUDATensorCore() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteCooperativeFetch(), Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), @@ -102,8 +102,8 @@ Array Postproc::DefaultCUDATensorCore() { }; } -Array Postproc::DefaultHexagon() { - return Array{ +ffi::Array Postproc::DefaultHexagon() { + return ffi::Array{ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(), Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(), Postproc::VerifyVTCMLimit(), diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d7009c0596f5..67620e6e9540 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -30,14 +30,15 @@ namespace tir { * \param axis The axis name expected * \return std::nullopt if parsing fails; Otherwise, the extent of thread axis */ -Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { +ffi::Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, + ffi::String axis) { static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); if (!inst->kind.same_as(inst_kind_bind)) { return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 1); ICHECK_EQ(inst->attrs.size(), 1); - String thread_axis = Downcast(inst->attrs[0]); + ffi::String thread_axis = Downcast(inst->attrs[0]); if (thread_axis != axis) { return std::nullopt; } @@ -51,15 +52,15 @@ Optional ParseThreadBinding(const Schedule& sch, const Instruction& ins * \param vector_lane The number of vector lane in vectorized cooperative fetching * \return std::nullopt if parsing fails; Otherwise, the annotated block */ -Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, - int64_t* vector_lane) { +ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, + int64_t* vector_lane) { static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); if (!inst->kind.same_as(inst_kind_annotate)) { return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); if (ann_key != attr::meta_schedule_cooperative_fetch) { return std::nullopt; } @@ -80,7 +81,7 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { } ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); - String ann_key = Downcast(inst->attrs[0]); + ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == attr::warp_execution; } @@ -124,7 +125,7 @@ class RewriteCooperativeFetchNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { - if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + if (ffi::Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; @@ -135,7 +136,7 @@ class RewriteCooperativeFetchNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -153,11 +154,13 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { int64_t vector_lane = 1; std::vector> tasks; for (const tir::Instruction& inst : trace->insts) { - if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { + if (ffi::Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { thread_extent_x = new_thread_extent.value()->value; continue; } - if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + if (ffi::Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { thread_extent_y = new_thread_extent.value()->value; continue; } @@ -165,7 +168,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { thread_extent_x = thread_warp_size_; continue; } - Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); + ffi::Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); if (!opt_block_rv.defined()) { continue; } @@ -191,29 +194,30 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } if (thread_extent_y != -1) { if (vector_lane > 1) { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } } else { if (vector_lane > 1) { - Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_x), // - Integer(vector_lane)}); + ffi::Array split = sch->Split(fused, {std::nullopt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { - Array split = sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); + ffi::Array split = + sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); sch->Bind(split[1], "threadIdx.x"); } } @@ -227,7 +231,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteCooperativeFetch() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 0d645fcf8b21..27768d162b63 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -36,17 +36,17 @@ class BufferReadPosCollector : public StmtExprVisitor { const std::pair& GetBufferLocation() const { return buffer_loc_; } - const Optional GetBufferIndexMap() const { return buffer_index_map_; } + const ffi::Optional GetBufferIndexMap() const { return buffer_index_map_; } private: void VisitStmt_(const ForNode* op) final { - loop_stack_.push_back(GetRef(op)); + loop_stack_.push_back(ffi::GetRef(op)); StmtVisitor::VisitStmt_(op); loop_stack_.pop_back(); } void VisitStmt_(const BlockRealizeNode* op) final { - BlockRealize outer_block_realize = GetRef(op); + BlockRealize outer_block_realize = ffi::GetRef(op); std::swap(outer_block_realize, cur_realize_); StmtVisitor::VisitStmt_(op); std::swap(cur_realize_, outer_block_realize); @@ -57,13 +57,13 @@ class BufferReadPosCollector : public StmtExprVisitor { const Buffer& buffer = op->buffer; if (buffer_ == buffer.get()) { - Map subst_map; + ffi::Map subst_map; for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { const Var& var = cur_realize_->block->iter_vars[i]->var; const PrimExpr& value = cur_realize_->iter_values[i]; subst_map.Set(var, value); } - Array subst_indices; + ffi::Array subst_indices; for (const PrimExpr& e : op->indices) { subst_indices.push_back(Substitute(e, subst_map)); } @@ -93,10 +93,10 @@ class BufferReadPosCollector : public StmtExprVisitor { /*! \brief The block that consumes the buffer and the corresponding read index. */ std::pair buffer_loc_; /*! \brief The proposed IndexMap. */ - Optional buffer_index_map_; + ffi::Optional buffer_index_map_; /*! \brief Loop stack for calculating IndexMap. */ - Array loop_stack_; + ffi::Array loop_stack_; /*! \brief Arithmetic analyzer. */ arith::Analyzer analyzer_; /*! \brief Current BlockRealize scope, used in recursive visit */ @@ -108,7 +108,7 @@ class LayoutFreeBufferCollector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { StmtVisitor::VisitStmt_(block); if (auto ann = block->annotations.Get("layout_free_placeholders")) { - for (Buffer buffer : Downcast>(ann.value())) { + for (Buffer buffer : Downcast>(ann.value())) { buffers.insert(buffer); } } @@ -117,12 +117,12 @@ class LayoutFreeBufferCollector : public StmtVisitor { std::unordered_set buffers; }; -Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { +ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { // Only rewrite PrimFuncs with attr "layout_free_buffers" - Array layout_free_buffer_index = - func->GetAttr(attr::layout_free_buffers, Array()).value(); + ffi::Array layout_free_buffer_index = + func->GetAttr(attr::layout_free_buffers, ffi::Array()).value(); - Array layout_free_buffers; + ffi::Array layout_free_buffers; for (const Integer& index : layout_free_buffer_index) { ICHECK(static_cast(index->value) < func->params.size()); const Var& param = func->params[index->value]; @@ -182,14 +182,14 @@ std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode } bool RewriteLayout(const Schedule& sch) { - std::vector> results; + std::vector> results; auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, true); }; for (const auto& [g_var, base_func] : sch->mod()->functions) { - const String& func_name = g_var->name_hint; + const ffi::String& func_name = g_var->name_hint; const auto* prim_func = base_func.as(); // Only consider PrimFunc if (prim_func == nullptr) { @@ -261,7 +261,7 @@ class RewriteLayoutNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -270,7 +270,7 @@ class RewriteLayoutNode : public PostprocNode { }; Postproc Postproc::RewriteLayout() { - auto n = make_object(); + auto n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 945b9adbc948..f0047d688a80 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -146,7 +146,7 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA } } -int CalculateNumRewritableLoops(const Array& loop_srefs, +int CalculateNumRewritableLoops(const ffi::Array& loop_srefs, const std::vector& loop_types) { int rw_loops_num = 0; ICHECK_EQ(loop_srefs.size(), loop_types.size()); @@ -174,7 +174,7 @@ int CalculateNumRewritableLoops(const Array& loop_srefs, } void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, - const Array& loop_rvs, ParsedAnnotation* parsed) { + const ffi::Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { return; @@ -186,7 +186,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, return; } // Extract loop_srefs, and calculate the iterator types - Array loop_srefs; + ffi::Array loop_srefs; std::vector loop_types; { loop_srefs.reserve(n_loops); @@ -198,7 +198,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } // check the maximal number of axes that are vectorizable (contiguous memory access) BlockRealize realize = GetBlockRealize(sch->state(), block_sref); - Array buffer_access(realize->block->reads); + ffi::Array buffer_access(realize->block->reads); buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), realize->block->writes.end()); std::unordered_map binding_map; @@ -357,10 +357,11 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block return false; } -void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { +void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array* loop_rvs, + int vec_len) { size_t n_loops = loop_rvs->size(); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); - Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); + ffi::Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); ICHECK_EQ(split.size(), 2); const LoopRV& outer = split[0]; const LoopRV& inner = split[1]; @@ -372,7 +373,7 @@ void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_ loop_rvs->Set(n_loops - 1, inner); } -void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { +void RewriteParallel(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); sch->Parallel(fused); @@ -381,7 +382,7 @@ void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { } } -void RewriteVectorize(const Schedule& sch, size_t n, Array* loop_rvs) { +void RewriteVectorize(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { size_t n_loops = loop_rvs->size(); ICHECK_LE(n, n_loops); LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); @@ -417,7 +418,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { tir::BlockRV root_rv{nullptr}; while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); if (loop_rvs.empty()) { continue; } @@ -451,7 +452,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { Postproc Clone() const { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return Postproc(n); } @@ -461,7 +462,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { Postproc Postproc::RewriteParallelVectorizeUnroll() { ObjectPtr n = - make_object(); + ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index bd78855d8684..7c997f8261b3 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -27,8 +27,8 @@ namespace tir { struct ReductionBlockFinder : private StmtVisitor { public: /*! \brief Find all the reduction blocks that should be decomposed */ - static std::vector> Find(const ScheduleState& self) { - std::vector> results; + static std::vector> Find(const ScheduleState& self) { + std::vector> results; for (const auto& kv : self->mod->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; @@ -92,7 +92,7 @@ struct ReductionBlockFinder : private StmtVisitor { * or -1 if the `init` does not need to be decomposed. */ int FindDecomposePoint(const StmtSRef& block_sref) { - Array loop_srefs = GetLoops(block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); int n = loop_srefs.size(); for (int i = 0; i < n; ++i) { if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { @@ -122,7 +122,7 @@ class RewriteReductionBlockNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -132,26 +132,27 @@ class RewriteReductionBlockNode : public PostprocNode { bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { for (;;) { - std::vector> results = + std::vector> results = tir::ReductionBlockFinder::Find(sch->state()); int rewritten = 0; for (const auto& kv : results) { const tir::StmtSRef& block_sref = kv.first; - const String& global_var_name = kv.second; + const ffi::String& global_var_name = kv.second; int decompose_point = tir::FindDecomposePoint(block_sref); if (decompose_point == -1) { continue; } tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); // Rewrite auto tensorization related annotations - if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).has_value()) { + if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize) + .has_value()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); - Optional tensorize_init = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); + ffi::Optional tensorize_init = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); // The annotation of tensorization of the init statement should be moved to the init block // after 'DecomposeReduction'. // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is std::nullopt. @@ -172,7 +173,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteReductionBlock() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 596bc7cb1f24..e97202461e9f 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -30,15 +30,15 @@ using tir::BlockRV; using tir::LoopRV; void CollectTensorizationJobs( - const tir::Schedule& sch, const String& func_name, const tir::PrimFuncNode* func, + const tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, bool vectorize_init_loop, - std::vector>>* jobs) { + std::vector>>* jobs) { tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); std::string block_name = block_sref->StmtAs()->name_hint; - if (Optional intrin_name = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + if (ffi::Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { if (intrin_name.value() != "") { jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { try { @@ -49,9 +49,9 @@ void CollectTensorizationJobs( }); } else if (block_name.find("init") && vectorize_init_loop) { jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { - Array child_blocks = sch->GetChildBlocks(block); + ffi::Array child_blocks = sch->GetChildBlocks(block); ICHECK(child_blocks.size() == 1); - Array init_loops = sch->GetLoops(child_blocks[0]); + ffi::Array init_loops = sch->GetLoops(child_blocks[0]); ICHECK(init_loops.size() == 1); sch->Vectorize(init_loops[0]); }); @@ -73,7 +73,7 @@ class RewriteTensorizeNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -85,7 +85,7 @@ class RewriteTensorizeNode : public PostprocNode { bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { // The rewriting jobs, 3-tuple (block_name, func_name, job_func) - std::vector>> jobs; + std::vector>> jobs; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; @@ -94,8 +94,8 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { } } for (const auto& job : jobs) { - const String& block_name = std::get<0>(job); - const String& func_name = std::get<1>(job); + const ffi::String& block_name = std::get<0>(job); + const ffi::String& func_name = std::get<1>(job); const auto& job_func = std::get<2>(job); BlockRV block = sch->GetBlock(block_name, func_name); sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); @@ -105,7 +105,7 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->vectorize_init_loop = vectorize_init_loop; return Postproc(n); } diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index acebeb71cdf7..529e3509569b 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -27,7 +27,7 @@ namespace tir { /*! \brief Find all the blocks that are not bound */ class UnboundBlockFinder : private StmtVisitor { public: - static std::vector> Find(const ScheduleState& self) { + static std::vector> Find(const ScheduleState& self) { UnboundBlockFinder finder(self); for (const auto& kv : self->mod->functions) { GlobalVar g_var = kv.first; @@ -68,13 +68,13 @@ class UnboundBlockFinder : private StmtVisitor { /*! \brief The schedule state */ const ScheduleState& self_; /*! \brief The list of unbound blocks */ - std::vector> blocks_; + std::vector> blocks_; /*! \brief The number of blockIdx above the current stmt */ int n_block_idx_; /*! \brief The number of threadIdx above the current stmt */ int n_thread_idx_; /*! \brief The name of the global var */ - String global_var_name_; + ffi::String global_var_name_; }; } // namespace tir @@ -89,7 +89,7 @@ class RewriteUnboundBlockNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; - Optional max_threads_per_block = + ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); CHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; @@ -100,7 +100,7 @@ class RewriteUnboundBlockNode : public PostprocNode { bool Apply(const tir::Schedule& sch) final; Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -128,11 +128,11 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { return Integer(std::min(t, max_extent)); }; - std::vector> unbound_blocks = + std::vector> unbound_blocks = tir::UnboundBlockFinder::Find(sch->state()); for (const auto& kv : unbound_blocks) { tir::StmtSRef block_sref = kv.first; - String global_var_name = kv.second; + ffi::String global_var_name = kv.second; BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); } @@ -140,7 +140,7 @@ bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { } Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_threadblocks_ = max_threadblocks; n->max_threads_per_block_ = -1; return Postproc(n); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 20cd0735431d..5aaf756d43bb 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -73,9 +73,9 @@ class ThreadExtentChecker : private StmtVisitor { if (block->annotations.count(attr::warp_execution)) { thread_idx_x = thread_warp_size_; } - if (Optional low_inclusive = + if (ffi::Optional low_inclusive = GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { - if (Optional high_inclusive = + if (ffi::Optional high_inclusive = GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { int64_t low = low_inclusive.value()->value; int64_t high = high_inclusive.value()->value; @@ -104,7 +104,7 @@ namespace meta_schedule { /*! \brief Extract attribute from a target. */ Integer Extract(const Target& target, const char* name) { ICHECK(target.defined()); - if (Optional v = target->GetAttr(name)) { + if (ffi::Optional v = target->GetAttr(name)) { return v.value(); } LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; @@ -115,13 +115,13 @@ Integer Extract(const Target& target, const char* name) { class VerifyGPUCodeNode : public PostprocNode { public: Target target_{nullptr}; - Map target_constraints_{nullptr}; + ffi::Map target_constraints_{nullptr}; int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { ICHECK(context->target.defined()); this->target_ = context->target.value(); - this->target_constraints_ = Map{ + this->target_constraints_ = ffi::Map{ {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")}, {"max_threads_per_block", Extract(this->target_, "max_threads_per_block")}, {"max_vthread", Integer(8)}, @@ -152,7 +152,7 @@ class VerifyGPUCodeNode : public PostprocNode { } IRModule lowered{nullptr}; try { - auto pass_list = Array(); + auto pass_list = ffi::Array(); // Phase 1 pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); @@ -180,14 +180,15 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = - WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); + tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", true).value(); if (noalias) { f = WithAttr(std::move(f), "tir.noalias", true); } - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + IRModule mod = + IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const std::exception&) { return false; @@ -201,7 +202,7 @@ class VerifyGPUCodeNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); n->target_constraints_ = this->target_constraints_; return Postproc(n); } @@ -211,7 +212,7 @@ class VerifyGPUCodeNode : public PostprocNode { }; Postproc Postproc::VerifyGPUCode() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index ee9394f16b17..09a61ebd855f 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -56,7 +56,7 @@ class VerifyVTCMLimitNode : public PostprocNode { } Postproc Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return Postproc(n); } @@ -65,7 +65,7 @@ class VerifyVTCMLimitNode : public PostprocNode { }; Postproc Postproc::VerifyVTCMLimit() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return Postproc(n); } diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index d133e67eadef..2a71aeed69ca 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -28,22 +28,22 @@ namespace meta_schedule { /**************** Profiler ****************/ -Map ProfilerNode::Get() const { - Map ret; +ffi::Map ProfilerNode::Get() const { + ffi::Map ret; for (const auto& kv : stats_sec) { ret.Set(kv.first, FloatImm(DataType::Float(64), kv.second)); } return ret; } -String ProfilerNode::Table() const { +ffi::String ProfilerNode::Table() const { CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run the profiler first."; CHECK(stats_sec.count("Total")) << "ValueError: The total time is not recorded. This method should be called only after " "exiting the profiler's with scope."; double total = stats_sec.at("Total"); struct Entry { - String name; + ffi::String name; double minutes; double percentage; bool operator<(const Entry& other) const { return percentage > other.percentage; } @@ -71,14 +71,14 @@ String ProfilerNode::Table() const { } Profiler::Profiler() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stats_sec.clear(); n->total_timer = nullptr; data_ = n; } -ffi::Function ProfilerTimedScope(String name) { - if (Optional opt_profiler = Profiler::Current()) { +ffi::Function ProfilerTimedScope(ffi::String name) { + if (ffi::Optional opt_profiler = Profiler::Current()) { return ffi::TypedFunction([profiler = opt_profiler.value(), // tik = std::chrono::high_resolution_clock::now(), // name = std::move(name)]() { @@ -91,7 +91,7 @@ ffi::Function ProfilerTimedScope(String name) { return nullptr; } -ScopedTimer Profiler::TimedScope(String name) { return ScopedTimer(ProfilerTimedScope(name)); } +ScopedTimer Profiler::TimedScope(ffi::String name) { return ScopedTimer(ProfilerTimedScope(name)); } /**************** Context Manager ****************/ @@ -113,7 +113,7 @@ void Profiler::ExitWithScope() { } } -Optional Profiler::Current() { +ffi::Optional Profiler::Current() { std::vector* profilers = ThreadLocalProfilers(); if (profilers->empty()) { return std::nullopt; diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 08ecb7aaa22d..d59d57ec64d4 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -23,30 +23,32 @@ namespace tvm { namespace meta_schedule { -RunnerInput::RunnerInput(String artifact_path, String device_type, Array args_info) { - ObjectPtr n = make_object(); +RunnerInput::RunnerInput(ffi::String artifact_path, ffi::String device_type, + ffi::Array args_info) { + ObjectPtr n = ffi::make_object(); n->artifact_path = artifact_path; n->device_type = device_type; n->args_info = args_info; this->data_ = n; } -RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { - ObjectPtr n = make_object(); +RunnerResult::RunnerResult(ffi::Optional> run_secs, + ffi::Optional error_msg) { + ObjectPtr n = ffi::make_object(); n->run_secs = run_secs; n->error_msg = error_msg; this->data_ = n; } RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_done = f_done; n->f_result = f_result; this->data_ = n; } Runner Runner::PyRunner(Runner::FRun f_run) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_run = f_run; return Runner(n); } @@ -64,13 +66,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.RunnerInput", - [](String artifact_path, String device_type, Array args_info) -> RunnerInput { - return RunnerInput(artifact_path, device_type, args_info); - }) + [](ffi::String artifact_path, ffi::String device_type, ffi::Array args_info) + -> RunnerInput { return RunnerInput(artifact_path, device_type, args_info); }) .def("meta_schedule.RunnerResult", - [](Optional> run_secs, Optional error_msg) -> RunnerResult { - return RunnerResult(run_secs, error_msg); - }) + [](ffi::Optional> run_secs, ffi::Optional error_msg) + -> RunnerResult { return RunnerResult(run_secs, error_msg); }) .def("meta_schedule.RunnerFuture", [](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { return RunnerFuture(f_done, f_result); diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index 9d2cdaedbde3..e8afb71d6b7f 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -26,21 +26,21 @@ namespace meta_schedule { using namespace tvm::tir; -static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - Array factors{nullptr}; - Array loops = sch->GetLoops(block); + ffi::Array factors{nullptr}; + ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); - Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); - Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); @@ -64,7 +64,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); @@ -75,13 +75,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cpu.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV block) -> Array { + [](Schedule sch, BlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); @@ -92,7 +92,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cpu.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV block) -> Array { + [](Schedule sch, BlockRV block) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); return {sch}; diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index 287f764a4640..b71ea9164ecf 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -31,10 +31,10 @@ namespace meta_schedule { using namespace tvm::tir; -std::function MakeFactorSampler(Schedule sch, Array thread_extents) { +std::function MakeFactorSampler(Schedule sch, ffi::Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + ffi::Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { @@ -48,14 +48,14 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, FloatImm(DataType::Float(32), 1.0 / n)); + ffi::Array probs(n, FloatImm(DataType::Float(32), 1.0 / n)); return sch->SampleCategorical(extents, probs); }; } -Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, - int64_t max_threads_per_block, - std::function get_factor) { +ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, + int64_t max_threads_per_block, + std::function get_factor) { int64_t extent = -1; if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { extent = *e; @@ -67,15 +67,15 @@ Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); } ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); - Array splits = sch->Split(loop, {std::nullopt, factor}); + ffi::Array splits = sch->Split(loop, {std::nullopt, factor}); ICHECK_EQ(splits.size(), 2); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); return {splits[0], splits[1]}; } else { - Array splits = sch->Split(loop, {std::nullopt, - Integer(max_threadblocks), // - Integer(max_threads_per_block)}); + ffi::Array splits = sch->Split(loop, {std::nullopt, + Integer(max_threadblocks), // + Integer(max_threads_per_block)}); ICHECK_EQ(splits.size(), 3); sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); @@ -95,7 +95,7 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // if (tir::HasBeenMultiLevelTiled(block_sref)) { return; } - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); int n = loops.size(); int i_block_idx = -1; int i_thread_idx = -1; @@ -143,7 +143,7 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // } LoopRV loop_rv{nullptr}; { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); if (i_spatial_loop == -1) { LoopRV spatial_loop_rv{nullptr}; if (loop_rvs.empty()) { @@ -165,7 +165,7 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV block_rv, // } if (i_block_idx == -1 && i_thread_idx != -1) { int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); sch->Bind(loop_rv, "blockIdx.x"); return; diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index ea7ee90e1408..759ab9fc721c 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -29,22 +29,22 @@ namespace meta_schedule { using namespace tvm::tir; -static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, - std::vector tiled, std::vector unrolled) { +static ffi::Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block, + std::vector tiled, std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule using namespace tvm::tir; ICHECK_EQ(tiled.size(), 2); ICHECK_EQ(unrolled.size(), 4); - Array factors{nullptr}; - Array loops = sch->GetLoops(block); + ffi::Array factors{nullptr}; + ffi::Array loops = sch->GetLoops(block); ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); - Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); + ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); - Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); + ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); @@ -68,10 +68,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); - Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); + ffi::Array loops = ScheduleDataPack(sch, data_pack, {2, 3}, {0, 1, 4, 5}); { BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); sch->ReverseComputeAt(data_pack_local, loops.back(), /*preserve_unit_loops=*/true); @@ -84,7 +84,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(data_pack); + ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); @@ -92,26 +92,26 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cuda.conv2d_nhwc_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> Array { + [](Schedule sch, BlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; - Array loops = sch->GetLoops(inverse); + ffi::Array loops = sch->GetLoops(inverse); ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_data_pack", - [](Schedule sch, BlockRV data_pack) -> Array { + [](Schedule sch, BlockRV data_pack) -> ffi::Array { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); LoopRV outer{nullptr}; { - Array loops = sch->GetLoops(data_pack); + ffi::Array loops = sch->GetLoops(data_pack); ICHECK_EQ(loops.size(), 6); sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); sch->Unroll(loops[0]); @@ -134,7 +134,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return {sch}; }) .def("meta_schedule.cuda.conv2d_nchw_winograd_inverse", - [](Schedule sch, BlockRV inverse) -> Array { + [](Schedule sch, BlockRV inverse) -> ffi::Array { GetWinogradProducerAndInlineConst(sch, inverse); // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] int64_t tile_size = @@ -142,17 +142,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ LoopRV outer{nullptr}; { BlockRV output = sch->GetConsumers(inverse)[0]; - Array nchw = sch->GetLoops(output); + ffi::Array nchw = sch->GetLoops(output); ICHECK_EQ(nchw.size(), 4); - Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); - Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); + ffi::Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); + ffi::Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); outer = ws[0]; } { sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); - Array loops = sch->GetLoops(inverse); + ffi::Array loops = sch->GetLoops(inverse); ICHECK_EQ(loops.size(), 10); sch->Unroll(loops[6]); sch->Unroll(loops[7]); diff --git a/src/meta_schedule/schedule/generic/winograd.cc b/src/meta_schedule/schedule/generic/winograd.cc index edb14667bcec..fe41e1e686f1 100644 --- a/src/meta_schedule/schedule/generic/winograd.cc +++ b/src/meta_schedule/schedule/generic/winograd.cc @@ -29,8 +29,8 @@ using namespace tvm::tir; * \return The only producer block. */ BlockRV GetWinogradProducerAndInlineConst(Schedule sch, BlockRV block) { - Array producers = sch->GetProducers(block); - Array results; + ffi::Array producers = sch->GetProducers(block); + ffi::Array results; for (const BlockRV& producer : producers) { if (sch->Get(producer)->reads.empty()) { sch->ComputeInline(producer); diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index c2f3a7208f64..81e541c1691f 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -36,11 +36,11 @@ class AddRFactorNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -70,8 +70,8 @@ class AddRFactorNode : public ScheduleRuleNode { }; ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, - Optional max_innermost_factor) { - ObjectPtr n = make_object(); + ffi::Optional max_innermost_factor) { + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; n->max_parallel_extent_ = -1; @@ -79,7 +79,8 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, return ScheduleRule(n); } -Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { +ffi::Array AddRFactorNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { tir::StmtSRef block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, max_parallel_basic_)) { @@ -97,16 +98,18 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Split the fused reduction loop. - Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); - Array split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + ffi::Array factors = + sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + ffi::Array split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); - Array res; + ffi::Array res; for (const tir::LoopRV& split_loop : split_loops) { tir::Schedule sch_tmp = sch->Copy(); sch_tmp->Seed(sch->ForkSeed()); try { const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); - Array axes = sch_tmp->GetLoops(block_rf); + ffi::Array axes = sch_tmp->GetLoops(block_rf); ICHECK_GT(axes.size(), num_spatial_loops); // Annotate that the rfactor block, which is now the producer of the original block, needs to diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 35752b8b73eb..d9000c35cf69 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -36,24 +36,25 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { CHECK(this->target_.defined()) << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; - Array keys = this->target_.value()->keys; - if (Optional ann = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + ffi::Array keys = this->target_.value()->keys; + if (ffi::Optional ann = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { if (ann.value() != "None") { - for (const String& key : keys) { + for (const ffi::String& key : keys) { if (const auto custom_schedule_fn = tvm::ffi::Function::GetGlobal(GetCustomRuleName(ann.value(), key))) { - Array result = - (*custom_schedule_fn)(sch, block_rv).cast>(); + ffi::Array result = + (*custom_schedule_fn)(sch, block_rv).cast>(); return result; } } std::ostringstream os; os << "Unknown schedule rule \"" << ann.value() << "\" for target keys \"" << keys << "\". Checked ffi::Functions:"; - for (const String& key : keys) { + for (const ffi::String& key : keys) { os << "\n " << GetCustomRuleName(ann.value(), key); } LOG(WARNING) << os.str(); @@ -65,13 +66,13 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); n->target_ = target_; return ScheduleRule(n); } public: - Optional target_ = std::nullopt; + ffi::Optional target_ = std::nullopt; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -83,7 +84,7 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { }; ScheduleRule ScheduleRule::ApplyCustomRule() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 717ec0732575..79bb9607718a 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -32,7 +32,7 @@ class AutoBindNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; - Optional max_threads_per_block = + ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); CHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; @@ -40,11 +40,11 @@ class AutoBindNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -54,7 +54,7 @@ class AutoBindNode : public ScheduleRuleNode { /*! \brief The max number of threadblocks in the cuda device */ int64_t max_threadblocks_ = -1; /*! \brief thread_extents Candidates of thread axis extent. */ - Array thread_extents_; + ffi::Array thread_extents_; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -65,16 +65,17 @@ class AutoBindNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode); }; -Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { +ffi::Array AutoBindNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = MakeFactorSampler(sch, this->thread_extents_); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); return {sch}; } -ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents, +ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, ffi::Array thread_extents, int max_threads_per_block) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_threadblocks_ = max_threadblocks; n->max_threads_per_block_ = max_threads_per_block; n->thread_extents_ = std::move(thread_extents); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 7d0277880cf4..913ee646539e 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -39,7 +39,7 @@ bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sr for (; sref->parent != nullptr; sref = sref->parent) { } ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); - return IsSpatialPrimFunc(GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); + return IsSpatialPrimFunc(ffi::GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); } /*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ @@ -52,7 +52,7 @@ class AutoInlineNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { InlineType inline_type = CheckInline(sch, block_rv); if (inline_type == InlineType::kInlineIntoConsumer) { sch->ComputeInline(block_rv); @@ -64,7 +64,7 @@ class AutoInlineNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -82,7 +82,7 @@ class AutoInlineNode : public ScheduleRuleNode { /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ bool require_ordered; /*! \brief The operators that are disallowed in auto inline */ - Array disallow_op; + ffi::Array disallow_op; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -114,7 +114,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } // Cond 2. For a block that generates a constant tensor, ignore all other conditions if (inline_const_tensor && block->reads.empty()) { - Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } @@ -144,25 +144,26 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } } // Cond 6. The block is disallowed for auto inline - if (Optional ann = - tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { + if (ffi::Optional ann = + tir::GetAnn(block_sref, tir::attr::meta_schedule_inline_rule)) { if (ann.value() == "disable") return InlineType::kNoInline; } // Last cond: Check inline into the consumers or the spatial producer tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, /*require_stage_pipeline=*/false); if (into_consumer) { - Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } } if (into_producer) { - Array producer_srefs = GetProducers(state, block_sref); + ffi::Array producer_srefs = GetProducers(state, block_sref); if (producer_srefs.size() == 1 && tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && CanReverseComputeInline(state, block_sref) && - !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).has_value()) { + !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize) + .has_value()) { return InlineType::kInlineIntoProducer; } } @@ -175,8 +176,8 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // bool disallow_if_then_else, // bool require_injective, // bool require_ordered, // - Optional> disallow_op) { - ObjectPtr n = make_object(); + ffi::Optional> disallow_op) { + ObjectPtr n = ffi::make_object(); n->into_producer = into_producer; n->into_consumer = into_consumer; n->inline_const_tensor = inline_const_tensor; @@ -185,9 +186,9 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // n->require_ordered = require_ordered; n->disallow_op.clear(); if (disallow_op.defined()) { - Array op_names = disallow_op.value(); + ffi::Array op_names = disallow_op.value(); n->disallow_op.reserve(op_names.size()); - for (const String& op_name : op_names) { + for (const ffi::String& op_name : op_names) { n->disallow_op.push_back(Op::Get(op_name)); } } @@ -206,7 +207,7 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { public: void InitializeWithTuneContext(const TuneContext& context) final {} - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { // Look for a block of the form // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) { // reads([]) @@ -225,7 +226,7 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { } ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -239,7 +240,7 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { }; ScheduleRule ScheduleRule::InlineConstantScalars() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index ddf603db27ab..219e05254e2f 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -30,8 +30,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ICHECK(context->target.defined()); Target target = context->target.value(); - Optional opt_max_threads_per_block = target->GetAttr("max_threads_per_block"); - Optional opt_warp_size = target->GetAttr("thread_warp_size"); + ffi::Optional opt_max_threads_per_block = + target->GetAttr("max_threads_per_block"); + ffi::Optional opt_warp_size = target->GetAttr("thread_warp_size"); if (!opt_max_threads_per_block.defined()) { TVM_PY_LOG(WARNING, context->logger) @@ -48,7 +49,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { // Step 0. Check the conditions of this rule. if (max_threads_per_block == -1 || warp_size == -1) { return {sch}; @@ -75,7 +76,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); + ffi::Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -87,7 +88,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // the loop before binding. // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. if (!InThreadScope(tmp_sch, target_block)) { - const Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(tgt_block_innermost_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); if (tgt_block_innermost_loop.same_as(target_loop)) { @@ -108,7 +109,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { tir::LoopRV fused_reduce_loop; ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. - const Array& split_res = + const ffi::Array& split_res = tmp_sch->Split(fused_reduce_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); @@ -117,7 +118,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -130,7 +131,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \return A boolean indicating whether the block is in thread scope. */ bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { - const Array& axes = sch->GetLoops(block); + const ffi::Array& axes = sch->GetLoops(block); for (const tir::LoopRV& loop_rv : axes) { const tir::For& loop = sch->Get(loop_rv); runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); @@ -172,7 +173,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { tir::ExprRV extent{nullptr}; for (const tir::Instruction& inst : trace->insts) { - if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { + if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { return extent; } @@ -202,7 +203,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Step 1. Get all the consumers of the input block. - Array consumers = sch->GetConsumers(block_rv); + ffi::Array consumers = sch->GetConsumers(block_rv); // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is // not fusible. @@ -225,7 +226,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } // Step 4. Get the outer loops of the target block, and get the compute-at position index. - Array tgt_block_loops = sch->GetLoops(consumers[0]); + ffi::Array tgt_block_loops = sch->GetLoops(consumers[0]); int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); // Step 5. A negative position index means not fusible, and vice-versa. @@ -248,8 +249,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { * \param lca_sref The lowest common ancestor of all the consumers of the input block * \return The compute-at position index of the input block */ - int GetComputePosition(const tir::Schedule& sch, const Array& block_loops, - const Array& tgt_block_loops, const tir::StmtSRef& lca_sref) { + int GetComputePosition(const tir::Schedule& sch, const ffi::Array& block_loops, + const ffi::Array& tgt_block_loops, + const tir::StmtSRef& lca_sref) { int n_block_loop = static_cast(block_loops.size()); int n_tgt_block_loop = static_cast(tgt_block_loops.size()); @@ -271,7 +273,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + ffi::Array thread_extents; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -285,11 +287,11 @@ class CrossThreadReductionNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(ffi::Array thread_extents) { for (const auto& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->thread_extents = std::move(thread_extents); return ScheduleRule(n); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 6a7c6ade45c1..2f796fa6b1da 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -57,8 +57,8 @@ using tir::Schedule; TVM_FFI_STATIC_INIT_BLOCK({ MultiLevelTilingNode::RegisterReflection(); }); -State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles) { - ObjectPtr node = make_object(); +State::State(tir::Schedule sch, tir::BlockRV block_rv, ffi::Array> tiles) { + ObjectPtr node = ffi::make_object(); node->sch = std::move(sch); node->block_rv = std::move(block_rv); node->tiles = std::move(tiles); @@ -66,22 +66,23 @@ State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> } State StateNode::Copy() const { - ObjectPtr node = make_object(*this); + ObjectPtr node = ffi::make_object(*this); node->sch = sch->Copy(); return State(node); } // Do nothing; Inherited from ScheduleRuleNode void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) { - if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + if (ffi::Optional v = + context->target.value()->GetAttr("max_threads_per_block")) { this->max_threads_per_block_ = v.value()->value; - if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + if (ffi::Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } - if (Optional opt_sm = context->target.value()->GetAttr("arch")) { + if (ffi::Optional opt_sm = context->target.value()->GetAttr("arch")) { std::string sm = opt_sm.value(); if (support::StartsWith(sm, "sm_")) { sm = sm.substr(3); @@ -102,12 +103,12 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) } // Entry of the mega rule; Inherited from ScheduleRuleNode -Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { +ffi::Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& block_rv) { if ((filter_fn_ && filter_fn_.value()(sch, sch->GetSRef(block_rv)).cast()) || NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - Array results; + ffi::Array results; for (auto&& state : ApplySubRules({State(sch, block_rv)})) { results.push_back(std::move(state->sch)); } @@ -118,7 +119,7 @@ Array MultiLevelTilingNode::Apply(const Schedule& sch, const BlockRV& // Inherited from ScheduleRuleNode ScheduleRule MultiLevelTilingNode::Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -138,7 +139,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { } std::vector levels = config.levels; ReuseType req = config.req; - if (Optional> ann = tir::GetAnn>( + if (ffi::Optional> ann = tir::GetAnn>( state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) { req = ReuseType::kMustReuse; levels.clear(); @@ -148,7 +149,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { std::vector results; if (req == ReuseType::kMayReuse) { // Case 1. If the write cache is already there, we don't need to add another. - Array consumer_rvs = state->sch->GetConsumers(state->block_rv); + ffi::Array consumer_rvs = state->sch->GetConsumers(state->block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state->sch->GetSRef(consumer_rvs[0]))) { for (int level : levels) { State new_state = state->Copy(); @@ -180,14 +181,14 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -std::pair, Array> MultiLevelTilingNode::SplitLoop( +std::pair, ffi::Array> MultiLevelTilingNode::SplitLoop( const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { - Array factors = sch->SamplePerfectTile( + ffi::Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -196,7 +197,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types - Array loops = sch->GetLoops(block_rv); + ffi::Array loops = sch->GetLoops(block_rv); std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it @@ -210,10 +211,10 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num; int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num; - Array skipped_outer_spatial_loops; - std::vector> tiles(s_indices_.size() + r_indices_.size()); + ffi::Array skipped_outer_spatial_loops; + std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -268,7 +269,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, sch->Bind(fused, tile_binds[i]); tiles[i] = {fused}; } - state->tiles = Array>{tiles.begin(), tiles.end()}; + state->tiles = ffi::Array>{tiles.begin(), tiles.end()}; if (this->thread_warp_size_ != -1) { int64_t low_inclusive = 1; int64_t high_inclusive = this->max_threads_per_block_; @@ -308,9 +309,9 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { // Insert cache_read block to the proper place sch->ComputeAt(cache_read_block, loop_rv, true); // Fuse the iterators of the cache_read - Array buffer_loops = sch->GetLoops(cache_read_block); - sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // - buffer_loops.end()}); + ffi::Array buffer_loops = sch->GetLoops(cache_read_block); + sch->Fuse(ffi::Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); AnnotateCooperativeFetching(&sch, cache_read_block); new_state->read_reuse.emplace(i, cache_read_block); } @@ -330,7 +331,7 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { // therefore it matches the notation array size in the following code tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); - Array seq = Downcast(r_for_loop->body)->seq; + ffi::Array seq = Downcast(r_for_loop->body)->seq; if (seq.size() != 3) { return {state}; } @@ -346,11 +347,11 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { State new_state = state->Copy(); LoopRV r_loop_fused = new_state->sch->Fuse(new_state->tiles[r_indices_[0]]); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage, - Array{0, 0, stage - 2}); + ffi::Array{0, 0, stage - 2}); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order, - Array{0, 1, 2}); + ffi::Array{0, 1, 2}); new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages, - Array{0}); + ffi::Array{0}); ret.push_back(std::move(new_state)); } return ret; @@ -386,19 +387,20 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, double prob = 1.0 / n; tir::ExprRV vector_load_len = (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), - Array(n, FloatImm(DataType::Float(32), prob))); + ffi::Array(n, FloatImm(DataType::Float(32), prob))); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } // Constructor -ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write, - Optional filter_fn) { +ScheduleRule ScheduleRule::MultiLevelTiling( + ffi::String structure, ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, + ffi::Optional filter_fn) { auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); node->filter_fn_ = filter_fn; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 2b03d749f2b5..8de89b5ba0b7 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -64,7 +64,7 @@ enum class ReuseType : int32_t { * \param str The string to be converted. * \return The converted ReuseType. */ -inline ReuseType Str2ReuseType(const String& str) { +inline ReuseType Str2ReuseType(const ffi::String& str) { if (str == "no") { return ReuseType::kNoReuse; } else if (str == "may") { @@ -84,16 +84,16 @@ struct ReuseConfig { /*! \brief Which levels are caching stage inserted at */ std::vector levels; /*! \brief The storage scope */ - String scope; + ffi::String scope; /*! \brief Default constructor: no data reuse */ ReuseConfig() : req(ReuseType::kNoReuse) {} /*! \brief Construct from a configuration dictionary */ - explicit ReuseConfig(const Map& config) - : req(Str2ReuseType(Downcast(config.at("req")))), - levels(support::AsVector(Downcast>(config.at("levels")))), - scope(Downcast(config.at("scope"))) { + explicit ReuseConfig(const ffi::Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { ICHECK_EQ(config.size(), 3); } }; @@ -109,9 +109,9 @@ class StateNode : public Object { /*! \brief The block to be tiled */ tir::BlockRV block_rv; /*! \brief The loop tiles */ - Array> tiles; + ffi::Array> tiles; /*! \brief The factors of the loop tiles. */ - Array> tile_factors; + ffi::Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ @@ -131,7 +131,8 @@ class StateNode : public Object { class State : public ObjectRef { public: /*! \brief Default constructor */ - explicit State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles = {}); + explicit State(tir::Schedule sch, tir::BlockRV block_rv, + ffi::Array> tiles = {}); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); }; @@ -173,7 +174,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; // Entry of the mega rule; Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; // Inherited from ScheduleRuleNode ScheduleRule Clone() const override; @@ -181,10 +182,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, - tir::BlockRV block, - tir::LoopRV loop, - int n_tiles) const; + virtual std::pair, ffi::Array> SplitLoop( + const tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; @@ -195,9 +194,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { * - 'SSRSRS' on CPU * - 'SSSRRSRS' on GPU */ - String structure; + ffi::String structure; /*! \brief For each level of tiles, which thread axis it is bound to */ - Array tile_binds; + ffi::Array tile_binds; /*! \brief The maximum size of the innermost factor */ int max_innermost_factor; /*! \brief The length of vector lane in vectorized cooperative fetching */ @@ -219,7 +218,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief The logging function */ ffi::Function logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ - Optional filter_fn_; + ffi::Optional filter_fn_; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -234,12 +233,13 @@ class MultiLevelTilingNode : public ScheduleRuleNode { }; template -ObjectPtr MultiLevelTilingInitCommon(String structure, Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write) { - ObjectPtr n = make_object(); +ObjectPtr MultiLevelTilingInitCommon( + ffi::String structure, ffi::Optional> tile_binds, + ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { + ObjectPtr n = ffi::make_object(); n->structure = structure; n->tile_binds = tile_binds.value_or({}); n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 22f9699c9180..0bbccbdffe7a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -36,11 +36,11 @@ using tir::LoopRV; using tir::Schedule; struct TensorCoreIntrinGroup { - String init_intrin; - String load_a_intrin; - String load_b_intrin; - String compute_intrin; - String store_intrin; + ffi::String init_intrin; + ffi::String load_a_intrin; + ffi::String load_b_intrin; + ffi::String compute_intrin; + ffi::String store_intrin; /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the * following keys: @@ -52,11 +52,12 @@ struct TensorCoreIntrinGroup { * The values of the keys should be the names of the corresponding intrinsics and should be * registered via TensorIntrin.Register beforehand. */ - static TensorCoreIntrinGroup FromConfig(const Map& config); + static TensorCoreIntrinGroup FromConfig(const ffi::Map& config); }; -TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map& config) { - auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { +TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig( + const ffi::Map& config) { + auto f_initialize_intrin = [&config](ffi::String key_name, ffi::String* intrin_name) { CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; *intrin_name = config.at(key_name); // Check the existence of the intrin @@ -98,15 +99,17 @@ class TensorCoreState : public State { public: explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, Array> tiles = {}); + BlockRV block_rv, bool use_async, + ffi::Array> tiles = {}); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); }; TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, bool use_async, Array> tiles) { - ObjectPtr node = make_object(); + BlockRV block_rv, bool use_async, + ffi::Array> tiles) { + ObjectPtr node = ffi::make_object(); node->intrin_group = intrin_group; node->mapping_info = mapping_info; node->sch = std::move(sch); @@ -118,7 +121,7 @@ TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, } State TensorCoreStateNode::Copy() const { - ObjectPtr node = make_object(*this); + ObjectPtr node = ffi::make_object(*this); node->sch = sch->Copy(); return State(node); } @@ -145,11 +148,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { // Subrule: Add software pipeline inline std::vector AddSoftwarePipeline(TensorCoreState state) const; // Subrule: split loop for mma using sample partitioned tile - inline std::pair, Array> MMASplitLoop(const Schedule& sch, - BlockRV block, LoopRV loop, - int n_tiles, - int partition_pos, - int innerpart_factor) const; + inline std::pair, ffi::Array> MMASplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, + int innerpart_factor) const; // Subrule: tile loop nest for mma // Basically same with MultiLevelTilingNode::TileLoopNest, but change SamplePerfectTile to // SamplePartitionedTile @@ -159,12 +160,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { std::vector ApplySubRules(std::vector states) final; // Override Apply to apply tensorization-specific analysis before applying sub-rules - Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + ffi::Array Apply(const Schedule& sch, const BlockRV& block_rv) final; // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -174,16 +175,17 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { * \param intrin_name The name of the tensor intrin * \return The loop to be tensorized. std::nullopt if the workload can't be tensorized. */ - Optional TransformWithTensorIntrin(TensorCoreStateNode* state, - const String& intrin_name) const; + ffi::Optional TransformWithTensorIntrin(TensorCoreStateNode* state, + const ffi::String& intrin_name) const; /*! * \brief Tile, blockize and annotate for tensorization with the given intrin * \param block_rv The block to be tensorized * \param intrin_name The name of the tensor intrin */ - void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, const String& intrin_name, - const String& permuted_layout_annotate_value) const; + void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, + const ffi::String& intrin_name, + const ffi::String& permuted_layout_annotate_value) const; public: /*! \brief The candidate tensor core intrin groups to apply */ @@ -197,8 +199,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { }; // Entry of the mega rule; Inherited from ScheduleRuleNode -Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, - const BlockRV& block_rv) { +ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, + const BlockRV& block_rv) { if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { return {sch}; } @@ -206,7 +208,7 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, std::unordered_map intrin_group_to_mapping_info; for (int i = 0, n = intrin_groups.size(); i < n; ++i) { TensorCoreIntrinGroup intrin_group = intrin_groups[i]; - Optional mapping_info = tir::GetAutoTensorizeMappingInfo( + ffi::Optional mapping_info = tir::GetAutoTensorizeMappingInfo( sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); if (mapping_info.defined()) { @@ -231,7 +233,7 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv, true)); } - Array results; + ffi::Array results; for (auto&& state : ApplySubRules(initial_states)) { TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " << state.as()->intrin_group.compute_intrin; @@ -273,9 +275,9 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); + Schedule* sch, const BlockRV& block_rv, const ffi::String& intrin_name, + const ffi::String& permuted_layout_annotate_value) const { + ffi::Optional loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); ICHECK(loop.defined()); BlockRV blockized_outer = (*sch)->Blockize(loop.value()); (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); @@ -308,8 +310,9 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); new_state->read_reuse.emplace(i, cache_read_block); if (state->is_mma) { - new_state->sch->Annotate(cache_read_block, "permuted_layout", - String(std::string("g2s_") + std::string(i == 0 ? "A" : "B"))); + new_state->sch->Annotate( + cache_read_block, "permuted_layout", + ffi::String(std::string("g2s_") + std::string(i == 0 ? "A" : "B"))); } } results.push_back(std::move(new_state)); @@ -317,16 +320,17 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta return results; } -std::pair, Array> MultiLevelTilingTensorCoreNode::MMASplitLoop( - const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, - int innerpart_factor) const { - Array factors = sch->SamplePartitionedTile( +std::pair, ffi::Array> +MultiLevelTilingTensorCoreNode::MMASplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, + int n_tiles, int partition_pos, + int innerpart_factor) const { + ffi::Array factors = sch->SamplePartitionedTile( /*loop=*/loop, /*n=*/n_tiles, /*partition_pos=*/partition_pos, /*innerpart_factor=*/innerpart_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } @@ -334,7 +338,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types - Array loops = sch->GetLoops(block_rv); + ffi::Array loops = sch->GetLoops(block_rv); if (!(loops.size() == 3 || !state->is_mma)) { LOG(DEBUG) << "The MMA tensor core only supports SSR loops now"; return {}; @@ -343,9 +347,9 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; - std::vector> tiles(s_indices_.size() + r_indices_.size()); + std::vector> tiles(s_indices_.size() + r_indices_.size()); state->tile_factors.resize(tiles.size()); - std::vector> tile_factors; + std::vector> tile_factors; tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; @@ -397,7 +401,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta sch->Bind(fused, tile_binds[i]); tiles[i] = {fused}; } - state->tiles = Array>{tiles.begin(), tiles.end()}; + state->tiles = ffi::Array>{tiles.begin(), tiles.end()}; if (this->thread_warp_size_ != -1) { int64_t low_inclusive = 1; int64_t high_inclusive = this->max_threads_per_block_; @@ -445,7 +449,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. // `loop_idx` can be negative, in which case it is counted from the end. auto f_get_inner_tile_product = [&](int loop_idx) { - Array factors; + ffi::Array factors; for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { auto s_factors = state->tile_factors[s_indices_[i]]; if (loop_idx < 0) { @@ -479,8 +483,8 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // frag_shape_m and frag_shape_n are structural bindings that cannot // not be automatically captured until c++20 [&, frag_shape_m = frag_shape_m, - frag_shape_n = frag_shape_n](const Array& indices) { - Array result; + frag_shape_n = frag_shape_n](const ffi::Array& indices) { + ffi::Array result; result.reserve(indices.size() + 4); for (int i = 0; i < num_higher_dims; ++i) { result.push_back(indices[i]); @@ -547,7 +551,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( // Get the loops other than the innermost two loops (accum_m and accum_n). auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { - Array buffer_loops = sch->GetLoops(block_rv); + ffi::Array buffer_loops = sch->GetLoops(block_rv); ICHECK_GT(buffer_loops.size(), 6); return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; @@ -571,24 +575,24 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, state->intrin_group.store_intrin); - Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ffi::Array buffer_loops = sch->GetLoops(state->write_reuse[0]); ICHECK_GT(buffer_loops.size(), 5); - sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D - buffer_loops.end()}); + sch->Fuse(ffi::Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); AnnotateCooperativeFetching(&sch, state->write_reuse[0]); return {state}; } std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( TensorCoreState state) const { - const Array& r_tiles = state->tiles[r_indices_[1]]; + const ffi::Array& r_tiles = state->tiles[r_indices_[1]]; Schedule& sch = state->sch; ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block"; - auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) { + auto f_tensorize_load = [&](int read_index, ffi::String scope, ffi::String intrin_name) { auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); state->sch->ComputeAt(cache_read, r_tiles.back(), true); - String permuted_layout_annotate_value = + ffi::String permuted_layout_annotate_value = state->is_mma ? std::string("s2l_") + std::string(read_index == 0 ? "A" : "B") : ""; TileAndAnnotateTensorize(&sch, cache_read, intrin_name, permuted_layout_annotate_value); }; @@ -603,7 +607,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( sch->ComputeInline(sch->GetProducers(cache_read)[0]); const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( - sch->state(), GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + sch->state(), ffi::GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); @@ -631,7 +635,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // Check reduction length after blockize. int64_t reduction_length = 1; for (int r_index : r_indices_) { - const Array& tiles = state->tiles[r_index]; + const ffi::Array& tiles = state->tiles[r_index]; for (const LoopRV& tile : tiles) { const auto* extent = sch->Get(tile)->extent.as(); ICHECK(extent != nullptr) << "Dynamic extent is not supported."; @@ -686,16 +690,16 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // compute matmul with fragment K1 - 1 // sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 1}); + ffi::Array{0, 0, 1}); sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order, - Array{0, 1, 2}); + ffi::Array{0, 1, 2}); if (state->is_mma && state->use_async) { sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_async_stages, - Array{0}); + ffi::Array{0}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 1, 2, 2}); + ffi::Array{0, 0, 1, 2, 2}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, - Array{0, 1, 3, 2, 4}); + ffi::Array{0, 1, 3, 2, 4}); } else { // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop. // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop. @@ -738,16 +742,16 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // compute matmul with fragment K1 - 1 of tile K0 - 1 // sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, - Array{0, 0, 0, 0, 0, 1, 1}); + ffi::Array{0, 0, 0, 0, 0, 1, 1}); sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, - Array{0, 3, 1, 4, 5, 2, 6}); + ffi::Array{0, 3, 1, 4, 5, 2, 6}); } return {state}; } -Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( - TensorCoreStateNode* state, const String& intrin_name) const { +ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( + TensorCoreStateNode* state, const ffi::String& intrin_name) const { BlockRV block_rv = state->block_rv; const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); @@ -755,7 +759,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Add reindex stages const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); // Hold the reference of the block before reindex - const tir::Block block_before_reindex = GetRef(block); + const tir::Block block_before_reindex = ffi::GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { // only matmul-like computation is allowed return std::nullopt; @@ -792,7 +796,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( for (int i = 0; i < offset; ++i) { const tir::VarNode* var_ptr = index_map->final_indices[i].as(); ICHECK(var_ptr != nullptr); - unmapped_index_map_src.insert(GetRef(var_ptr)); + unmapped_index_map_src.insert(ffi::GetRef(var_ptr)); } for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; @@ -806,7 +810,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( ICHECK(tir::is_one(range->extent)); const tir::VarNode* var_ptr = range->min.as(); ICHECK(var_ptr != nullptr); - const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef(var_ptr)]; + const tir::Var& lhs_representer = lhs_to_index_map_src[ffi::GetRef(var_ptr)]; sub_index_map_src.push_back(lhs_representer); if (unmapped_index_map_src.count(lhs_representer)) { sub_index_map_tgt.push_back(lhs_representer); @@ -815,15 +819,15 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); ICHECK(var != nullptr); - sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); + sub_index_map_tgt.push_back(rhs_to_index_map_tgt[ffi::GetRef(var)]); } return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); }; std::unordered_set visited_buffers; - Map buffer_sub_index_map; // cache of the sub index map associated - // with each buffer + ffi::Map buffer_sub_index_map; // cache of the sub index map + // associated with each buffer auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( @@ -835,7 +839,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Refresh block pointer (block sref is not invalidated) block = TVM_SREF_TO_BLOCK(block_sref); const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( - state->sch->state(), GetRef(block), buffer_index, index_type); + state->sch->state(), ffi::GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, @@ -868,7 +872,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( TensorCoreState state) const { // Do reindex and layout transformations. - Optional transformed_loop_rv = + ffi::Optional transformed_loop_rv = TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); if (!transformed_loop_rv.defined()) { // The workload can't be tensorized. @@ -888,12 +892,13 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat } ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( - Array> intrin_groups, String structure, Optional> tile_binds, - Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write, - bool use_software_pipeline) { + ffi::Array> intrin_groups, ffi::String structure, + ffi::Optional> tile_binds, ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write, bool use_software_pipeline) { if (tile_binds.defined()) { - for (const String& tile_bind : tile_binds.value()) { + for (const ffi::String& tile_bind : tile_binds.value()) { CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core."; } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index a560248ee2b2..3397945afd42 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -46,16 +46,18 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { protected: ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } - std::pair, Array> SplitLoop(const Schedule& sch, BlockRV block, - LoopRV loop, int n_tiles) const; + std::pair, ffi::Array> SplitLoop(const Schedule& sch, + BlockRV block, LoopRV loop, + int n_tiles) const; }; -std::pair, Array> MultiLevelTilingWideVectorNode::SplitLoop( - const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { +std::pair, ffi::Array> +MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, + int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::BlockNode* block_node = block_sref->StmtAs(); @@ -93,32 +95,33 @@ std::pair, Array> MultiLevelTilingWideVectorNode // We split the innermost spatial loop in a way that always uses the maximum vector length. const int64_t* extent_int = tir::GetLoopIntExtent(loop); if (extent_int && *extent_int > vec_len) { - Array inner_splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{std::nullopt, PrimExpr(vec_len)}); - Array outer_factors = sch->SamplePerfectTile( + ffi::Array inner_splits = + sch->Split(/*loop=*/loop_rv, + /*factors=*/{std::nullopt, PrimExpr(vec_len)}); + ffi::Array outer_factors = sch->SamplePerfectTile( /*loop=*/inner_splits[0], /*n=*/n_tiles - 1, /*max_innermost_factor=*/max_innermost_factor); - Array outer_splits = sch->Split( + ffi::Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); outer_factors.push_back(PrimExpr(vec_len)); return {outer_factors, outer_splits}; } else { - Array factors(n_tiles - 1, PrimExpr(1)); + ffi::Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - Array splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + ffi::Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); return {factors, splits}; } } } -ScheduleRule ScheduleRule::MultiLevelTilingWideVector(String structure, - Integer vector_length_in_bits, - Optional max_innermost_factor, - Optional> reuse_read, - Optional> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWideVector( + ffi::String structure, Integer vector_length_in_bits, + ffi::Optional max_innermost_factor, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { auto node = MultiLevelTilingInitCommon( structure, std::nullopt, max_innermost_factor, std::nullopt, reuse_read, reuse_write); node->vector_length_in_bits = vector_length_in_bits->value; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 85c9243e6bb1..5747746a52a5 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -31,15 +31,15 @@ namespace meta_schedule { * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate * the tiled block for tensorization by postproc rewrite. */ -Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, - const std::string& intrin_name) { - Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); +ffi::Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, + const std::string& intrin_name) { + ffi::Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); if (!tiled_loop_rv) { return std::nullopt; } ICHECK(tiled_loop_rv.defined()); tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); - sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, ffi::String(intrin_name)); return outer_block; } @@ -48,7 +48,7 @@ Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, */ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { protected: - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; @@ -68,7 +68,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -87,18 +87,18 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { public: /*! \brief The name of a tensor intrinsic. */ - String intrin_name; + ffi::String intrin_name; static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); }; -ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String structure, - Optional> tile_binds, - Optional max_innermost_factor, - Optional> vector_load_lens, - Optional> reuse_read, - Optional> reuse_write) { +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( + ffi::String intrin_name, ffi::String structure, + ffi::Optional> tile_binds, ffi::Optional max_innermost_factor, + ffi::Optional> vector_load_lens, + ffi::Optional> reuse_read, + ffi::Optional> reuse_write) { ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) << "Provided tensor intrinsic " << intrin_name << " is not registered."; auto node = MultiLevelTilingInitCommon( diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 28929d933762..dd3684e3aa05 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -30,7 +30,7 @@ bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) { bool CheckSpatialPrimFunc(const Schedule& sch, const BlockRV& root_block_rv) { return IsSpatialPrimFunc( - GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); + ffi::GetRef(GetRootPrimFunc(sch->mod(), sch->Get(root_block_rv).get(), nullptr))); } } // namespace tir @@ -51,7 +51,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { // Currently only mark the root block with annotations. if (!tir::IsRootBlock(sch, root_rv)) { return {sch}; @@ -70,7 +70,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, FloatImm(DataType::Float(32), prob)); + ffi::Array probs(n, FloatImm(DataType::Float(32), prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -84,7 +84,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { ObjectPtr n = - make_object(*this); + ffi::make_object(*this); return ScheduleRule(n); } @@ -104,7 +104,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + ffi::Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -125,9 +125,9 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + ffi::Array unroll_max_steps, bool unroll_explicit) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_jobs_per_core = max_jobs_per_core; n->max_vectorize_extent = max_vectorize_extent; n->unroll_max_steps = unroll_max_steps; diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index a2bfa2644b1e..fa84ecffe217 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -29,7 +29,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + ffi::Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { if (!CheckConditions(sch, block_rv)) { return {sch}; } @@ -40,7 +40,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer // access the input block. Hence we collect its producer ahead of time. // - Note that only single producer is allowed in this case. - Array producers{nullptr}; + ffi::Array producers{nullptr}; if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, true)) { producers = sch->GetProducers(block_rv); @@ -61,7 +61,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); return ScheduleRule(n); } @@ -82,7 +82,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child // block. - Array loop_srefs = tir::GetLoops(block_sref); + ffi::Array loop_srefs = tir::GetLoops(block_sref); if (loop_srefs.empty()) { return false; } @@ -123,7 +123,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { }; ScheduleRule ScheduleRule::RandomComputeLocation() { - return ScheduleRule(make_object()); + return ScheduleRule(ffi::make_object()); } TVM_FFI_STATIC_INIT_BLOCK({ RandomComputeLocationNode::RegisterReflection(); }); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index e23ca117c616..2aad6a8df548 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -30,8 +30,8 @@ void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { f_initialize_with_tune_context(context); } -Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, - const tir::BlockRV& block) { +ffi::Array PyScheduleRuleNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block) { ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return f_apply(sch, block); } @@ -46,7 +46,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( PyScheduleRuleNode::FApply f_apply, // PyScheduleRuleNode::FClone f_clone, // PyScheduleRuleNode::FAsString f_as_string) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_apply = std::move(f_apply); n->f_clone = std::move(f_clone); @@ -54,7 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule( return ScheduleRule(n); } -Array ScheduleRule::DefaultLLVM() { +ffi::Array ScheduleRule::DefaultLLVM() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -65,7 +65,7 @@ Array ScheduleRule::DefaultLLVM() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -76,21 +76,21 @@ Array ScheduleRule::DefaultLLVM() { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; } -Array ScheduleRule::DefaultX86(const String& type) { - static const Map intrins = {{"vnni", "dot_16x4_vnni"}, - {"avx512", "dot_16x4_avx512"}}; +ffi::Array ScheduleRule::DefaultX86(const ffi::String& type) { + static const ffi::Map intrins = {{"vnni", "dot_16x4_vnni"}, + {"avx512", "dot_16x4_avx512"}}; return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -101,7 +101,7 @@ Array ScheduleRule::DefaultX86(const String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -113,9 +113,9 @@ Array ScheduleRule::DefaultX86(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, @@ -123,34 +123,34 @@ Array ScheduleRule::DefaultX86(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; } -Array ScheduleRule::DefaultCUDA() { +ffi::Array ScheduleRule::DefaultCUDA() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTiling( /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, + /*tile_binds=*/ffi::Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, /*max_innermost_factor=*/Integer(64), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared")}}, /*reuse_write=*/ - Map{{"req", String("must")}, - {"levels", Array{3}}, // - {"scope", String("local")}}), + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{3}}, // + {"scope", ffi::String("local")}}), ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/true, @@ -159,22 +159,22 @@ Array ScheduleRule::DefaultCUDA() { /*disallow_if_then_else=*/false, /*require_injective=*/false, /*require_ordered=*/false, - /*disallow_op=*/Array{}), + /*disallow_op=*/ffi::Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/ffi::Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, - /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}), + /*thread_extents*/ ffi::Array{32, 64, 128, 256, 512, 1024}), }; } -Array ScheduleRule::DefaultCUDATensorCore() { - Array> wmma_intrin_groups = { +ffi::Array ScheduleRule::DefaultCUDATensorCore() { + ffi::Array> wmma_intrin_groups = { // Tensor Cores f32 += f16 * f16 { {"init", "wmma_fill_16x16x16_f32"}, @@ -221,7 +221,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, }; - Array> mma_intrin_groups = { + ffi::Array> mma_intrin_groups = { // Tensor Core MMA { {"init", "mma_init_m16n8k8_f16"}, @@ -238,45 +238,45 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"store", "mma_store_m16n8k8_f32_global"}, }, }; - Array results{ + ffi::Array results{ ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/wmma_intrin_groups, /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*tile_binds=*/ffi::Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared.dyn")}}, /*reuse_write=*/ - Map{{"req", String("must")}, - {"levels", Array{2}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{2}}, // + {"scope", ffi::String("shared.dyn")}}, /*use_software_pipeline=*/false), // ScheduleRule::MultiLevelTilingTensorCore( /*intrin_groups=*/mma_intrin_groups, /*structure=*/"SSSRRSRS", - /*tile_binds=*/Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*tile_binds=*/ffi::Array{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), - /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*vector_load_lens=*/ffi::Array{1, 2, 3, 4, 8, 16}, /*reuse_read=*/ - Map{{"req", String("must")}, - {"levels", Array{4}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("must")}, + {"levels", ffi::Array{4}}, // + {"scope", ffi::String("shared.dyn")}}, /*reuse_write=*/ - Map{{"req", String("no")}, - {"levels", Array{2}}, // - {"scope", String("shared.dyn")}}, + ffi::Map{{"req", ffi::String("no")}, + {"levels", ffi::Array{2}}, // + {"scope", ffi::String("shared.dyn")}}, /*use_software_pipeline=*/true) // }; - Array append = ScheduleRule::DefaultCUDA(); + ffi::Array append = ScheduleRule::DefaultCUDA(); results.insert(results.end(), append.begin() + 1, append.end()); return results; } -Array ScheduleRule::DefaultHexagon() { +ffi::Array ScheduleRule::DefaultHexagon() { return { ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), @@ -287,26 +287,26 @@ Array ScheduleRule::DefaultHexagon() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::MultiLevelTilingWideVector( /*structure=*/"SRSRS", /*vector_length_in_bits=*/1024, /*max_innermost_factor=*/Integer(128), /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } -Array ScheduleRule::DefaultRISCV(const int vlen) { - Array rules; +ffi::Array ScheduleRule::DefaultRISCV(const int vlen) { + ffi::Array rules; rules.push_back(ScheduleRule::ApplyCustomRule()); rules.push_back(ScheduleRule::InlineConstantScalars()); rules.push_back(ScheduleRule::AutoInline( @@ -316,15 +316,15 @@ Array ScheduleRule::DefaultRISCV(const int vlen) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"})); + /*disallow_op=*/ffi::Array{"tir.exp"})); rules.push_back(ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64))); auto current_target = tvm::Target::Current(); const auto reg_rvv_intrinsics = tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics"); - const auto rvv_kernels_inventory = - reg_rvv_intrinsics(current_target, /* inventory_only */ true).cast>(); + const auto rvv_kernels_inventory = reg_rvv_intrinsics(current_target, /* inventory_only */ true) + .cast>(); for (const auto& intrin : rvv_kernels_inventory) { if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) { // on demand intrinsic register @@ -338,9 +338,9 @@ Array ScheduleRule::DefaultRISCV(const int vlen) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}})); + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}})); } rules.push_back(ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", @@ -349,74 +349,75 @@ Array ScheduleRule::DefaultRISCV(const int vlen) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{ - {"req", String("may")}, {"levels", Array{1, 2}}, {"scope", String("global")}})); + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}})); rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/ffi::Array{0, 16, 64, 512}, /*unroll_explicit=*/true)); rules.push_back(ScheduleRule::RandomComputeLocation()); return rules; } -Array GetARMNeonSpecificRules() { +ffi::Array GetARMNeonSpecificRules() { return { ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_i8i8s32_neon"), + /*intrin_name=*/ffi::String("dot_4x4_i8i8s32_neon"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), }; } -Array GetARMDotprodSpecificRules() { +ffi::Array GetARMDotprodSpecificRules() { return { ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"), + /*intrin_name=*/ffi::String("dot_4x4_i8i8s32_sdot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_u8u8u32_udot"), + /*intrin_name=*/ffi::String("dot_4x4_u8u8u32_udot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::MultiLevelTilingWithIntrin( - /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"), + /*intrin_name=*/ffi::String("dot_4x4_u8u8i32_hdot"), /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), }; } -Array ScheduleRule::DefaultARM(const String& type) { - return Array::Agregate( +ffi::Array ScheduleRule::DefaultARM(const ffi::String& type) { + return ffi::Array::Agregate( ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(), ScheduleRule::AutoInline( /*into_producer=*/false, @@ -425,12 +426,12 @@ Array ScheduleRule::DefaultARM(const String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tir.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/8, /*max_innermost_factor=*/Integer(32)), - "neon" == type ? GetARMNeonSpecificRules() : Array{}, - "dotprod" == type ? GetARMDotprodSpecificRules() : Array{}, + "neon" == type ? GetARMNeonSpecificRules() : ffi::Array{}, + "dotprod" == type ? GetARMDotprodSpecificRules() : ffi::Array{}, ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", /*tile_binds=*/std::nullopt, @@ -438,13 +439,13 @@ Array ScheduleRule::DefaultARM(const String& type) { /*vector_load_lens=*/std::nullopt, /*reuse_read=*/std::nullopt, /*reuse_write=*/ - Map{{"req", String("may")}, - {"levels", Array{1, 2}}, - {"scope", String("global")}}), + ffi::Map{{"req", ffi::String("may")}, + {"levels", ffi::Array{1, 2}}, + {"scope", ffi::String("global")}}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/ffi::Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 82c0dcb746c6..306a3634d9d1 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -115,7 +115,7 @@ struct PerThreadData { IRModule mod{nullptr}; TRandState rand_state{-1}; std::function trace_sampler = nullptr; - std::function()> mutator_sampler = nullptr; + std::function()> mutator_sampler = nullptr; /*! * \brief Set the value for the trace and mutator samplers per thread. @@ -124,7 +124,7 @@ struct PerThreadData { * \param mutator_probs The probability of each mutator as a dict. */ void Set(const std::vector& scores, double genetic_mutate_prob, - const Map& mutator_probs) { + const ffi::Map& mutator_probs) { trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); } @@ -135,11 +135,11 @@ struct PerThreadData { * \param rand_state The random state for sampling * \return The sampler created */ - static std::function()> MakeMutatorSampler( - double genetic_mutate_prob, // - const Map& mutator_probs, // + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const ffi::Map& mutator_probs, // TRandState* rand_state) { - std::vector> mutators; + std::vector> mutators; std::vector masses; mutators.push_back(std::nullopt); masses.push_back(1.0 - genetic_mutate_prob); @@ -165,7 +165,7 @@ struct PerThreadData { } } return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), - mutators = std::move(mutators)]() -> Optional { + mutators = std::move(mutators)]() -> ffi::Optional { int i = idx_sampler(); return mutators[i]; }; @@ -212,8 +212,8 @@ struct ConcurrentBitmask { * \param traces The picked candidate traces. * \return The assembled measure candidates. */ -Array AssembleCandidates(const std::vector& picks) { - Array measure_inputs; +ffi::Array AssembleCandidates(const std::vector& picks) { + ffi::Array measure_inputs; measure_inputs.reserve(picks.size()); for (const Schedule& sch : picks) { measure_inputs.push_back( @@ -261,7 +261,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The counter of returning empty results. */ int num_empty_iters; /*! \brief The design spaces. Decisions are not used so traces only. */ - Array design_spaces; + ffi::Array design_spaces; /*! \brief Pre thread data including module to be tuned and random state. */ std::vector per_thread_data_; /*! @@ -277,7 +277,8 @@ class EvolutionarySearchNode : public SearchStrategyNode { Workload token_{nullptr}; explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, - Array design_space_schedules, Database database, CostModel cost_model) + ffi::Array design_space_schedules, Database database, + CostModel cost_model) : self(self), max_trials(max_trials), num_trials_per_iter(num_trials_per_iter), @@ -331,10 +332,10 @@ class EvolutionarySearchNode : public SearchStrategyNode { inline std::vector PickWithEpsGreedy(const std::vector& inits, const std::vector& bests, int num); /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline Optional> GenerateMeasureCandidates(); + inline ffi::Optional> GenerateMeasureCandidates(); /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline void NotifyRunnerResults(const Array& measure_candidates, - const Array& results); + inline void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results); /*! * \brief Compute the hash for the given module. * \param mod The input TIR module. @@ -346,9 +347,9 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The tuning context of the evolutionary search strategy. */ const TuneContextNode* ctx_{nullptr}; /*! \brief The postprocessors */ - Array postprocs_; + ffi::Array postprocs_; /*! \brief The mutators and their probability. */ - Map mutator_probs_; + ffi::Map mutator_probs_; /*! \brief The random state. To be initialized with TuneContext. */ TRandState rand_state_; /*! \brief The state of the search strategy. */ @@ -413,8 +414,9 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { ICHECK(!design_spaces.empty()); CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; CHECK(database.defined()) @@ -439,19 +441,19 @@ class EvolutionarySearchNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(measure_candidates, results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->population_size = this->population_size; n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; n->init_measured_ratio = this->init_measured_ratio; @@ -472,7 +474,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase"); std::vector measured_traces; measured_traces.reserve(num); - Array top_records = this->database_->GetTopK(this->token_, num); + ffi::Array top_records = this->database_->GetTopK(this->token_, num); for (TuningRecord record : top_records) { measured_traces.push_back(record->trace); } @@ -487,7 +489,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu tir::Trace trace = measured_traces.at(trace_id); Schedule& result = results.at(trace_id); ICHECK(!result.defined()); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } else { LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; @@ -514,7 +516,7 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu ICHECK(!result.defined()); int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); tir::Trace trace(design_spaces[design_space_index]->insts, {}); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } }; @@ -546,7 +548,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( for (int iter = 0;; ++iter) { // Predict normalized score with the cost model, std::vector scores = - PredictNormalizedScore(population, GetRef(self->ctx_), this->cost_model_); + PredictNormalizedScore(population, ffi::GetRef(self->ctx_), this->cost_model_); { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc"); @@ -583,7 +585,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; std::function& trace_sampler = data.trace_sampler; - std::function()>& mutator_sampler = data.mutator_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; Schedule& result = next_population.at(trace_id); int sampled_trace_id = -1; // Loop until success @@ -591,11 +593,11 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( sampled_trace_id = trace_sampler(); sampled_trace_id = sampled_trace_id % self->population_size; tir::Trace trace = population.at(sampled_trace_id)->trace().value(); - if (Optional opt_mutator = mutator_sampler()) { + if (ffi::Optional opt_mutator = mutator_sampler()) { // Decision: mutate Mutator mutator = opt_mutator.value(); - if (Optional new_trace = mutator->Apply(trace, rand_state)) { - if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + if (ffi::Optional new_trace = mutator->Apply(trace, rand_state)) { + if (ffi::Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { // note that sch's trace is different from new_trace // because it contains post-processing information result = sch.value(); @@ -694,7 +696,8 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( return results; } -Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { +ffi::Optional> +EvolutionarySearchNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } @@ -737,7 +740,8 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure } void EvolutionarySearchNode::State::NotifyRunnerResults( - const Array& measure_candidates, const Array& results) { + const ffi::Array& measure_candidates, + const ffi::Array& results) { st += results.size(); ed += results.size(); } @@ -757,7 +761,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, / TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->population_size = population_size; n->num_empty_iters_before_early_stop = 5; n->init_measured_ratio = init_measured_ratio; @@ -776,14 +780,15 @@ class EvolutionarySearch : public SearchStrategy { EvolutionarySearchNode); }; -Array EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) { +ffi::Array EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) { std::vector results = self->state_->SampleInitPopulation(num); - return Array(results.begin(), results.end()); + return ffi::Array(results.begin(), results.end()); } -Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, - Array population, int num) { - Array result; +ffi::Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, + ffi::Array population, + int num) { + ffi::Array result; std::vector population_vec = std::vector(population.begin(), population.end()); std::vector schs = self->state_->EvolveWithCostModel(population_vec, num); diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index c9a219777053..d9233e307443 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -49,16 +49,16 @@ class ReplayFuncNode : public SearchStrategyNode { << "ValueError: The search strategy has not been initialized."; } - inline Optional> GenerateMeasureCandidates(); - inline void NotifyRunnerResults(const Array& results); + inline ffi::Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const ffi::Array& results); }; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = std::nullopt; + ffi::Optional mod_ = std::nullopt; /*! \brief The space generator from TuneContext. */ - Optional space_generator_ = std::nullopt; + ffi::Optional space_generator_ = std::nullopt; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; @@ -85,8 +85,10 @@ class ReplayFuncNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { CHECK(this->state_ == nullptr) << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; this->state_ = std::make_unique(this, max_trials, num_trials_per_iter); @@ -98,19 +100,19 @@ class ReplayFuncNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->rand_state_ = -1; n->mod_ = std::nullopt; n->space_generator_ = std::nullopt; @@ -119,17 +121,18 @@ class ReplayFuncNode : public SearchStrategyNode { } }; -inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { +inline ffi::Optional> +ReplayFuncNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } ed = std::min(ed, max_trials); - Array result; + ffi::Array result; IRModule mod = self->mod_.value(); - Array postprocs = self->space_generator_.value()->postprocs.value_or({}); + ffi::Array postprocs = self->space_generator_.value()->postprocs.value_or({}); for (int i = st; i < ed; i++) { for (;;) { - Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); + ffi::Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); tir::Schedule sch = schs[design_space_index]; sch->EnterPostproc(); @@ -141,7 +144,7 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC } } if (!failed) { - Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); + ffi::Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); result.push_back(MeasureCandidate(sch, args_info)); break; } @@ -150,13 +153,13 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC return result; } -inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { +inline void ReplayFuncNode::State::NotifyRunnerResults(const ffi::Array& results) { st += num_trials_per_iter; ed += num_trials_per_iter; } SearchStrategy SearchStrategy::ReplayFunc() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return SearchStrategy(n); } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 151d502ec078..33e43e3574b6 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -31,7 +31,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - Array design_spaces; + ffi::Array design_spaces; /*! \brief The number of total trials. */ int max_trials; /*! \brief The number of trials per iteration. */ @@ -42,9 +42,9 @@ class ReplayTraceNode : public SearchStrategyNode { int ed; /*! \brief The module to be tuned. */ - Array per_thread_mod_{nullptr}; + ffi::Array per_thread_mod_{nullptr}; - explicit State(ReplayTraceNode* self, Array design_spaces, int max_trials, + explicit State(ReplayTraceNode* self, ffi::Array design_spaces, int max_trials, int num_trials_per_iter) : self(self), design_spaces(design_spaces), @@ -59,8 +59,8 @@ class ReplayTraceNode : public SearchStrategyNode { } } - inline Optional> GenerateMeasureCandidates(); - inline void NotifyRunnerResults(const Array& results); + inline ffi::Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const ffi::Array& results); }; /*! \brief The max number of failures during trace replaying. */ @@ -69,11 +69,11 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = std::nullopt; + ffi::Optional mod_ = std::nullopt; /*! \brief The number of threads to be used. */ int num_threads_ = -1; /*! \brief The postprocessors. */ - Array postprocs_ = {}; + ffi::Array postprocs_ = {}; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; @@ -102,12 +102,14 @@ class ReplayTraceNode : public SearchStrategyNode { this->state_.reset(); } - void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, - const Optional& database, const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) final { ICHECK(!design_spaces.empty()); CHECK(this->state_ == nullptr) << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; - Array design_space_traces; + ffi::Array design_space_traces; design_space_traces.reserve(design_spaces.size()); for (const tir::Schedule& space : design_spaces) { design_space_traces.push_back(space->trace().value()->Simplified(true)); @@ -121,19 +123,19 @@ class ReplayTraceNode : public SearchStrategyNode { this->state_.reset(); } - Optional> GenerateMeasureCandidates() final { + ffi::Optional> GenerateMeasureCandidates() final { ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& measure_candidates, - const Array& results) final { + void NotifyRunnerResults(const ffi::Array& measure_candidates, + const ffi::Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } SearchStrategy Clone() const final { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_fail_count = this->max_fail_count; n->rand_state_ = this->rand_state_; n->state_ = nullptr; // cleared the state @@ -141,14 +143,15 @@ class ReplayTraceNode : public SearchStrategyNode { } }; -inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { +inline ffi::Optional> +ReplayTraceNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { return std::nullopt; } ed = std::min(ed, max_trials); ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); - Array> per_task_result(ed - st, std::nullopt); + ffi::Array> per_task_result(ed - st, std::nullopt); ThreadedTraceApply pp(self->postprocs_); auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, int task_id) -> void { @@ -159,31 +162,31 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; tir::Trace new_trace = tir::Trace(trace->insts, {}); - if (Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { + if (ffi::Optional opt_sch = pp.Apply(mod, new_trace, &rand_state)) { tir::Schedule sch = opt_sch.value(); - Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); + ffi::Array args_info = ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true); per_task_result.Set(task_id, MeasureCandidate(sch, args_info)); break; } } }; support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); - Array filtered; + ffi::Array filtered; filtered.reserve(ed - st); - for (Optional result : per_task_result) + for (ffi::Optional result : per_task_result) if (result.has_value()) { filtered.push_back(*std::move(result)); } return filtered; } -inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { +inline void ReplayTraceNode::State::NotifyRunnerResults(const ffi::Array& results) { st += num_trials_per_iter; ed += num_trials_per_iter; } SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->max_fail_count = max_fail_count; return SearchStrategy(n); } diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 66d063b2dcba..3d0941c3632f 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -23,8 +23,8 @@ namespace tvm { namespace meta_schedule { -MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { - ObjectPtr n = make_object(); +MeasureCandidate::MeasureCandidate(tir::Schedule sch, ffi::Array args_info) { + ObjectPtr n = ffi::make_object(); n->sch = sch; n->args_info = args_info; data_ = std::move(n); @@ -37,9 +37,9 @@ void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) } void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, - const Array& design_spaces, - const Optional& database, - const Optional& cost_model) { + const ffi::Array& design_spaces, + const ffi::Optional& database, + const ffi::Optional& cost_model) { ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model); } @@ -49,14 +49,15 @@ void PySearchStrategyNode::PostTuning() { f_post_tuning(); } -Optional> PySearchStrategyNode::GenerateMeasureCandidates() { +ffi::Optional> PySearchStrategyNode::GenerateMeasureCandidates() { ICHECK(f_generate_measure_candidates != nullptr) << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return f_generate_measure_candidates(); } -void PySearchStrategyNode::NotifyRunnerResults(const Array& measure_candidates, - const Array& results) { +void PySearchStrategyNode::NotifyRunnerResults( + const ffi::Array& measure_candidates, + const ffi::Array& results) { ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; f_notify_runner_results(measure_candidates, results); @@ -74,7 +75,7 @@ SearchStrategy SearchStrategy::PySearchStrategy( PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results, // PySearchStrategyNode::FClone f_clone) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_initialize_with_tune_context = f_initialize_with_tune_context; n->f_pre_tuning = f_pre_tuning; n->f_post_tuning = f_post_tuning; @@ -93,7 +94,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.MeasureCandidate", - [](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { + [](tir::Schedule sch, ffi::Optional> args_info) -> MeasureCandidate { return MeasureCandidate(sch, args_info.value_or({})); }) .def("meta_schedule.SearchStrategyPySearchStrategy", SearchStrategy::PySearchStrategy) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 86f21f43e817..1c41b1f96522 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -45,8 +45,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - Array GenerateDesignSpace(const IRModule& mod) final { - using ScheduleAndUnvisitedBlocks = std::pair>; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + using ScheduleAndUnvisitedBlocks = std::pair>; CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, @@ -55,8 +55,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; - Array result{sch}; - Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); + ffi::Array result{sch}; + ffi::Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); for (ScheduleRule sch_rule : sch_rules.value()) { for (const tir::Schedule& sch : result) { @@ -80,12 +80,12 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { - if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { + if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { stack.emplace_back(sch, blocks); continue; } } - Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + ffi::Array applied = sch_rule->Apply(sch, /*block=*/block_rv); for (const tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } @@ -95,7 +95,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); CloneRules(this, n.get()); return SpaceGenerator(n); } @@ -103,11 +103,11 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::PostOrderApply( + ffi::Function f_block_filter, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 1112aca88762..537551ba7436 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -40,7 +40,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { this->rand_state_ = ForkSeed(&context->rand_state); } - Array GenerateDesignSpace(const IRModule& mod) final { + ffi::Array GenerateDesignSpace(const IRModule& mod) final { tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), @@ -56,7 +56,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { return {sch.value()}; } if (const auto* arr = obj.as()) { - Array result; + ffi::Array result; result.reserve(arr->size()); for (Any val : *arr) { if (auto sch = val.as()) { @@ -76,7 +76,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); CloneRules(this, n.get()); return SpaceGenerator(n); } @@ -85,11 +85,11 @@ class ScheduleFnNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::ScheduleFn( + ffi::Function schedule_fn, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 20d2d3626843..e6f01fa51760 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,7 +24,7 @@ namespace tvm { namespace meta_schedule { -String GetRuleKindFromTarget(const Target& target) { +ffi::String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { static auto target_has_feature_fn_ptr = tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); @@ -59,7 +59,7 @@ String GetRuleKindFromTarget(const Target& target) { return "hexagon"; } if (target->kind->name == "cuda") { - if (Optional opt_sm = target->GetAttr("arch")) { + if (ffi::Optional opt_sm = target->GetAttr("arch")) { std::string sm = opt_sm.value(); if (support::StartsWith(sm, "sm_")) { sm = sm.substr(3); @@ -92,10 +92,10 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { !(sch_rules.defined() && // postprocs.defined() && // mutator_probs.defined())) { - String kind = GetRuleKindFromTarget(context->target.value()); - Array default_sch_rules; - Array default_postprocs; - Map default_mutator_probs; + ffi::String kind = GetRuleKindFromTarget(context->target.value()); + ffi::Array default_sch_rules; + ffi::Array default_postprocs; + ffi::Map default_mutator_probs; // for target with skylake-avx512 if (kind == "llvm") { default_sch_rules = ScheduleRule::DefaultLLVM(); @@ -174,7 +174,7 @@ void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) f_initialize_with_tune_context(context); } -Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { +ffi::Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { ICHECK(f_generate_design_space != nullptr) << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); @@ -186,11 +186,12 @@ SpaceGenerator PySpaceGeneratorNode::Clone() const { } SpaceGenerator SpaceGenerator::PySpaceGenerator( - Optional> sch_rules, Optional> postprocs, - Optional> mutator_probs, + ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->sch_rules = sch_rules; n->postprocs = postprocs; n->mutator_probs = mutator_probs; diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index f9a8c2e71c8b..4151265b2718 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -27,7 +27,7 @@ namespace meta_schedule { class SpaceGeneratorUnionNode : public SpaceGeneratorNode { public: /*! \brief The array of design space generators unioned, could be recursive. */ - Array space_generators; + ffi::Array space_generators; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -42,11 +42,11 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { } } - Array GenerateDesignSpace(const IRModule& mod) final { - Array design_spaces; + ffi::Array GenerateDesignSpace(const IRModule& mod) final { + ffi::Array design_spaces; for (const SpaceGenerator& space_generator : space_generators) { // Generate partial design spaces from each design space generator. - Array partial = space_generator->GenerateDesignSpace(mod); + ffi::Array partial = space_generator->GenerateDesignSpace(mod); // Merge the partial design spaces. design_spaces.insert(design_spaces.end(), partial.begin(), partial.end()); } @@ -54,8 +54,8 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { } SpaceGenerator Clone() const final { - ObjectPtr n = make_object(*this); - n->space_generators = Array(); + ObjectPtr n = ffi::make_object(*this); + n->space_generators = ffi::Array(); for (const SpaceGenerator& space_generator : this->space_generators) { n->space_generators.push_back(space_generator->Clone()); } @@ -72,11 +72,11 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { * \param space_generators Array of the design space generators to be unioned. * \return The design space generator created. */ -SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators, - Optional> sch_rules, - Optional> postprocs, - Optional> mutator_probs) { - ObjectPtr n = make_object(); +SpaceGenerator SpaceGenerator::SpaceGeneratorUnion( + ffi::Array space_generators, ffi::Optional> sch_rules, + ffi::Optional> postprocs, + ffi::Optional> mutator_probs) { + ObjectPtr n = ffi::make_object(); n->sch_rules = std::move(sch_rules); n->postprocs = std::move(postprocs); n->mutator_probs = std::move(mutator_probs); diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index a19754b49ccd..3ec066e7e882 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -44,10 +44,10 @@ class GradientBasedNode final : public TaskSchedulerNode { TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); public: - void Tune(Array tasks, Array task_weights, int max_trials_global, + void Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) final { + ffi::Array measure_callbacks, ffi::Optional database, + ffi::Optional cost_model) final { int n_tasks = tasks.size(); round_robin_rounds_ = 0; best_latency_history_.resize(n_tasks, std::vector()); @@ -122,8 +122,8 @@ class GradientBasedNode final : public TaskSchedulerNode { return task_id; } - Array JoinRunningTask(int task_id) final { - Array results = TaskSchedulerNode::JoinRunningTask(task_id); + ffi::Array JoinRunningTask(int task_id) final { + ffi::Array results = TaskSchedulerNode::JoinRunningTask(task_id); TaskRecordNode* task = this->tasks_[task_id].get(); if (task->latency_ms.size() > 0) { this->best_latency_history_.at(task_id).push_back( @@ -136,7 +136,7 @@ class GradientBasedNode final : public TaskSchedulerNode { TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->alpha = alpha; n->window_size = window_size; diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 9bb5a20188ec..cc45ded7f40b 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -58,7 +58,7 @@ class RoundRobinNode final : public TaskSchedulerNode { }; TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->task_id = -1; return TaskScheduler(n); diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 21827ba8ad03..cc337d99a3a4 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -48,9 +48,9 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { void SendToBuilder(TaskRecordNode* self, const Builder& builder) { auto _ = Profiler::TimedScope("SendToBuilder"); - Array candidates = self->measure_candidates.value(); + ffi::Array candidates = self->measure_candidates.value(); Target target = self->ctx->target.value(); - Array inputs; + ffi::Array inputs; inputs.reserve(candidates.size()); for (const MeasureCandidate& candidate : candidates) { inputs.push_back(BuilderInput(candidate->sch->mod(), target)); @@ -60,13 +60,13 @@ void SendToBuilder(TaskRecordNode* self, const Builder& builder) { void SendToRunner(TaskRecordNode* self, const Runner& runner) { auto _ = Profiler::TimedScope("SendToRunner"); - Array candidates = self->measure_candidates.value(); - Array builder_results = self->builder_results.value(); + ffi::Array candidates = self->measure_candidates.value(); + ffi::Array builder_results = self->builder_results.value(); Target target = self->ctx->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); int n_build_errors = 0; - Array inputs; + ffi::Array inputs; inputs.reserve(n); for (int i = 0; i < n; ++i) { const MeasureCandidate& candidate = candidates[i]; @@ -79,12 +79,12 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { /*device_type=*/target->kind->name, /*args_info=*/candidate->args_info)); } - Array futures = runner->Run(inputs); + ffi::Array futures = runner->Run(inputs); if (n_build_errors == 0) { self->runner_futures = futures; return; } - Array results; + ffi::Array results; results.reserve(n); for (int i = 0, j = 0; i < n; ++i) { const BuilderResult& builder_result = builder_results[i]; @@ -102,7 +102,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { self->runner_futures = results; } -void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& results) { +void TaskCleanUp(TaskRecordNode* self, int task_id, const ffi::Array& results) { ICHECK_EQ(self->builder_results.value().size(), results.size()); ICHECK_EQ(self->runner_futures.value().size(), results.size()); int n = results.size(); @@ -112,7 +112,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r const BuilderResult& builder_result = self->builder_results.value()[i]; const MeasureCandidate& candidate = self->measure_candidates.value()[i]; const RunnerResult& runner_result = results[i]; - Optional error_msg = std::nullopt; + ffi::Optional error_msg = std::nullopt; int trials = self->latency_ms.size() + 1; double run_ms = 1e9; if ((error_msg = builder_result->error_msg)) { @@ -148,11 +148,12 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r self->runner_futures = std::nullopt; } -void TaskSchedulerNode::Tune(Array ctxs, Array task_weights, +void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, Optional database, - Optional cost_model) { + ffi::Array measure_callbacks, + ffi::Optional database, + ffi::Optional cost_model) { CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " "length as `ctxs`"; int n_tasks = this->remaining_tasks_ = ctxs.size(); @@ -167,7 +168,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; this->tasks_.push_back(TaskRecord(ctx, weight)); - Array design_spaces = + ffi::Array design_spaces = ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() << " design space(s) generated"; @@ -194,7 +195,7 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh TerminateTask(task_id); continue; } - if (Optional> candidates = task->measure_candidates = + if (ffi::Optional> candidates = task->measure_candidates = task->ctx->search_strategy.value()->GenerateMeasureCandidates()) { int num_candidates = candidates.value().size(); num_trials_already += num_candidates; @@ -218,13 +219,13 @@ void TaskSchedulerNode::Tune(Array ctxs, Array task_weigh } } -Array TaskSchedulerNode::JoinRunningTask(int task_id) { +ffi::Array TaskSchedulerNode::JoinRunningTask(int task_id) { TaskRecordNode* task = this->tasks_[task_id].get(); ICHECK(task->runner_futures.defined()); - Array results; + ffi::Array results; { auto _ = Profiler::TimedScope("JoinRunnerFutures"); - Array futures = task->runner_futures.value(); + ffi::Array futures = task->runner_futures.value(); results.reserve(futures.size()); for (RunnerFuture future : futures) { results.push_back(future->Result()); @@ -237,7 +238,7 @@ Array TaskSchedulerNode::JoinRunningTask(int task_id) { ICHECK_EQ(results.size(), task->measure_candidates.value().size()); ICHECK_EQ(results.size(), task->builder_results.value().size()); for (const MeasureCallback& callback : this->measure_callbacks_) { - callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + callback->Apply(ffi::GetRef(this), task_id, task->measure_candidates.value(), task->builder_results.value(), results); } TaskCleanUp(task, task_id, results); @@ -333,7 +334,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined"; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->logger = logger; n->f_next_task_id = f_next_task_id; n->f_join_running_task = f_join_running_task; @@ -346,7 +347,7 @@ int PyTaskSchedulerNode::NextTaskId() { return f_next_task_id(); } -Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { +ffi::Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { if (f_join_running_task == nullptr) { return TaskSchedulerNode::JoinRunningTask(task_id); } else { @@ -354,11 +355,12 @@ Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { } } -void PyTaskSchedulerNode::Tune(Array tasks, Array task_weights, +void PyTaskSchedulerNode::Tune(ffi::Array tasks, ffi::Array task_weights, int max_trials_global, int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, - Array measure_callbacks, - Optional database, Optional cost_model) { + ffi::Array measure_callbacks, + ffi::Optional database, + ffi::Optional cost_model) { if (f_tune == nullptr) { TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, builder, runner, measure_callbacks, database, diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 114afc0ad72e..d9096e4b9c3d 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -56,7 +56,7 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { std::unordered_set get_block_names; for (const auto& inst : anchor_trace->insts) { if (inst->kind.same_as(kind_get_block)) { - auto block_name = Downcast(inst->attrs[0]); + auto block_name = Downcast(inst->attrs[0]); get_block_names.insert(block_name); } } @@ -140,9 +140,10 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { continue; } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); + ffi::Array inputs = TranslateInputRVs(inst->inputs, rv_map); - if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast(inst->attrs[0]))) { + if (inst->kind.same_as(kind_get_block) && + !HasBlock(sch, Downcast(inst->attrs[0]))) { // The anchor trace does get_block on a block that is not part of the target schedule. auto block = Downcast(inst->outputs[0]); foreign_blocks.insert(block); @@ -174,7 +175,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { } Any decision = anchor_trace->GetDecision(inst); - Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); + ffi::Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision); if (inst->kind.same_as(kind_get_child_blocks)) { // We want to allow a trace generated for a single conv2d block to be applied to @@ -184,9 +185,9 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" // outputs matches with the "old" outputs, and truncating the new outputs accordingly. ICHECK(inst->outputs.size() <= outputs.size()); - TranslateAddOutputRVs(inst->outputs, - Array(outputs.begin(), outputs.begin() + inst->outputs.size()), - &rv_map); + TranslateAddOutputRVs( + inst->outputs, ffi::Array(outputs.begin(), outputs.begin() + inst->outputs.size()), + &rv_map); } else { TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } @@ -248,7 +249,7 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm auto auto_bind_rule = ScheduleRule::AutoBind(/*max_threadblocks=*/256, - /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}, + /*thread_extents*/ ffi::Array{32, 64, 128, 256, 512, 1024}, max_threads_per_block.value()->value); auto_bind_rule->Apply(sch, last_block); } diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 1b2cb9d0c140..857fc5b2977c 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -25,12 +25,13 @@ namespace tvm { namespace meta_schedule { -TuneContext::TuneContext(Optional mod, Optional target, - Optional space_generator, - Optional search_strategy, Optional task_name, - int num_threads, TRandState rand_state, ffi::Function logger) { +TuneContext::TuneContext(ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, + ffi::Optional task_name, int num_threads, + TRandState rand_state, ffi::Function logger) { CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->mod = mod; n->target = target; n->space_generator = space_generator; @@ -43,7 +44,7 @@ TuneContext::TuneContext(Optional mod, Optional target, } TuneContext TuneContextNode::Clone() const { - ObjectPtr n = make_object(*this); + ObjectPtr n = ffi::make_object(*this); if (this->space_generator.defined()) { n->space_generator = this->space_generator.value()->Clone(); } @@ -57,10 +58,10 @@ TuneContext TuneContextNode::Clone() const { void TuneContextNode::Initialize() { if (this->space_generator.defined()) { - this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + this->space_generator.value()->InitializeWithTuneContext(ffi::GetRef(this)); } if (this->search_strategy.defined()) { - this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + this->search_strategy.value()->InitializeWithTuneContext(ffi::GetRef(this)); } } @@ -70,10 +71,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("meta_schedule.TuneContext", - [](Optional mod, Optional target, - Optional space_generator, Optional search_strategy, - Optional task_name, int num_threads, TRandState rand_state, - ffi::Function logger) -> TuneContext { + [](ffi::Optional mod, ffi::Optional target, + ffi::Optional space_generator, + ffi::Optional search_strategy, ffi::Optional task_name, + int num_threads, TRandState rand_state, ffi::Function logger) -> TuneContext { return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, rand_state, logger); }) diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 21483d3b98a4..732a3a083d03 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -136,7 +136,7 @@ inline bool using_ipython() { * \brief Print out the performance table interactively in jupyter notebook. * \param str The serialized performance table. */ -inline void print_interactive_table(const String& data) { +inline void print_interactive_table(const ffi::String& data) { const auto f_print_interactive_table = tvm::ffi::Function::GetGlobal("meta_schedule.print_interactive_table"); ICHECK(f_print_interactive_table.has_value()) @@ -214,14 +214,14 @@ std::string JSONDumps(Any json_obj); * \param hash_code The hash code * \return The string representation of the hash code */ -inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +inline ffi::String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } /*! * \brief Converts an TVM object to the hex string representation of its structural hash. * \param obj The TVM object. * \return The hex string representation of the hash code. */ -inline String SHash2Hex(const ObjectRef& obj) { +inline ffi::String SHash2Hex(const ObjectRef& obj) { std::ostringstream os; size_t hash_code = 0; if (obj.defined()) { @@ -272,7 +272,7 @@ inline IRModule DeepCopyIRModule(IRModule mod) { return LoadJSON(SaveJSON(mod)). * \param delim The delimiter * \return The concatenated string */ -inline std::string Concat(const Array& strs, const std::string& delim) { +inline std::string Concat(const ffi::Array& strs, const std::string& delim) { if (strs.empty()) { return ""; } @@ -292,7 +292,7 @@ inline std::string Concat(const Array& strs, const std::string& delim) { * \return The BlockRV */ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, - const String& global_var_name) { + const ffi::String& global_var_name) { const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); return sch->GetBlock(block->name_hint, global_var_name); } @@ -303,7 +303,7 @@ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& */ struct ThreadedTraceApply { /*! \brief Constructor */ - explicit ThreadedTraceApply(const Array& postprocs) + explicit ThreadedTraceApply(const ffi::Array& postprocs) : n_(postprocs.size()), items_(new Item[n_]) { for (int i = 0; i < n_; ++i) { items_[i].postproc = postprocs[i]; @@ -321,8 +321,8 @@ struct ThreadedTraceApply { * \param rand_state The random seed * \return The schedule created, or std::nullopt if any postprocessor fails */ - Optional Apply(const IRModule& mod, const tir::Trace& trace, - TRandState* rand_state) { + ffi::Optional Apply(const IRModule& mod, const tir::Trace& trace, + TRandState* rand_state) { tir::Schedule sch = tir::Schedule::Traced(mod, /*rand_state=*/ForkSeed(rand_state), @@ -397,7 +397,7 @@ inline int GetTargetNumCores(const Target& target) { * \return The median of the running time in millisecond */ inline double GetRunMsMedian(const RunnerResult& runner_result) { - Array run_secs = runner_result->run_secs.value(); + ffi::Array run_secs = runner_result->run_secs.value(); ICHECK(!run_secs.empty()); std::vector v; v.reserve(run_secs.size()); @@ -417,10 +417,10 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { * \param obj The object to be converted * \return The array of floating point numbers */ -inline Array AsFloatArray(const ObjectRef& obj) { +inline ffi::Array AsFloatArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); - Array results; + ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { auto float_value = [&]() -> FloatImm { @@ -444,10 +444,10 @@ inline Array AsFloatArray(const ObjectRef& obj) { * \param obj The object to be converted * \return The array of integers */ -inline Array AsIntArray(const ObjectRef& obj) { +inline ffi::Array AsIntArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); - Array results; + ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { auto int_value = [&]() -> int64_t { @@ -467,7 +467,7 @@ inline Array AsIntArray(const ObjectRef& obj) { struct SortTuningRecordByMeanRunSecs { static const constexpr double kMaxMeanTime = 1e10; - static double Mean(const Array& a) { + static double Mean(const ffi::Array& a) { if (a.empty()) { return kMaxMeanTime; } @@ -492,8 +492,8 @@ struct SortTuningRecordByMeanRunSecs { */ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { if (src->sch_rules.defined()) { - Array original = src->sch_rules.value(); - Array sch_rules; + ffi::Array original = src->sch_rules.value(); + ffi::Array sch_rules; sch_rules.reserve(original.size()); for (const ScheduleRule& sch_rule : original) { sch_rules.push_back(sch_rule->Clone()); @@ -501,8 +501,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { dst->sch_rules = std::move(sch_rules); } if (src->postprocs.defined()) { - Array original = src->postprocs.value(); - Array postprocs; + ffi::Array original = src->postprocs.value(); + ffi::Array postprocs; postprocs.reserve(original.size()); for (const Postproc& postproc : original) { postprocs.push_back(postproc->Clone()); @@ -510,8 +510,8 @@ inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { dst->postprocs = std::move(postprocs); } if (src->mutator_probs.defined()) { - Map original = src->mutator_probs.value(); - Map mutator_probs; + ffi::Map original = src->mutator_probs.value(); + ffi::Map mutator_probs; for (const auto& kv : original) { mutator_probs.Set(kv.first->Clone(), kv.second); } @@ -532,7 +532,7 @@ inline bool IsGPUTarget(const std::string& target_name) { * \return The AutoInline schedule rule for the given target. */ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { - Array rules{nullptr}; + ffi::Array rules{nullptr}; if (target_name == "llvm") { rules = ScheduleRule::DefaultLLVM(); } else if (target_name == "hexagon") { @@ -557,7 +557,7 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { * \param arr The array of FloatImm. * \return The summary of the values in the given array. */ -inline double Sum(const Array& arr) { +inline double Sum(const ffi::Array& arr) { double sum = 0; for (const FloatImm& f : arr) { sum += f->value; @@ -568,21 +568,21 @@ inline double Sum(const Array& arr) { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch, - const ffi::Function f_block_filter = nullptr) { // + static ffi::Array Collect(const tir::Schedule& sch, + const ffi::Function f_block_filter = nullptr) { // return BlockCollector(sch, f_block_filter).Run(); } private: /*! \brief Entry point */ - Array Run() { + ffi::Array Run() { std::vector results; - auto f_collect = [this, &results](tir::PrimFunc func, String func_name) { + auto f_collect = [this, &results](tir::PrimFunc func, ffi::String func_name) { func_name_ = func_name; block_names_.clear(); blocks_to_collect_.clear(); VisitStmt(func->body); - for (const String& name : blocks_to_collect_) { + for (const ffi::String& name : blocks_to_collect_) { results.push_back(sch_->GetBlock(name, func_name_)); } }; @@ -596,7 +596,7 @@ class BlockCollector : public tir::StmtVisitor { // `gv->name_hint` is the name of the function // `base_func` can be PrimFunc or relax::Function if (const auto* func = base_func.as()) { - f_collect(GetRef(func), gv->name_hint); + f_collect(ffi::GetRef(func), gv->name_hint); } } } @@ -617,7 +617,7 @@ class BlockCollector : public tir::StmtVisitor { // Otherwise collect all blocks. Bool collect_block = Bool(true); if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(GetRef(block)).cast(); + collect_block = f_block_filter_(ffi::GetRef(block)).cast(); } if (collect_block) { blocks_to_collect_.push_back(block->name_hint); @@ -629,15 +629,15 @@ class BlockCollector : public tir::StmtVisitor { /*! \brief An optional packed func that allows only certain blocks to be collected. */ const ffi::Function f_block_filter_; /*! \brief The set of func name and block name pair */ - std::unordered_set block_names_; + std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ - Array blocks_to_collect_; + ffi::Array blocks_to_collect_; /*! \brief Name of the current PrimFunc */ - String func_name_; + ffi::String func_name_; }; -void JSONFileAppendLine(const String& path, const std::string& line); -std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); +void JSONFileAppendLine(const ffi::String& path, const std::string& line); +std::vector JSONFileReadLines(const ffi::String& path, int num_threads, bool allow_missing); } // namespace meta_schedule } // namespace tvm diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index 334c15b3be97..fee7eeb26cab 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -50,7 +50,7 @@ class AttrRegistry { * \param name The name of the item. * \return The corresponding entry. */ - const EntryType* Get(const String& name) const { + const EntryType* Get(const ffi::String& name) const { auto it = entry_map_.find(name); if (it != entry_map_.end()) return it->second; return nullptr; @@ -61,7 +61,7 @@ class AttrRegistry { * \param name The name of the item. * \return The corresponding entry. */ - EntryType& RegisterOrGet(const String& name) { + EntryType& RegisterOrGet(const ffi::String& name) { auto it = entry_map_.find(name); if (it != entry_map_.end()) return *it->second; uint32_t registry_index = static_cast(entries_.size()); @@ -77,8 +77,8 @@ class AttrRegistry { * \brief List all the entry names in the registry. * \return The entry names. */ - Array ListAllNames() const { - Array names; + ffi::Array ListAllNames() const { + ffi::Array names; for (const auto& kv : entry_map_) { names.push_back(kv.first); } @@ -92,7 +92,7 @@ class AttrRegistry { * \param value The value to be set. * \param plevel The support level. */ - void UpdateAttr(const String& attr_name, const KeyType& key, Any value, int plevel) { + void UpdateAttr(const ffi::String& attr_name, const KeyType& key, Any value, int plevel) { using ffi::Any; auto& op_map = attrs_[attr_name]; if (op_map == nullptr) { @@ -119,7 +119,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \param key The key to the attribute table. */ - void ResetAttr(const String& attr_name, const KeyType& key) { + void ResetAttr(const ffi::String& attr_name, const KeyType& key) { auto& op_map = attrs_[attr_name]; if (op_map == nullptr) { return; @@ -135,7 +135,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \return The result attribute map. */ - const AttrRegistryMapContainerMap& GetAttrMap(const String& attr_name) { + const AttrRegistryMapContainerMap& GetAttrMap(const ffi::String& attr_name) { auto it = attrs_.find(attr_name); if (it == attrs_.end()) { LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered"; @@ -148,7 +148,7 @@ class AttrRegistry { * \param attr_name The name of the attribute. * \return The check result. */ - bool HasAttrMap(const String& attr_name) { return attrs_.count(attr_name); } + bool HasAttrMap(const ffi::String& attr_name) { return attrs_.count(attr_name); } /*! * \return a global singleton of the registry. @@ -162,9 +162,9 @@ class AttrRegistry { // entries in the registry std::vector> entries_; // map from name to entries. - std::unordered_map entry_map_; + std::unordered_map entry_map_; // storage of additional attribute table. - std::unordered_map>> attrs_; + std::unordered_map>> attrs_; }; } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index e666b434f8f5..82060f0e857b 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -38,12 +38,12 @@ using ffi::PackedArgs; // key1, value1, ..., key_n, value_n void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { // TODO(tvm-team): consider further simplify by removing DictAttrsNode special handling - String type_key = args[0].cast(); + ffi::String type_key = args[0].cast(); int32_t type_index; TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); if (type_index == DictAttrsNode::RuntimeTypeIndex()) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->InitByPackedArgs(args.Slice(1), false); *rv = ObjectRef(attrs); } else { diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 9b1565d2ab3a..68b2b392105b 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -35,7 +35,8 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { return inst; } -std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { +std::string TVMScriptPrinter::Script(const ObjectRef& node, + const ffi::Optional& cfg) { if (!TVMScriptPrinter::vtable().can_dispatch(node)) { std::ostringstream os; ReprPrinter printer(os); @@ -59,34 +60,34 @@ bool IsIdentifier(const std::string& name) { [](char c) { return std::isalnum(c) || c == '_'; }); } -PrinterConfig::PrinterConfig(Map config_dict) { - runtime::ObjectPtr n = make_object(); +PrinterConfig::PrinterConfig(ffi::Map config_dict) { + runtime::ObjectPtr n = ffi::make_object(); if (auto v = config_dict.Get("name")) { - n->binding_names.push_back(Downcast(v.value())); + n->binding_names.push_back(Downcast(v.value())); } if (auto v = config_dict.Get("show_meta")) { n->show_meta = v.value().cast(); } if (auto v = config_dict.Get("ir_prefix")) { - n->ir_prefix = Downcast(v.value()); + n->ir_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("tir_prefix")) { - n->tir_prefix = Downcast(v.value()); + n->tir_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("relax_prefix")) { - n->relax_prefix = Downcast(v.value()); + n->relax_prefix = Downcast(v.value()); } if (auto v = config_dict.Get("module_alias")) { - n->module_alias = Downcast(v.value()); + n->module_alias = Downcast(v.value()); } if (auto v = config_dict.Get("buffer_dtype")) { - n->buffer_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("int_dtype")) { - n->int_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->int_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("float_dtype")) { - n->float_dtype = DataType(StringToDLDataType(Downcast(v.value()))); + n->float_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("verbose_expr")) { n->verbose_expr = v.value().cast(); @@ -101,18 +102,20 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->num_context_lines = v.value().cast(); } if (auto v = config_dict.Get("path_to_underline")) { - n->path_to_underline = Downcast>>(v).value_or(Array()); + n->path_to_underline = + Downcast>>(v).value_or(ffi::Array()); } if (auto v = config_dict.Get("path_to_annotate")) { - n->path_to_annotate = - Downcast>>(v).value_or(Map()); + n->path_to_annotate = Downcast>>(v).value_or( + ffi::Map()); } if (auto v = config_dict.Get("obj_to_underline")) { - n->obj_to_underline = Downcast>>(v).value_or(Array()); + n->obj_to_underline = + Downcast>>(v).value_or(ffi::Array()); } if (auto v = config_dict.Get("obj_to_annotate")) { - n->obj_to_annotate = - Downcast>>(v).value_or(Map()); + n->obj_to_annotate = Downcast>>(v).value_or( + ffi::Map()); } if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = v.value().cast(); @@ -134,8 +137,8 @@ PrinterConfig::PrinterConfig(Map config_dict) { this->data_ = std::move(n); } -Array PrinterConfigNode::GetBuiltinKeywords() { - Array result{this->ir_prefix, this->tir_prefix, this->relax_prefix}; +ffi::Array PrinterConfigNode::GetBuiltinKeywords() { + ffi::Array result{this->ir_prefix, this->tir_prefix, this->relax_prefix}; if (!this->module_alias.empty()) { result.push_back(this->module_alias); } @@ -146,7 +149,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("node.PrinterConfig", - [](Map config_dict) { return PrinterConfig(config_dict); }) + [](ffi::Map config_dict) { return PrinterConfig(config_dict); }) .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script); }); diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 1810efa1bf2e..24916fb18803 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -50,12 +50,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::TypeAttrDef() .def("__data_to_json__", [](const ffi::ModuleObj* node) { - std::string bytes = codegen::SerializeModuleToBytes(GetRef(node), + std::string bytes = codegen::SerializeModuleToBytes(ffi::GetRef(node), /*export_dso*/ false); return ffi::Base64Encode(ffi::Bytes(bytes)); }) - .def("__data_from_json__", [](const String& base64_bytes) { - Bytes bytes = ffi::Base64Decode(base64_bytes); + .def("__data_from_json__", [](const ffi::String& base64_bytes) { + ffi::Bytes bytes = ffi::Base64Decode(base64_bytes); ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator std::string()); return rtmod; }); @@ -68,7 +68,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ support::Base64OutStream b64strm(&mstrm); runtime::SaveDLTensor(&b64strm, node); b64strm.Finish(); - return String(blob); + return ffi::String(blob); }) .def("__data_from_json__", [](const std::string& blob) { dmlc::MemoryStringStream mstrm(const_cast(&blob)); diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index f1f47910f8b1..c2d29f9837bd 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -46,9 +46,9 @@ struct InsertionSet { class VarVisitor : protected ExprVisitor { public: - Array Free(const Expr& expr) { + ffi::Array Free(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : vars_.data) { if (bound_vars_.set.count(v) == 0) { ret.push_back(v); @@ -57,31 +57,31 @@ class VarVisitor : protected ExprVisitor { return ret; } - Array Collect() { - Array ret; + ffi::Array Collect() { + ffi::Array ret; for (const auto& v : bound_vars_.data) { ret.push_back(v); } return ret; } - Array Bound(const Expr& expr) { + ffi::Array Bound(const Expr& expr) { this->VisitExpr(expr); return Collect(); } - Array All(const Expr& expr) { + ffi::Array All(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : vars_.data) { ret.push_back(v); } return ret; } - Array AllGlobalVars(const Expr& expr) { + ffi::Array AllGlobalVars(const Expr& expr) { this->VisitExpr(expr); - Array ret; + ffi::Array ret; for (const auto& v : global_vars_.data) { ret.push_back(v); } @@ -93,7 +93,7 @@ class VarVisitor : protected ExprVisitor { vars_.Insert(v); } - void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + void VisitExpr_(const VarNode* var) final { vars_.Insert(ffi::GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { @@ -102,7 +102,9 @@ class VarVisitor : protected ExprVisitor { VisitExpr(op->body); } - void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + void VisitExpr_(const GlobalVarNode* op) final { + global_vars_.Insert(ffi::GetRef(op)); + } void VisitExpr_(const CallNode* call_node) final { VisitSpan(call_node->span); @@ -134,25 +136,27 @@ class VarVisitor : protected ExprVisitor { InsertionSet global_vars_; }; -tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } +tvm::ffi::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } -tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } +tvm::ffi::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } -tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } +tvm::ffi::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } +tvm::ffi::Array AllGlobalVars(const Expr& expr) { + return VarVisitor().AllGlobalVars(expr); +} -Optional FindImpureCall(const Expr& expr, const Optional& own_name) { +ffi::Optional FindImpureCall(const Expr& expr, const ffi::Optional& own_name) { class ImpureCallChecker : public ExprVisitor { public: - static Optional Check(const Expr& expr, const Optional& own_name) { + static ffi::Optional Check(const Expr& expr, const ffi::Optional& own_name) { ImpureCallChecker visitor(own_name); visitor.VisitExpr(expr); return visitor.impure_expr_; } private: - explicit ImpureCallChecker(const Optional& own_name) : own_name_(own_name) {} + explicit ImpureCallChecker(const ffi::Optional& own_name) : own_name_(own_name) {} void VisitExpr(const Expr& expr) override { // Early bail-out if we found an impure expression @@ -169,7 +173,7 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) void VisitExpr_(const CallNode* call) override { // ignore recursive calls if we find one bool is_recursive = (own_name_ && own_name_.value().same_as(call->op)); - auto expr = GetRef(call); + auto expr = ffi::GetRef(call); if (!is_recursive && IsImpureCall(expr)) { impure_expr_ = expr; } else { @@ -178,8 +182,8 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) } private: - const Optional& own_name_; - Optional impure_expr_ = std::nullopt; + const ffi::Optional& own_name_; + ffi::Optional impure_expr_ = std::nullopt; }; if (own_name) { @@ -194,7 +198,7 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) return ImpureCallChecker::Check(to_check, own_name); } -bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { +bool ContainsImpureCall(const Expr& expr, const ffi::Optional& own_name) { return FindImpureCall(expr, own_name).defined(); } diff --git a/src/relax/analysis/collect_call_map.cc b/src/relax/analysis/collect_call_map.cc index 3e0170d3444d..85099d88ff57 100644 --- a/src/relax/analysis/collect_call_map.cc +++ b/src/relax/analysis/collect_call_map.cc @@ -38,7 +38,9 @@ using ir::CalleeCollector; struct Visitor : ExprVisitor { explicit Visitor(CalleeCollector* collector) : collector(collector) {} CalleeCollector* collector; - void VisitExpr_(const GlobalVarNode* node) override { collector->Mark(GetRef(node)); } + void VisitExpr_(const GlobalVarNode* node) override { + collector->Mark(ffi::GetRef(node)); + } }; } // namespace diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 8b8665445d98..5ce64fcef220 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -35,10 +35,10 @@ namespace relax { namespace { class CompileTimeCollector : ExprVisitor { public: - static Array Collect(const Function& func) { + static ffi::Array Collect(const Function& func) { CompileTimeCollector visitor; visitor(func); - return Array(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end()); + return ffi::Array(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end()); } private: @@ -89,7 +89,7 @@ class CompileTimeCollector : ExprVisitor { }; } // namespace -Array ComputableAtCompileTime(const Function& func) { +ffi::Array ComputableAtCompileTime(const Function& func) { return CompileTimeCollector::Collect(func); } diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 73ad8a31f8a5..05260d18d89e 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -87,7 +87,7 @@ class DependencyGatherer : public ExprVisitor { void VisitExpr_(const GlobalVarNode* gv) override { // disregard PrimFuncs - if (!m_->Lookup(GetRef(gv)).as()) { + if (!m_->Lookup(ffi::GetRef(gv)).as()) { return; } deps_.insert(gv->name_hint); @@ -111,7 +111,7 @@ adjacency_map GatherDependencyGraph(const IRModule& m) { continue; } std::string name = gv_func.first->name_hint; - auto deps = DependencyGatherer(m).Track(GetRef(func)); + auto deps = DependencyGatherer(m).Track(ffi::GetRef(func)); ret.insert({name, deps}); } return ret; @@ -369,7 +369,7 @@ std::vector CoalesceCircuits(const std::vector& circuits) { return ret; } -tvm::Array> DetectRecursion(const IRModule& m) { +tvm::ffi::Array> DetectRecursion(const IRModule& m) { auto graph = GatherDependencyGraph(m); // have to decide on some ordering for names @@ -382,9 +382,9 @@ tvm::Array> DetectRecursion(const IRModule& m) { auto groups = CoalesceCircuits(DetectElementaryCircuits(indices)); // convert to expected representation - tvm::Array> ret; + tvm::ffi::Array> ret; for (auto group : groups) { - tvm::Array found; + tvm::ffi::Array found; for (size_t node : group) { found.push_back(m->GetGlobalVar(name_ordering[node])); } diff --git a/src/relax/analysis/graph_partitioner.cc b/src/relax/analysis/graph_partitioner.cc index 00f4da400657..d68626160fe9 100644 --- a/src/relax/analysis/graph_partitioner.cc +++ b/src/relax/analysis/graph_partitioner.cc @@ -252,11 +252,11 @@ size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src, } return 0; }; - if (auto call_node = GetRef(src->ref).as()) { + if (auto call_node = ffi::GetRef(src->ref).as()) { for (auto& it : call_node->args) { sum += calc_args_number(it); } - } else if (auto tuple_node = GetRef(src->ref).as()) { + } else if (auto tuple_node = ffi::GetRef(src->ref).as()) { for (auto& it : tuple_node->fields) { sum += calc_args_number(it); } @@ -288,19 +288,19 @@ size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph, void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { auto args_counter = [](const tvm::Object* obj) { size_t args_num = 0; - if (auto call_node = GetRef(obj).as()) { + if (auto call_node = ffi::GetRef(obj).as()) { for (auto& it : call_node->args) { if (it.as() || it.as()) { args_num++; } } - } else if (auto tuple_node = GetRef(obj).as()) { + } else if (auto tuple_node = ffi::GetRef(obj).as()) { for (auto& it : tuple_node->fields) { if (it.as() || it.as()) { args_num++; } } - } else if (GetRef(obj).as()) { + } else if (ffi::GetRef(obj).as()) { args_num++; } return args_num; diff --git a/src/relax/analysis/graph_partitioner.h b/src/relax/analysis/graph_partitioner.h index 3afb9888a162..09bf68734cc8 100644 --- a/src/relax/analysis/graph_partitioner.h +++ b/src/relax/analysis/graph_partitioner.h @@ -83,7 +83,7 @@ class IndexedForwardGraph { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + os << "node[" << i << "], " << ffi::GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } @@ -194,7 +194,7 @@ class GraphPartitioner { size_t args_num{0}; /*! \brief Optional attributes to annotate the grouped function. */ - Map attrs; + ffi::Map attrs; /*! * \brief Find the group root, perform path compression * \return The root type node. diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 109af127df2e..aa5ceea01560 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -40,8 +40,8 @@ using namespace tir; /********** Helper Functions **********/ /*! \brief Checks if a transformation is bijective affine over the given ranges */ -static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { - Map input_iters; +static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges) { + ffi::Map input_iters; ICHECK_EQ(m->initial_indices.size(), ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { input_iters.Set(m->initial_indices[i], ranges[i]); @@ -61,7 +61,7 @@ static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { */ class IndexAnalyzer : public ExprVisitor { public: - Array Analyze(const arith::IterSumExpr& expr) { + ffi::Array Analyze(const arith::IterSumExpr& expr) { VisitExpr(expr); return iterators_; } @@ -86,14 +86,14 @@ class IndexAnalyzer : public ExprVisitor { void VisitIterMark(const arith::IterMark& op) { if (const auto* var = op->source.as()) - iterators_.push_back(GetRef(var)); + iterators_.push_back(ffi::GetRef(var)); else VisitExpr(op->source); VisitExpr(op->extent); } private: - Array iterators_; + ffi::Array iterators_; }; /*! @@ -111,13 +111,13 @@ class IndexAnalyzer : public ExprVisitor { * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} * SpatialLayout(A[s0 * c + s1]) = undefined */ -using SpatialLayout = Array>; +using SpatialLayout = ffi::Array>; static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { ICHECK(!iter_map_result->indices.empty()); SpatialLayout result; for (const arith::IterSumExpr& index : iter_map_result->indices) { IndexAnalyzer index_analyzer; - Array iter_vars = index_analyzer.Analyze(index); + ffi::Array iter_vars = index_analyzer.Analyze(index); if (iter_vars.size() >= 2) { LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " << arith::NormalizeIterMapToExpr(index); @@ -173,7 +173,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { if (t0->final_indices.size() != t1->final_indices.size()) return false; // Create a new shape expression. - Array t1_initial_indices = + ffi::Array t1_initial_indices = t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); arith::Analyzer analyzer; auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer); @@ -213,9 +213,9 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) */ using VarSet = std::unordered_set; -static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, - const IndexMap& src_transformation, - const SpatialLayout& tgt_spatial_layout) { +static ffi::Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, + const IndexMap& src_transformation, + const SpatialLayout& tgt_spatial_layout) { // Copy over the src transformation intial and final indices auto initial_indices = support::AsList(src_transformation->initial_indices); auto final_indices = support::AsList(src_transformation->final_indices); @@ -244,7 +244,7 @@ static Optional InferLayoutTransformation(const SpatialLayout& src_spa auto final_indices_it = final_indices.begin(); while (final_indices_it != final_indices.end()) { // Collect all the vars used in this final index. - Array used_vars = tir::UndefinedVars(*final_indices_it); + ffi::Array used_vars = tir::UndefinedVars(*final_indices_it); ICHECK(!used_vars.empty()) << "IndexMap expression must always contain tir::Var nodes but found none in: " << *final_indices_it; @@ -318,7 +318,7 @@ static Optional InferLayoutTransformation(const SpatialLayout& src_spa */ class BlockAnalyzer : public StmtExprVisitor { public: - explicit BlockAnalyzer(const Block& block, const Map& transformation_cache, + explicit BlockAnalyzer(const Block& block, const ffi::Map& transformation_cache, IndexMap write_transformation) : can_transform_block_(true), write_transformation_(write_transformation), @@ -380,7 +380,7 @@ class BlockAnalyzer : public StmtExprVisitor { } block_transformation_ = maybe_block_transformation.value(); - Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); + ffi::Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); if (!IsBijectiveAffine(block_transformation_, block_ranges)) { can_transform_block_ = false; LOG(WARNING) << "[LayoutInference] Inferred block transformation is not bijective affine, " @@ -437,7 +437,7 @@ class BlockAnalyzer : public StmtExprVisitor { }; // Helper to break down the indices of buffer access. - SpatialLayout DetectBufferAccessIterMap(Array indices) { + SpatialLayout DetectBufferAccessIterMap(ffi::Array indices) { auto result = arith::DetectIterMap( /*indices=*/indices, /*input_iters*/ spatial_dom_, /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); @@ -516,19 +516,19 @@ class BlockAnalyzer : public StmtExprVisitor { public: bool CanBeTransformed() { return can_transform_block_; } IndexMap GetBlockTransformation() { return block_transformation_; } - Map GetReadBufferTransformations() { return read_buffer_transformations_; } + ffi::Map GetReadBufferTransformations() { return read_buffer_transformations_; } private: bool can_transform_block_; IndexMap write_transformation_; - Map spatial_dom_; + ffi::Map spatial_dom_; arith::Analyzer arith_analyzer_; Block block_; IndexMap block_transformation_; - Map read_buffer_transformations_; - const Map& buffer_transformation_cache_; + ffi::Map read_buffer_transformations_; + const ffi::Map& buffer_transformation_cache_; std::unordered_map buffer_access_info_; }; @@ -542,14 +542,14 @@ class BlockAnalyzer : public StmtExprVisitor { */ class PrimFuncAnalyzer : public StmtExprVisitor { public: - explicit PrimFuncAnalyzer(const PrimFunc& func, Array write_transformations) { + explicit PrimFuncAnalyzer(const PrimFunc& func, ffi::Array write_transformations) { ICHECK_LE(write_transformations.size(), func->params.size()) << "Incompatible PrimFunc and write_transformations"; size_t first_write_index = func->params.size() - write_transformations.size(); for (size_t i = 0; i < write_transformations.size(); ++i) { auto param = func->params[first_write_index + i]; - Optional param_buf = func->buffer_map.Get(param); + ffi::Optional param_buf = func->buffer_map.Get(param); ICHECK(param_buf.defined()); ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) << "Mismatch between output buffer shape and index map"; @@ -557,10 +557,10 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } VisitStmt(func->body); } - Map> GetSuggestedTransforms() { - Map> result; + ffi::Map> GetSuggestedTransforms() { + ffi::Map> result; for (const auto& [block, index_map] : block_transformations_) { - Map block_transformations; + ffi::Map block_transformations; block_transformations.Set(block, index_map); for (const auto& buffer : block_to_buffer_[block]) { block_transformations.Set(buffer, buffer_transformation_cache_[buffer]); @@ -578,7 +578,7 @@ class PrimFuncAnalyzer : public StmtExprVisitor { return; } - Block block = GetRef(op); + Block block = ffi::GetRef(op); // Get block write buffer transformation. if (block->writes.size() != 1) return; auto write_buffer = block->writes[0]->buffer; @@ -601,13 +601,13 @@ class PrimFuncAnalyzer : public StmtExprVisitor { } private: - Map buffer_transformation_cache_; - Map block_transformations_; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; + ffi::Map buffer_transformation_cache_; + ffi::Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; }; -Map> SuggestLayoutTransforms( - const PrimFunc& prim_func, Array write_buffer_transformations) { +ffi::Map> SuggestLayoutTransforms( + const PrimFunc& prim_func, ffi::Array write_buffer_transformations) { // No changes to the PrimFunc are required if no transformations on output buffers. if (write_buffer_transformations.empty()) return {}; @@ -618,7 +618,7 @@ Map> SuggestLayoutTransforms( TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.suggest_layout_transforms", - [](PrimFunc fn, Array write_buffer_transformations) { + [](PrimFunc fn, ffi::Array write_buffer_transformations) { return SuggestLayoutTransforms(fn, write_buffer_transformations); }); }); diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc index 70ce5ac06e90..e2f624937773 100644 --- a/src/relax/analysis/shape_analysis.cc +++ b/src/relax/analysis/shape_analysis.cc @@ -29,7 +29,7 @@ namespace tvm { namespace relax { -bool CanProveShapeEqual(const Array& lhs, const Array& rhs, +bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, arith::Analyzer* ana) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 389fb003c6d3..53f76cadcbba 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -57,14 +57,14 @@ class StaticTypeDeriver : public StructInfoFunctor { // end-module: distributed Type VisitStructInfo_(const TupleStructInfoNode* op) final { - Array fields = + ffi::Array fields = op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); return TupleType(fields, op->span); } Type VisitStructInfo_(const FuncStructInfoNode* op) final { if (op->IsOpaque()) return PackedFuncType(op->span); - Array params = op->params.value().Map( + ffi::Array params = op->params.value().Map( [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); Type ret = this->VisitStructInfo(op->ret); return FuncType(params, ret, op->span); @@ -93,13 +93,13 @@ StructInfo StructInfoFromType(const Type& type) { } else if (const TensorTypeNode* tensor_type = type.as()) { return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); } else if (const TupleTypeNode* tuple_type = type.as()) { - Array fields; + ffi::Array fields; for (const Type& field : tuple_type->fields) { fields.push_back(StructInfoFromType(field)); } return TupleStructInfo(fields, type->span); } else if (const FuncTypeNode* func_type = type.as()) { - Array params = + ffi::Array params = func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); StructInfo ret = StructInfoFromType(func_type->ret_type); // TODO(relax-team): Maybe add purity into the type as well @@ -117,13 +117,14 @@ class WellDefinedEraser : public StructInfoMutator, public ExprMutatorBase, public tir::ExprMutator { public: - WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, - std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, + arith::Analyzer* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final { bool has_undefined = false; - Optional value; + ffi::Optional value; if (op->value.defined()) { std::swap(has_undefined_, has_undefined); @@ -134,7 +135,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PrimStructInfo(value.value(), op->span); } @@ -145,7 +146,7 @@ class WellDefinedEraser : public StructInfoMutator, StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { bool has_undefined = false; - Optional> values; + ffi::Optional> values; if (op->values.defined()) { std::swap(has_undefined_, has_undefined); @@ -155,7 +156,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (values.same_as(op->values)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeStructInfo(values.value(), op->span); } @@ -166,7 +167,7 @@ class WellDefinedEraser : public StructInfoMutator, StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { bool has_undefined = false; - Optional shape; + ffi::Optional shape; if (op->shape.defined()) { std::swap(has_undefined_, has_undefined); @@ -179,7 +180,7 @@ class WellDefinedEraser : public StructInfoMutator, // erase symbolic shape if we have undefined. if (!has_undefined) { if (shape.same_as(op->shape)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (shape.defined()) { return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); @@ -197,7 +198,7 @@ class WellDefinedEraser : public StructInfoMutator, // // All the occuring symbolic variables are defined in parameters' // struct info annotations. So there is no needed to erase. - return GetRef(op); + return ffi::GetRef(op); } using relax::ExprMutatorBase::VisitExpr_; @@ -215,22 +216,22 @@ class WellDefinedEraser : public StructInfoMutator, } Expr VisitExpr_(const VarNode* var) final { - Optional ret; + ffi::Optional ret; if (f_var_map_ != nullptr) { - ret = f_var_map_(GetRef(var)); + ret = f_var_map_(ffi::GetRef(var)); } has_undefined_ = has_undefined_ || !ret.defined(); if (ret.defined()) { ICHECK(ret.as() || ret.as()) << "Only allow Expr in StructInfo to be ShapeExpr or Var"; } - return ret.value_or(GetRef(var)); + return ret.value_or(ffi::GetRef(var)); } PrimExpr VisitExpr_(const tir::VarNode* var) final { - Optional ret; + ffi::Optional ret; if (f_shape_var_map_ != nullptr) { - ret = f_shape_var_map_(GetRef(var)); + ret = f_shape_var_map_(ffi::GetRef(var)); } has_undefined_ = has_undefined_ || !ret.defined(); @@ -242,20 +243,21 @@ class WellDefinedEraser : public StructInfoMutator, ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; return value; } else { - return GetRef(var); + return ffi::GetRef(var); } } private: bool has_undefined_ = false; - std::function(const tir::Var& var)> f_shape_var_map_; - std::function(const Var& var)> f_var_map_; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; arith::Analyzer* ana_; }; StructInfo EraseToWellDefined( - const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, - std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { if (ana == nullptr) { arith::Analyzer inst; return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); @@ -264,13 +266,13 @@ StructInfo EraseToWellDefined( } } -StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, - Map var_map, arith::Analyzer* ana) { - std::function(const tir::Var& var)> f_shape_var_map = nullptr; - std::function(const Var& var)> f_var_map = nullptr; +StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, + ffi::Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; if (!shape_var_map.empty()) { - f_shape_var_map = [&](const tir::Var& var) -> Optional { + f_shape_var_map = [&](const tir::Var& var) -> ffi::Optional { auto it = shape_var_map.find(var); if (it != shape_var_map.end()) return (*it).second; return std::nullopt; @@ -278,7 +280,7 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh } if (!var_map.empty()) { - f_var_map = [&](const Var& var) -> Optional { + f_var_map = [&](const Var& var) -> ffi::Optional { auto it = var_map.find(var); if (it != var_map.end()) return (*it).second; return std::nullopt; @@ -292,9 +294,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.analysis.EraseToWellDefined", - [](const StructInfo& info, Map shape_var_map, Map var_map) { - return EraseToWellDefined(info, shape_var_map, var_map); - }); + [](const StructInfo& info, ffi::Map shape_var_map, + ffi::Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); }); }); //-------------------------- @@ -472,7 +473,7 @@ class StructInfoBaseChecker // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + if (struct_equal_(ffi::GetRef(lhs), other)) return BaseCheckResult::kPass; auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); @@ -511,7 +512,8 @@ class StructInfoBaseChecker * \param rhs The right hand shape. * \return CheckResult. */ - virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + virtual BaseCheckResult ShapeMatchCheck(const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; BaseCheckResult ret = BaseCheckResult::kPass; @@ -546,8 +548,8 @@ class StructInfoBaseChecker * \param rhs The right hand params. * \return Check result. */ - virtual BaseCheckResult FuncParamsCheck(const Array& lhs, - const Array& rhs) { + virtual BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, + const ffi::Array& rhs) { auto res = ArrayCheck(lhs, rhs); // treat L1 failures in params checking as L2. if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; @@ -578,7 +580,7 @@ class StructInfoBaseChecker * \param lhs The left operand. * \param rhs The right operand. */ - BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + BaseCheckResult ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; BaseCheckResult ret = BaseCheckResult::kPass; @@ -789,7 +791,7 @@ class StructInfoBasePreconditionCollector } private: - PrimExpr ArrayCheck(const Array& lhs, const Array& rhs) { + PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return Bool(false); } @@ -801,7 +803,7 @@ class StructInfoBasePreconditionCollector return all_equal; } - PrimExpr ArrayCheck(const Array& lhs, const Array& rhs) { + PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return Bool(false); } @@ -877,8 +879,8 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { // Whether to populate map in params. bool populate_mapping_{true}; // for simplicity, we make these fields public so the user can access them. - Map shape_var_map_; - Map var_map_; + ffi::Map shape_var_map_; + ffi::Map var_map_; using StructInfoBaseChecker::ShapeMatchCheck; @@ -889,7 +891,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } if (auto* ptr = param.as()) { - auto var = GetRef(ptr); + auto var = ffi::GetRef(ptr); auto it = shape_var_map_.find(var); // not populated if (it == shape_var_map_.end()) { @@ -916,7 +918,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } if (auto* ptr = lhs.as()) { - auto var = GetRef(ptr); + auto var = ffi::GetRef(ptr); auto it = var_map_.find(var); // not populated if (it == var_map_.end()) { @@ -936,8 +938,8 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); } - BaseCheckResult FuncParamsCheck(const Array& lhs, - const Array& rhs) final { + BaseCheckResult FuncParamsCheck(const ffi::Array& lhs, + const ffi::Array& rhs) final { // Set populate mapping to false // so we do not pick up symbolic vars in params with function type. // @@ -990,7 +992,7 @@ class StructInfoLCAFinder // Object is based of everything, unify to object. StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { @@ -1008,13 +1010,13 @@ class StructInfoLCAFinder if (!lhs->value.defined()) { // If the mismatch was due to extra information in the RHS, // prefer to avoid constructing a new object. - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return PrimStructInfo(lhs->dtype, lhs->span); } } - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { @@ -1026,13 +1028,13 @@ class StructInfoLCAFinder !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { // prefers return same when possible if (!lhs->values.defined() && lhs->ndim == ndim) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return ShapeStructInfo(ndim, lhs->span); } } // equals to each other - return GetRef(lhs); + return ffi::GetRef(lhs); } StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { @@ -1054,7 +1056,7 @@ class StructInfoLCAFinder // reuse lhs when possible if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { return TensorStructInfo(dtype, ndim, vdev, lhs->span); } @@ -1063,14 +1065,14 @@ class StructInfoLCAFinder if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) { return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span); } else { - return GetRef(lhs); + return ffi::GetRef(lhs); } } StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) return ObjectStructInfo(lhs->span); - Optional> fields = UnifyArray(lhs->fields, rhs->fields); + ffi::Optional> fields = UnifyArray(lhs->fields, rhs->fields); // tuple length not the same. if (!fields.defined()) return ObjectStructInfo(lhs->span); @@ -1078,7 +1080,7 @@ class StructInfoLCAFinder if (!fields.same_as(lhs->fields)) { return TupleStructInfo(fields.value(), lhs->span); } else { - return GetRef(lhs); + return ffi::GetRef(lhs); } } @@ -1093,7 +1095,7 @@ class StructInfoLCAFinder if (lhs->IsOpaque()) { if (lhs->derive_func.defined()) { if (lhs->derive_func.same_as(rhs->derive_func)) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { // Create a new opaque with object return return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), purity, lhs->span); @@ -1101,7 +1103,7 @@ class StructInfoLCAFinder } else { // no derivation function, only depends on ret StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); - if (ret.same_as(lhs->ret)) return GetRef(lhs); + if (ret.same_as(lhs->ret)) return ffi::GetRef(lhs); return FuncStructInfo::OpaqueFunc(ret, purity, lhs->span); } } @@ -1128,15 +1130,15 @@ class StructInfoLCAFinder // // Given we only do best effort checking in these cases, and such cases // are likely not a primary concern atm, we take this approach here. - if (struct_equal_(GetRef(lhs), GetRef(rhs))) { - return GetRef(lhs); + if (struct_equal_(ffi::GetRef(lhs), ffi::GetRef(rhs))) { + return ffi::GetRef(lhs); } auto params = UnifyArray(lhs->params.value(), rhs->params.value()); auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { - return GetRef(lhs); + return ffi::GetRef(lhs); } else { // fail to unify the params if (!params.defined()) { @@ -1154,8 +1156,8 @@ class StructInfoLCAFinder StructuralEqual struct_equal_; // check arrays - Optional> UnifyArray(const Array& lhs, - const Array& rhs) { + ffi::Optional> UnifyArray(const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.same_as(rhs)) return lhs; if (lhs.size() != rhs.size()) return std::nullopt; size_t index = 0; @@ -1191,7 +1193,7 @@ class TIRVarsDetector : public StructInfoVisitor { }; TIRVarsDetector(VarType collection_type) : collection_type(collection_type) {} - Array GetTIRVars() const { return tir_vars_; } + ffi::Array GetTIRVars() const { return tir_vars_; } private: void VisitPrimExpr(PrimExpr expr) { @@ -1208,7 +1210,7 @@ class TIRVarsDetector : public StructInfoVisitor { } } - void VisitShape(Array shape) { + void VisitShape(ffi::Array shape) { for (const PrimExpr& expr : shape) { VisitPrimExpr(expr); } @@ -1239,19 +1241,19 @@ class TIRVarsDetector : public StructInfoVisitor { } } - Array tir_vars_; + ffi::Array tir_vars_; std::unordered_set used_tir_vars_dedup_; VarType collection_type; }; -Array TIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Usage); detector(sinfo); return detector.GetTIRVars(); } -Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Definition); detector(sinfo); return detector.GetTIRVars(); @@ -1266,7 +1268,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ class NonNegativeExpressionCollector : relax::StructInfoVisitor { public: - static Array Collect(const StructInfo& sinfo) { + static ffi::Array Collect(const StructInfo& sinfo) { NonNegativeExpressionCollector visitor; visitor(sinfo); return visitor.expressions_; @@ -1298,11 +1300,11 @@ class NonNegativeExpressionCollector : relax::StructInfoVisitor { } } - Array expressions_; + ffi::Array expressions_; std::unordered_set dedup_lookup_; }; -Array CollectNonNegativeExpressions(const StructInfo& sinfo) { +ffi::Array CollectNonNegativeExpressions(const StructInfo& sinfo) { return NonNegativeExpressionCollector::Collect(sinfo); } @@ -1316,18 +1318,19 @@ class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { public: - static Array Free(const Expr& expr) { + static ffi::Array Free(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - Array ret{collector.free_symbolic_var_.begin(), collector.free_symbolic_var_.end()}; + ffi::Array ret{collector.free_symbolic_var_.begin(), + collector.free_symbolic_var_.end()}; return ret; } - static Array Defined(const Expr& expr) { + static ffi::Array Defined(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - Array ret{collector.defined_symbolic_var_.begin(), - collector.defined_symbolic_var_.end()}; + ffi::Array ret{collector.defined_symbolic_var_.begin(), + collector.defined_symbolic_var_.end()}; return ret; } @@ -1429,7 +1432,7 @@ class SymbolicVarCollector : public relax::ExprVisitor, } void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = GetRef(op); + tir::Var var = ffi::GetRef(op); // default mode, check defined. if (defined_symbolic_var_.count(var) == 0) { free_symbolic_var_.insert(var); @@ -1452,10 +1455,10 @@ class SymbolicVarCollector : public relax::ExprVisitor, std::unordered_set free_symbolic_var_; }; -Array DefinedSymbolicVars(const Expr& expr) { +ffi::Array DefinedSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Defined(expr); } -Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } +ffi::Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index b6809c0f35bb..0d9e92c17a84 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -35,7 +35,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { public: explicit PatternKindAnalyzer(const tir::PrimFunc& func) { for (const tir::Var& param : func->params) { - Optional param_buf = func->buffer_map.Get(param); + ffi::Optional param_buf = func->buffer_map.Get(param); if (param_buf.defined()) { param_buffers_.insert(param_buf.value()); } @@ -59,12 +59,12 @@ class PatternKindAnalyzer : public StmtExprVisitor { kind_ = kOpaque; return; } - store_ = GetRef(op); + store_ = ffi::GetRef(op); StmtVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode* op) final { - loads_.push_back(GetRef(op)); + loads_.push_back(ffi::GetRef(op)); ExprVisitor::VisitExpr_(op); } @@ -130,7 +130,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { // Step 4. Checking if the block contains reduce axis by looking into block iterators. bool has_reduction = false; - Array reduce_vars; + ffi::Array reduce_vars; for (const IterVar& it : op->iter_vars) { if (it->iter_type == tir::IterVarType::kCommReduce) { has_reduction = true; @@ -162,7 +162,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { /********** Helper Functions **********/ /*! \brief Checking if two arrays contains same elements. */ - static bool IsSameArray(const Array& lhs, const Array& rhs) { + static bool IsSameArray(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -293,8 +293,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { if (!lhs || !rhs) { return false; } - return IsAllowReusePattern(GetRef(store), GetRef(lhs)) && - IsAllowReusePattern(GetRef(store), GetRef(rhs)); + return IsAllowReusePattern(ffi::GetRef(store), + ffi::GetRef(lhs)) && + IsAllowReusePattern(ffi::GetRef(store), ffi::GetRef(rhs)); } } } @@ -308,7 +309,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i] = sum(B[i, j + k]) is not pure reduce * pooling is not pure reduce */ - static bool IsPureReducePattern(Array reduce_loops, Array indices) { + static bool IsPureReducePattern(ffi::Array reduce_loops, ffi::Array indices) { for (const PrimExpr& e : indices) { int id = -1; if (UsesVar(e, [&](const tir::VarNode* var) { @@ -333,9 +334,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * \brief The BufferStore node in the current block. * \note We only support one BufferStore node in a block (usually generated by TE compute) */ - Optional store_; + ffi::Optional store_; /*! \brief The BufferLoad nodes in the current block. */ - Array loads_; + ffi::Array loads_; /*! \brief The result of op pattern. */ OpPatternKind kind_ = kElemWise; /*! \brief The buffers from function params. I.e. the input and output buffers. */ @@ -379,8 +380,8 @@ bool HasReshapePattern(const PrimFunc& func) { // binding values. The mapping will be used in the substitution of // the flattened buffer access index. const Block& block = block_realize->block; - const Array& block_iter = block->iter_vars; - const Array& iter_values = block_realize->iter_values; + const ffi::Array& block_iter = block->iter_vars; + const ffi::Array& iter_values = block_realize->iter_values; ICHECK_EQ(block_iter.size(), iter_values.size()); int n_iter = block_iter.size(); for (int i = 0; i < n_iter; ++i) { @@ -401,7 +402,7 @@ bool HasReshapePattern(const PrimFunc& func) { return; } - Map var_range; + ffi::Map var_range; for (const IterVar& v : block->iter_vars) { ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); @@ -429,7 +430,7 @@ bool HasReshapePattern(const PrimFunc& func) { // This check requires at least one of the src/dst side is a trivial buffer // access (e.g., buf[ax0, ax1, ax2]). - auto f_calc_flattened_idx = [&](const Buffer& buffer, const Array& indices) { + auto f_calc_flattened_idx = [&](const Buffer& buffer, const ffi::Array& indices) { ICHECK_EQ(indices.size(), buffer->shape.size()); int ndim = indices.size(); PrimExpr idx = 0; @@ -447,7 +448,7 @@ bool HasReshapePattern(const PrimFunc& func) { }; auto f_is_trivial_indices = [block, this](const Buffer& buffer, - const Array& indices) { + const ffi::Array& indices) { if (indices.size() != block->iter_vars.size()) { return false; } @@ -462,7 +463,7 @@ bool HasReshapePattern(const PrimFunc& func) { return true; }; - Array nontrivial_indices{nullptr}; + ffi::Array nontrivial_indices{nullptr}; Buffer nontrivial_buffer{nullptr}; if (f_is_trivial_indices(dst_buffer_, buffer_store->indices)) { nontrivial_indices = buffer_load->indices; @@ -476,7 +477,7 @@ bool HasReshapePattern(const PrimFunc& func) { DataType dtype = !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64); tir::Var fused_var("fused", dtype); - Map inverse_indices_map; + ffi::Map inverse_indices_map; PrimExpr stride = IntImm(dtype, /*value=*/1); for (int i = static_cast(block->iter_vars.size()) - 1; i >= 0; --i) { inverse_indices_map.Set( @@ -487,7 +488,7 @@ bool HasReshapePattern(const PrimFunc& func) { PrimExpr flattened_idx = f_calc_flattened_idx(nontrivial_buffer, nontrivial_indices); flattened_idx = Substitute(std::move(flattened_idx), inverse_indices_map); - Array simplify_res = arith::IterMapSimplify( + ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, /*input_pred=*/Bool(true), @@ -519,7 +520,7 @@ bool HasReshapePattern(const PrimFunc& func) { arith::Analyzer ana_; }; - Array buffer_args; + ffi::Array buffer_args; for (const auto& param : func->params) { if (auto buffer = func->buffer_map.Get(param)) { buffer_args.push_back(buffer.value()); diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 6ec8dcfb5769..0045753ff619 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -44,23 +44,23 @@ class UDChain : relax::ExprVisitor { UDChain visitor; visitor.VisitExpr(expr); - Array output(visitor.outputs.begin(), visitor.outputs.end()); + ffi::Array output(visitor.outputs.begin(), visitor.outputs.end()); - Map> use_def; + ffi::Map> use_def; for (const auto& [var, usage] : visitor.usage_map) { - use_def.Set(var, Array(usage.begin(), usage.end())); + use_def.Set(var, ffi::Array(usage.begin(), usage.end())); } return VarUsageInfo{visitor.bound_values, use_def, output}; } private: - Map bound_values; + ffi::Map bound_values; std::unordered_set forward_declarations; std::unordered_map> usage_map; support::OrderedSet outputs; - Optional cur_user_; + ffi::Optional cur_user_; void VisitBinding_(const VarBindingNode* binding) override { CHECK(!bound_values.count(binding->var)) @@ -89,7 +89,7 @@ class UDChain : relax::ExprVisitor { } } void VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (cur_user_) { usage_map[var].insert(cur_user_.value()); @@ -109,13 +109,13 @@ class UDChain : relax::ExprVisitor { } }; -std::pair>, Array> FunctionUseDef(const Expr& fn) { +std::pair>, ffi::Array> FunctionUseDef(const Expr& fn) { auto usage = UDChain::Collect(fn); return {usage.downstream_usage, usage.outputs}; } -Map> DataflowBlockUseDef(const DataflowBlock& dfb) { - auto usage = UDChain::Collect(SeqExpr({dfb}, Tuple(Array()))); +ffi::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { + auto usage = UDChain::Collect(SeqExpr({dfb}, Tuple(ffi::Array()))); return usage.downstream_usage; } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index 1f28ba9edbf7..3a8a5c0ce80a 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -26,7 +26,7 @@ namespace tvm { namespace relax { class Var2ValAnalysis : public relax::ExprVisitor { public: - Map var2value_; + ffi::Map var2value_; void VisitBinding_(const VarBindingNode* binding) override { var2value_.Set(binding->var, binding->value); // Recursively visit the value to handle local functions. @@ -34,25 +34,25 @@ class Var2ValAnalysis : public relax::ExprVisitor { } }; -Map AnalyzeVar2Value(const Expr& expr) { +ffi::Map AnalyzeVar2Value(const Expr& expr) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitExpr(expr); return std::move(var2val_analysis.var2value_); } -Map AnalyzeVar2Value(const DataflowBlock& dfb) { +ffi::Map AnalyzeVar2Value(const DataflowBlock& dfb) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitBindingBlock_(dfb.get()); return std::move(var2val_analysis.var2value_); } -Map AnalyzeVar2Value(const IRModule& m) { +ffi::Map AnalyzeVar2Value(const IRModule& m) { Var2ValAnalysis var2val_analysis; for (const auto& it : m->functions) { // visit relax.Function if (auto* n = it.second.as()) { - var2val_analysis.VisitExpr(GetRef(n)); + var2val_analysis.VisitExpr(ffi::GetRef(n)); } } @@ -69,23 +69,24 @@ class Name2BindingAnalysis : public relax::ExprVisitor { public: // Map is not suitable for doing in-place update. // so we use standard container for internal usage. - std::map> name2bindings_; + std::map> name2bindings_; void VisitBinding_(const VarBindingNode* binding) override { const auto& vname = binding->var->name_hint(); - name2bindings_[vname].push_back(GetRef(binding)); + name2bindings_[vname].push_back(ffi::GetRef(binding)); } void VisitBinding_(const MatchCastNode* binding) override { const auto& vname = binding->var->name_hint(); - name2bindings_[vname].push_back(GetRef(binding)); + name2bindings_[vname].push_back(ffi::GetRef(binding)); } }; -Map> NameToBinding(const Function& fn) { +ffi::Map> NameToBinding(const Function& fn) { Name2BindingAnalysis analysis{}; analysis.VisitExpr_(fn.get()); - return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), - std::make_move_iterator(analysis.name2bindings_.end())); + return ffi::Map>( + std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index a1bc99ee75bf..14694b31f4da 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -86,7 +86,7 @@ class WellFormedChecker : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { public: - static bool Check(Variant obj, bool check_struct_info) { + static bool Check(ffi::Variant obj, bool check_struct_info) { WellFormedChecker well_formed_checker = WellFormedChecker(obj.as(), check_struct_info); @@ -94,13 +94,13 @@ class WellFormedChecker : public relax::ExprVisitor, for (const auto& it : mod->functions) { // visit relax.Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); well_formed_checker.VisitExpr(func); } } } else if (const auto* func = obj.as()) { - well_formed_checker.VisitExpr(GetRef(func)); + well_formed_checker.VisitExpr(ffi::GetRef(func)); } else { LOG(FATAL) << "Unreachable, " << "variant did not contain any of the allowed types"; @@ -109,7 +109,7 @@ class WellFormedChecker : public relax::ExprVisitor, } private: - WellFormedChecker(Optional mod, bool check_struct_info) + WellFormedChecker(ffi::Optional mod, bool check_struct_info) : mod_(std::move(mod)), check_struct_info_(check_struct_info), cur_visited_func_(nullptr) {} using relax::ExprVisitor::VisitExpr_; @@ -139,7 +139,7 @@ class WellFormedChecker : public relax::ExprVisitor, // to check again // check name in global var and gsymbol - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol != var->name_hint) { Malformed(Diagnostic::Error(func->span) << "Name in GlobalVar is not equal to name in gsymbol: " << var @@ -155,18 +155,20 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const GlobalVarNode* op) final { - GlobalVar var = GetRef(op); + GlobalVar var = ffi::GetRef(op); if (mod_.defined()) { if (!(mod_.value()->ContainGlobalVar(var->name_hint) && mod_.value()->GetGlobalVar(var->name_hint).same_as(var))) { - Malformed(Diagnostic::Error(var) << "GlobalVar " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) + << "GlobalVar " << ffi::GetRef(op) << " is not defined."); } } if (op->struct_info_.defined()) { if (!op->struct_info_->IsInstance()) { - Malformed(Diagnostic::Error(var) << "The struct_info_ of GlobalVar " << GetRef(op) - << " must be either FuncStructInfo."); + Malformed(Diagnostic::Error(var) + << "The struct_info_ of GlobalVar " << ffi::GetRef(op) + << " must be either FuncStructInfo."); } } @@ -198,21 +200,22 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { - Malformed(Diagnostic::Error(var) << "Var " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) << "Var " << ffi::GetRef(op) << " is not defined."); } CheckStructInfo(op); } void VisitExpr_(const DataflowVarNode* op) final { - DataflowVar var = GetRef(op); + DataflowVar var = ffi::GetRef(op); if (!is_dataflow_) { Malformed(Diagnostic::Error(var) - << "DataflowVar " << GetRef(op) << " is used outside DataflowBlock."); + << "DataflowVar " << ffi::GetRef(op) << " is used outside DataflowBlock."); } if (dataflow_var_set_.count(var) == 0) { - Malformed(Diagnostic::Error(var) << "DataflowVar " << GetRef(op) << " is not defined."); + Malformed(Diagnostic::Error(var) + << "DataflowVar " << ffi::GetRef(op) << " is not defined."); } CheckStructInfo(op); } @@ -244,8 +247,8 @@ class WellFormedChecker : public relax::ExprVisitor, // ensure the purity attributes are valid if (op->GetAttr(relax::attr::kForcePure).value_or(false) && !op->is_pure) { Malformed(Diagnostic::Error(op->span) - << "Function " << GetRef(op) << " has true for " << relax::attr::kForcePure - << " but false for is_pure; " << relax::attr::kForcePure + << "Function " << ffi::GetRef(op) << " has true for " + << relax::attr::kForcePure << " but false for is_pure; " << relax::attr::kForcePure << " should be true only if is_pure is also true."); } @@ -318,7 +321,7 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(call); if (is_dataflow_ && check_struct_info_) { - if (auto impure = FindImpureCall(GetRef(call))) { + if (auto impure = FindImpureCall(ffi::GetRef(call))) { Malformed(Diagnostic::Error(call) << "Impure function call " << impure << " occurs within a dataflow block."); } @@ -331,8 +334,8 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) { auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); - Call before_normalize = GetRef(call); - Optional after_normalize = std::nullopt; + Call before_normalize = ffi::GetRef(call); + ffi::Optional after_normalize = std::nullopt; try { after_normalize = func_normalize(dummy_builder, before_normalize); } catch (std::exception& err) { @@ -355,7 +358,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) { try { - func_validate(GetRef(call)); + func_validate(ffi::GetRef(call)); } catch (std::exception& err) { Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for " << call->op << " identified error: \n" @@ -369,13 +372,13 @@ class WellFormedChecker : public relax::ExprVisitor, // an expression that does not yet have `StructInfo`. auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); Call copied(call->op, call->args, call->attrs, call->sinfo_args); - Optional normalized = std::nullopt; + ffi::Optional normalized = std::nullopt; try { normalized = dummy_builder->Normalize(copied); } catch (std::exception& err) { Malformed(Diagnostic::Error(call) << "Each Relax expression must be able to have its StructInfo inferred. " - << "However, inferring the struct info of expression " << GetRef(call) + << "However, inferring the struct info of expression " << ffi::GetRef(call) << " resulted in the error: \n" << err.what()); } @@ -400,8 +403,9 @@ class WellFormedChecker : public relax::ExprVisitor, BaseCheckResult::kFailL1) { Malformed(Diagnostic::Error(call) << "All information in StructInfo annotations must be correct. " - << "However, while the expression " << GetRef(call) << " is annotated as " - << current_struct_info << ", the expression outputs " << inferred_struct_info); + << "However, while the expression " << ffi::GetRef(call) + << " is annotated as " << current_struct_info << ", the expression outputs " + << inferred_struct_info); } } } @@ -513,7 +517,7 @@ class WellFormedChecker : public relax::ExprVisitor, Malformed(Diagnostic::Error(var) << "DataflowVar " << var << " is defined outside DataflowBlock."); } - DataflowVar lv = GetRef(var); + DataflowVar lv = ffi::GetRef(var); if (dataflow_var_set_.count(lv) == 1) { Malformed(Diagnostic::Error(var) << "DataflowVar " << lv << " is defined more than once."); } @@ -523,7 +527,7 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitVarDef_(const VarNode* var) final { - Var gv = GetRef(var); + Var gv = ffi::GetRef(var); if (var_set_.count(gv) == 1) { Malformed(Diagnostic::Error(var) << "Var " << gv << " is defined more than once."); } @@ -533,7 +537,7 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = GetRef(op); + tir::Var var = ffi::GetRef(op); // default mode, check defined. if (symbolic_var_set_.count(var) == 0) { this->Malformed(Diagnostic::Error(var) << "Symbolic Var " << var << " is not defined."); @@ -571,7 +575,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (var_set_.count(var) == 0) { var_set_.insert(var); } @@ -590,7 +594,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence if (auto* op = expr.as()) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (symbolic_var_set_.count(var) == 0) { symbolic_var_set_.insert(var); } @@ -607,7 +611,7 @@ class WellFormedChecker : public relax::ExprVisitor, auto* sinfo = op->struct_info_.as(); if (sinfo != nullptr) { - this->VisitStructInfo(GetRef(sinfo)); + this->VisitStructInfo(ffi::GetRef(sinfo)); } else { Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " << " Expr.type_key=" << op->GetTypeKey()); @@ -622,7 +626,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::swap(mode_, mode); } - Optional mod_; + ffi::Optional mod_; const bool check_struct_info_; bool well_formed_ = true; bool is_dataflow_; @@ -642,7 +646,7 @@ class WellFormedChecker : public relax::ExprVisitor, tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; -bool WellFormed(Variant obj, bool check_struct_info) { +bool WellFormed(ffi::Variant obj, bool check_struct_info) { return WellFormedChecker::Check(obj, check_struct_info); } diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index b25bfbdb22a7..8103d2a3140d 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -113,7 +113,8 @@ class CollectCLMLFromCompositeFunctionBody : public ExprVisitor { */ class OpenCLMLJSONSerializer : public JSONSerializer { public: - explicit OpenCLMLJSONSerializer(Map constant_names, Map bindings) + explicit OpenCLMLJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} /*! @@ -135,9 +136,9 @@ class OpenCLMLJSONSerializer : public JSONSerializer { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -177,7 +178,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; } - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } /*! @@ -191,8 +192,8 @@ class OpenCLMLJSONSerializer : public JSONSerializer { const auto* fn_var = cn->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); nodes.pad = backend::TryGetOpInFunction(fn, "relax.nn.pad"); @@ -220,8 +221,8 @@ class OpenCLMLJSONSerializer : public JSONSerializer { const auto* fn_var = cn->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -292,11 +293,11 @@ class OpenCLMLJSONSerializer : public JSONSerializer { private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { - for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + for (const auto& entry : serializer_->VisitExpr(ffi::GetRef(constant_node))) { args_.emplace_back(entry); } } @@ -311,9 +312,10 @@ void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) * \param functions The extern functions to be compiled via OpenCLML * \return Runtime modules. */ -Array OpenCLMLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array OpenCLMLCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "OpenCLML partition:" << std::endl << func; OpenCLMLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 611e63de8954..3c6469423890 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -47,7 +47,7 @@ struct GenerateBodyOutput { std::string decl; std::vector buffers; std::vector outputs; - Array headers; + ffi::Array headers; }; // The base class to generate the declaration functions in C. @@ -115,7 +115,7 @@ class CodegenCBase { * * \code * - * Array foo_consts; + * ffi::Array foo_consts; * * // An example code for the generated C function. * int foo_wrapper_(DLTensor* arg0, @@ -129,7 +129,7 @@ class CodegenCBase { * * TVM_FFI_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * - * int foo_init_wrapper_(Array arr) { + * int foo_init_wrapper_(ffi::Array arr) { * foo_consts = arr; * return 0; * } @@ -220,7 +220,7 @@ class CodegenCBase { // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "int " << func_name - << "_init_wrapper_(tvm::Array arr) {\n"; + << "_init_wrapper_(tvm::ffi::Array arr) {\n"; EnterScope(); PrintIndents(); code_stream_ << func_name << "_consts = arr;\n"; @@ -233,7 +233,7 @@ class CodegenCBase { } } - void GenerateBackendCFunc(const std::string& func_name, const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const ffi::Array& args, const std::string& const_arr_name, const std::vector& outs, bool pass_dl_tensor = false) { std::vector arg_types; @@ -266,7 +266,7 @@ class CodegenCBase { * * \return The emitted code string. */ - std::string JitImpl(const std::string& ext_func_id, const Array& args, + std::string JitImpl(const std::string& ext_func_id, const ffi::Array& args, const std::vector& buf_decl, const std::vector& body, const std::string& const_arr_name, const std::vector& outs) { @@ -390,7 +390,7 @@ class CodegenCBase { * \return The created declaration */ std::string CreateTensorPool(const std::string& symbol) const { - return "tvm::Array " + symbol + "_consts;"; + return "tvm::ffi::Array " + symbol + "_consts;"; } /*! diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 1ea03a63c0dc..505696254209 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -87,7 +87,7 @@ class OpAttrExtractor { void Visit(const char* key, std::string* value) { SetNodeAttr(key, {*value}); } - void Visit(const char* key, Optional* value) { + void Visit(const char* key, ffi::Optional* value) { if (value->has_value()) { SetNodeAttr(key, {Fp2String(value->value())}); } else { @@ -95,7 +95,7 @@ class OpAttrExtractor { } } - void Visit(const char* key, Optional* value) { + void Visit(const char* key, ffi::Optional* value) { if (value->has_value()) { SetNodeAttr(key, {std::to_string(value->value())}); } else { @@ -119,7 +119,7 @@ class OpAttrExtractor { attr.push_back(std::to_string(im->value)); } else if (const auto* fm = (*an)[i].as()) { attr.push_back(Fp2String(fm->value)); - } else if (auto opt_str = (*an)[i].as()) { + } else if (auto opt_str = (*an)[i].as()) { attr.push_back(*opt_str); } else { LOG(FATAL) << "Not supported type: " << (*an)[i].GetTypeKey(); @@ -201,7 +201,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { * \brief Constructor * \param constant_names The names of all constants in the original module. */ - explicit JSONSerializer(const Map& constant_names) + explicit JSONSerializer(const ffi::Map& constant_names) : constant_names_(constant_names) {} void serialize(Function func) { @@ -214,7 +214,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } /*!\brief Return the required constants. */ - Array GetConstantNames() const { return constants_used_; } + ffi::Array GetConstantNames() const { return constants_used_; } /*!\brief Return the generated json. */ std::string GetJSON() { @@ -284,7 +284,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { extractor.Extract(const_cast(call_attr)); } else if (const auto* fn = cn->op.as()) { ICHECK(false); - auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); ICHECK(pattern.has_value()); std::vector values; values.push_back(pattern.value()); @@ -361,12 +361,12 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const ConstantNode* cn) { - auto name = constant_names_.find(GetRef(cn)); + auto name = constant_names_.find(ffi::GetRef(cn)); ICHECK(name != constant_names_.end()) - << "Cannot find the name of the constant: " << GetRef(cn); + << "Cannot find the name of the constant: " << ffi::GetRef(cn); constants_used_.push_back((*name).second); auto node = std::make_shared((*name).second, "const" /* op_type_ */); - return AddNode(node, GetRef(cn)); + return AddNode(node, ffi::GetRef(cn)); } NodeEntries VisitExpr_(const TupleNode* tn) { @@ -379,12 +379,12 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const CallNode* cn) { - Expr expr = GetRef(cn); + Expr expr = ffi::GetRef(cn); std::string name; if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.has_value()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { @@ -404,7 +404,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, cn); - return AddNode(node, GetRef(cn)); + return AddNode(node, ffi::GetRef(cn)); } NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { @@ -413,7 +413,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. @@ -453,9 +453,9 @@ class JSONSerializer : public relax::MemoizedExprTranslator { /*! \brief Output of the JSON graph. */ NodeEntries heads_; /*! \brief The list of required constants, ordered. */ - Array constants_used_; + ffi::Array constants_used_; /*! \brief The names of all constants in the original module. */ - const Map& constant_names_; + const ffi::Map& constant_names_; }; } // namespace contrib diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 0cd0150970e6..c403cac30696 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -41,7 +41,7 @@ using backend::contrib::NodeEntries; class CublasJSONSerializer : public JSONSerializer { public: - CublasJSONSerializer(Map constant_names, Map bindings) + CublasJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -49,10 +49,10 @@ class CublasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -101,17 +101,18 @@ class CublasJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array CublasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array CublasCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index a0201ccfda77..b612a9aa3b02 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -40,7 +40,7 @@ using backend::contrib::NodeEntries; class cuDNNJSONSerializer : public JSONSerializer { public: - cuDNNJSONSerializer(Map constant_names, Map bindings) + cuDNNJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +48,10 @@ class cuDNNJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -89,7 +89,7 @@ class cuDNNJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } NodeEntries HandleAttention(const CallNode* call_node, const Function& fn, @@ -125,17 +125,18 @@ class cuDNNJSONSerializer : public JSONSerializer { node->SetAttr("head_size", to_str_array(head_size)); node->SetAttr("head_size_v", to_str_array(head_size_v)); node->SetAttr("layout", std::vector{std::vector{layout}}); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array cuDNNCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array cuDNNCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { cuDNNJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 29ad2de412d8..dcfcc77b989a 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -55,7 +55,7 @@ std::string EmitSignature(const std::vector& out, const std::string& fun return code_stream_.str(); } -ffi::Module Finalize(const std::string& code, const Array& func_names) { +ffi::Module Finalize(const std::string& code, const ffi::Array& func_names) { ICHECK(!func_names.empty()) << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; @@ -71,14 +71,14 @@ ffi::Module Finalize(const std::string& code, const Array& func_names) { const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.CSourceModuleCreate"); VLOG(1) << "Generated CUTLASS code:" << std::endl << code; return pf(default_headers.str() + code, "cu", func_names, - /*const_vars=*/Array()) + /*const_vars=*/ffi::Array()) .cast(); } class CodegenResultNode : public Object { public: - String code; - Array headers; + ffi::String code; + ffi::Array headers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -93,8 +93,8 @@ class CodegenResultNode : public Object { class CodegenResult : public ObjectRef { public: - CodegenResult(String code, Array headers) { - auto n = make_object(); + CodegenResult(ffi::String code, ffi::Array headers) { + auto n = ffi::make_object(); n->code = std::move(code); n->headers = std::move(headers); data_ = std::move(n); @@ -107,15 +107,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ CodegenResultNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("contrib.cutlass.CodegenResult", [](String code, Array headers) { - return CodegenResult(code, headers); - }); + refl::GlobalDef().def("contrib.cutlass.CodegenResult", + [](ffi::String code, ffi::Array headers) { + return CodegenResult(code, headers); + }); }); GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& ext_func_id, const std::vector& output_types, - const Array& func_args, const Map& attrs, - int* buf_idx) { + const ffi::Array& func_args, + const ffi::Map& attrs, int* buf_idx) { // Make function call with input buffers when visiting arguements ICHECK_GT(func_args.size(), 0); std::ostringstream decl_stream; @@ -150,7 +151,7 @@ using OutputType = std::vector; class CodegenCutlass : public relax::MemoizedExprTranslator, public relax::contrib::CodegenCBase { public: - CodegenCutlass(const std::string& id, const Map& bindings) + CodegenCutlass(const std::string& id, const ffi::Map& bindings) : ext_func_id_(id), bindings_(bindings) {} void AddParm(Var param) { @@ -195,7 +196,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, return code_stream_.str(); } - Array GetHeaders() { return headers_; } + ffi::Array GetHeaders() { return headers_; } protected: OutputType VisitExpr_(const VarNode* node) final { @@ -209,8 +210,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, OutputType VisitExpr_(const CallNode* call) final { const auto* fn_var = call->op.as(); ICHECK(fn_var); - const auto func = Downcast(bindings_[GetRef(fn_var)]); - const auto pattern_name_opt = func->GetAttr(attr::kComposite); + const auto func = Downcast(bindings_[ffi::GetRef(fn_var)]); + const auto pattern_name_opt = func->GetAttr(attr::kComposite); ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); ext_func_body_.push_back(ret.decl); @@ -219,7 +220,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } OutputType VisitExpr_(const FunctionNode* fn) final { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; @@ -282,8 +283,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } private: - Array GetArgumentNames(const CallNode* call) { - Array arg_names; + ffi::Array GetArgumentNames(const CallNode* call) { + ffi::Array arg_names; for (size_t i = 0; i < call->args.size(); ++i) { auto res = VisitExpr(call->args[i]); for (const auto& out : res) { @@ -294,9 +295,9 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } GenerateBodyOutput GenerateBody(const CallNode* call, const std::string& func_name, - const Map& attrs) { + const ffi::Map& attrs) { auto func_args = GetArgumentNames(call); - auto struct_info = GetStructInfo(GetRef(call)); + auto struct_info = GetStructInfo(ffi::GetRef(call)); std::vector out_types; if (const auto* tensor_sinfo = struct_info.as()) { @@ -316,15 +317,15 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, */ int buf_idx_{0}; /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ - Array ext_func_args_; + ffi::Array ext_func_args_; /*! \brief The statements of the function that will be compiled using CUTLASS kernels. */ std::vector ext_func_body_; /*! \brief The declaration of intermediate buffers. */ std::vector buf_decl_; /*! \brief The binding to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; /*! \brief Required header-file names. */ - Array headers_; + ffi::Array headers_; /*! * \brief A mapping from a variable to its unique name. * We use this since sometimes different parameters to the same function end up having the same @@ -337,7 +338,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, class CutlassModuleCodegen { public: - ffi::Module CreateCSourceModule(Array functions, const Map& options) { + ffi::Module CreateCSourceModule(ffi::Array functions, + const ffi::Map& options) { std::string headers = ""; std::string code = ""; for (const auto& f : functions) { @@ -351,8 +353,8 @@ class CutlassModuleCodegen { } private: - std::pair> GenCutlassFunc(const Function& function, - const Map& options) { + std::pair> GenCutlassFunc( + const Function& function, const ffi::Map& options) { ICHECK(function.defined()) << "Input error: expect a Relax function."; auto sid = GetExtSymbol(function); @@ -369,17 +371,18 @@ class CutlassModuleCodegen { } /*! \brief The accumulated function names. */ - Array func_names_; + ffi::Array func_names_; }; -Array CUTLASSCompiler(Array functions, Map options, - Map /*unused*/) { +ffi::Array CUTLASSCompiler(ffi::Array functions, + ffi::Map options, + ffi::Map /*unused*/) { const auto tune_func = tvm::ffi::Function::GetGlobal("contrib.cutlass.tune_relax_function"); ICHECK(tune_func.has_value()) << "The packed function contrib.cutlass.tune_relax_function not found, " "please import tvm.contrib.cutlass.build"; - auto annotated_functions = (*tune_func)(functions, options).cast>(); + auto annotated_functions = (*tune_func)(functions, options).cast>(); auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); const auto pf = tvm::ffi::Function::GetGlobal("contrib.cutlass.compile"); diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc index efa4e1b685c7..6db5ae7dd628 100644 --- a/src/relax/backend/contrib/dnnl/codegen.cc +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -40,7 +40,7 @@ using backend::contrib::NodeEntries; class DNNLJSONSerializer : public JSONSerializer { public: - DNNLJSONSerializer(Map constant_names, Map bindings) + DNNLJSONSerializer(ffi::Map constant_names, ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +48,10 @@ class DNNLJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -73,17 +73,18 @@ class DNNLJSONSerializer : public JSONSerializer { } SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array DNNLCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array DNNLCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index e1104ac3d6c7..872ac23c5909 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -40,7 +40,8 @@ using backend::contrib::NodeEntries; class HipblasJSONSerializer : public JSONSerializer { public: - HipblasJSONSerializer(Map constant_names, Map bindings) + HipblasJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -48,10 +49,10 @@ class HipblasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -78,17 +79,18 @@ class HipblasJSONSerializer : public JSONSerializer { const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul"); SetCallNodeAttribute(node, root_call); - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; -Array HipblasCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array HipblasCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { HipblasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index f045e5b9c2c0..37f16ebf1493 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -190,17 +190,18 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { class NNAPIJSONSerializer : public JSONSerializer { public: - explicit NNAPIJSONSerializer(Map constant_names, Map bindings) + explicit NNAPIJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; std::vector VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); ICHECK(fn.defined()) << "Expects the callee to be a function."; - auto composite_opt = fn->GetAttr(attr::kComposite); + auto composite_opt = fn->GetAttr(attr::kComposite); ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -221,11 +222,11 @@ class NNAPIJSONSerializer : public JSONSerializer { VLOG(1) << "Adding node " << composite_name << " with " << node->GetInputs().size() << " inputs"; - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } private: - Map bindings_; + ffi::Map bindings_; }; void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { @@ -247,11 +248,12 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { ExprVisitor::VisitExpr_(call_node); } -Array NNAPICompiler(Array functions, Map /*unused*/, - Map constant_names) { +ffi::Array NNAPICompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { VLOG(1) << "NNAPI Compiler"; - Array compiled_functions; + ffi::Array compiled_functions; for (const auto& func : functions) { NNAPIJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 6dd8216469c2..73a10bec187b 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -46,7 +46,7 @@ namespace contrib { /*! \brief Attributes to store the compiler options for TensorRT. */ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter { - Array tensorrt_version; + ffi::Array tensorrt_version; bool use_implicit_batch; size_t max_workspace_size; bool remove_no_mac_subgraphs; @@ -58,7 +58,7 @@ struct TensorRTCompilerConfigNode : public AttrsNodeReflAdapter() .def_ro("tensorrt_version", &TensorRTCompilerConfigNode::tensorrt_version, "TensorRT version as (major, minor, patch).", - refl::DefaultValue(Array({6, 0, 1}))) + refl::DefaultValue(ffi::Array({6, 0, 1}))) .def_ro("use_implicit_batch", &TensorRTCompilerConfigNode::use_implicit_batch, "Use implicit batch", refl::DefaultValue(true)) .def_ro("max_workspace_size", &TensorRTCompilerConfigNode::max_workspace_size, @@ -128,7 +128,8 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { */ class TensorRTJSONSerializer : public JSONSerializer { public: - explicit TensorRTJSONSerializer(Map constant_names, Map bindings) + explicit TensorRTJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) : JSONSerializer(constant_names), bindings_(bindings) {} using JSONSerializer::VisitExpr_; @@ -137,9 +138,9 @@ class TensorRTJSONSerializer : public JSONSerializer { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); ICHECK(fn_var); - const auto fn = Downcast(bindings_[GetRef(fn_var)]); + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - auto opt_composite = fn->GetAttr(attr::kComposite); + auto opt_composite = fn->GetAttr(attr::kComposite); ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); @@ -172,7 +173,7 @@ class TensorRTJSONSerializer : public JSONSerializer { VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; - return AddNode(node, GetRef(call_node)); + return AddNode(node, ffi::GetRef(call_node)); } static void SaveGlobalAttributes(std::shared_ptr node) { @@ -206,11 +207,11 @@ class TensorRTJSONSerializer : public JSONSerializer { private: /*! \brief The bindings to look up composite functions. */ - Map bindings_; + ffi::Map bindings_; }; void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { - for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + for (const auto& entry : serializer_->VisitExpr(ffi::GetRef(constant_node))) { args_.emplace_back(entry); } } @@ -225,9 +226,10 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * \param functions The extern functions to be compiled via TensorRT * \return Runtime modules. */ -Array TensorRTCompiler(Array functions, Map /*unused*/, - Map constant_names) { - Array compiled_functions; +ffi::Array TensorRTCompiler(ffi::Array functions, + ffi::Map /*unused*/, + ffi::Map constant_names) { + ffi::Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "TensorRT partition:" << std::endl << func; TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -265,7 +267,7 @@ inline constexpr bool IsTensorRTRuntimeEnabled() { * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph * runtime is not enabled. */ -Array GetTensorRTVersion() { +ffi::Array GetTensorRTVersion() { #if TVM_GRAPH_EXECUTOR_TENSORRT return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; #else diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index b555d1fc0f74..3855c67702ff 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -31,8 +31,8 @@ namespace tvm { namespace relax { namespace backend { -Map ExtractArgIdx(String pattern_name, Function f) { - Map arg_idx; +ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f) { + ffi::Map arg_idx; auto pattern = backend::GetPattern(pattern_name); ICHECK(pattern) << "Unsupported op_type " << pattern_name; @@ -44,7 +44,7 @@ Map ExtractArgIdx(String pattern_name, Function f) { << "\", expected to find a match for " << pattern.value()->pattern << ". However, the function did not include this pattern " << f; - auto find_index = [](const Array& params, Var v) -> std::optional { + auto find_index = [](const ffi::Array& params, Var v) -> std::optional { for (size_t i = 0; i < params.size(); ++i) { if (params[i] == v) { return i; @@ -56,7 +56,7 @@ Map ExtractArgIdx(String pattern_name, Function f) { for (const auto& [name, pat] : pattern.value()->annotation_patterns) { auto exp = matched_expr.value()[pat]; if (auto arg_var = exp.as()) { - if (auto idx = find_index(f->params, GetRef(arg_var))) { + if (auto idx = find_index(f->params, ffi::GetRef(arg_var))) { arg_idx.Set(name, IntImm(DataType::Int(64), *idx)); } } diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index bbff798b8623..e1bcfd0aee1e 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -43,7 +43,7 @@ namespace backend { * \return The converted shape in std::vector */ -inline std::vector GetIntShape(const Array& shape) { +inline std::vector GetIntShape(const ffi::Array& shape) { std::vector ret; for (const auto& dim : shape) { const int64_t* pval = tir::as_const_int(dim); @@ -71,7 +71,7 @@ inline std::string DType2String(const tvm::DataType dtype) { inline bool IsOp(const CallNode* call, const std::string& op_name) { const auto* op_node = call->op.as(); if (!op_node) return false; - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); return op == Op::Get(op_name); } @@ -116,12 +116,12 @@ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { * \return A mapping between variable pattern names and their positions in the partitioned * function parameter list. */ -Map ExtractArgIdx(String pattern_name, Function f); +ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f); /*! * \brief Converts a numeric value to std::string. * \param value A numeric value to convert. - * \return String representation of a numeric value. + * \return ffi::String representation of a numeric value. */ template std::string to_str(const Type& value) { diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 6689aca2f9f4..fe6ef60073d6 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -31,15 +31,15 @@ static std::vector* GetRegistryTable() { return &table; } -void RegisterPatterns(Array entries) { +void RegisterPatterns(ffi::Array entries) { auto* table = GetRegistryTable(); for (const auto& entry : entries) { table->push_back(entry); } } -void RemovePatterns(Array names) { - std::unordered_set name_set{names.begin(), names.end()}; +void RemovePatterns(ffi::Array names) { + std::unordered_set name_set{names.begin(), names.end()}; auto* table = GetRegistryTable(); table->erase( @@ -48,9 +48,9 @@ void RemovePatterns(Array names) { table->end()); } -Array GetPatternsWithPrefix(const String& prefix) { +ffi::Array GetPatternsWithPrefix(const ffi::String& prefix) { auto* table = GetRegistryTable(); - Array result; + ffi::Array result; for (auto it = table->rbegin(); it != table->rend(); ++it) { if (support::StartsWith((*it)->name, prefix.data())) { result.push_back(*it); @@ -59,7 +59,7 @@ Array GetPatternsWithPrefix(const String& prefix) { return result; } -Optional GetPattern(const String& pattern_name) { +ffi::Optional GetPattern(const ffi::String& pattern_name) { auto* table = GetRegistryTable(); for (auto it = table->rbegin(); it != table->rend(); ++it) { if ((*it)->name == pattern_name) { diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h index 2c1f385a2dda..72956c33d625 100644 --- a/src/relax/backend/pattern_registry.h +++ b/src/relax/backend/pattern_registry.h @@ -44,27 +44,27 @@ using transform::FusionPattern; * \param patterns Patterns to be registered. Patterns that appear later in the list have * higher priority when partitioning DataflowBlock. */ -void RegisterPatterns(Array patterns); +void RegisterPatterns(ffi::Array patterns); /*! * \brief Remove patterns from the registry by their name. * \param names The name of patterns to be removed */ -void RemovePatterns(Array names); +void RemovePatterns(ffi::Array names); /*! * \brief Find patterns whose name starts with a particular prefix. * \param prefx The pattern name prefix. * \return Matched patterns, ordered by priority from high to low. */ -Array GetPatternsWithPrefix(const String& prefix); +ffi::Array GetPatternsWithPrefix(const ffi::String& prefix); /*! * \brief Find the pattern with a particular name. * \param name The pattern name. * \return The matched pattern. std::nullopt if not found. */ -Optional GetPattern(const String& name); +ffi::Optional GetPattern(const ffi::String& name); } // namespace backend } // namespace relax diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index b0571913049c..97dd75945ce5 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -67,15 +67,16 @@ class BlockCounter : public tir::StmtVisitor { class TaskExtractor : public ExprVisitor { public: - static Array ExtractTask(IRModule mod, Target target, String mod_eq_name) { + static ffi::Array ExtractTask(IRModule mod, Target target, + ffi::String mod_eq_name) { TaskExtractor extractor(mod, target, mod_eq_name); // We go through each Relax function in the module. for (const auto& kv : mod->functions) { if (const auto* func = kv.second.as()) { - extractor(GetRef(func)); + extractor(ffi::GetRef(func)); } } - Array tasks; + ffi::Array tasks; for (const auto& it : extractor.func2task_) { tasks.push_back(it.second); } @@ -83,7 +84,7 @@ class TaskExtractor : public ExprVisitor { } private: - explicit TaskExtractor(IRModule mod, Target target, String mod_eq_name) + explicit TaskExtractor(IRModule mod, Target target, ffi::String mod_eq_name) : mod_(std::move(mod)), target_(std::move(target)), mod_eq_(ModuleEquality::Create(mod_eq_name)), @@ -143,7 +144,7 @@ class TaskExtractor : public ExprVisitor { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.backend.MetaScheduleExtractTask", [](IRModule mod, Target target, - String mod_eq_name) { + ffi::String mod_eq_name) { return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); }); }); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index c26c043e7483..e29f580793b1 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -60,7 +60,7 @@ class CodeGenVM : public ExprFunctor { // Remove relax function and turn into TIR func. for (const auto& [gvar, f] : mod->functions) { if (auto* func = f.as()) { - codegen.Codegen(GetRef(func)); + codegen.Codegen(ffi::GetRef(func)); res_mod->Remove(gvar); } } @@ -82,11 +82,11 @@ class CodeGenVM : public ExprFunctor { } void Codegen(const Function& func) { - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; - Array param_names; + ffi::Array param_names; for (Var param : func->params) { param_names.push_back(param->name_hint()); } @@ -132,7 +132,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const CallNode* call_node) final { - Call call = GetRef(call_node); + Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { return Instruction::Arg::Register(Instruction::kVoidRegister); @@ -163,7 +163,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const IfNode* op) final { - const If& ife = GetRef(op); + const If& ife = ffi::GetRef(op); Instruction::Arg cond_value = this->VisitExpr(ife->cond); // Reserve a register for cond @@ -207,7 +207,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = this->var_arg_map_.find(var); ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; return it->second; @@ -236,7 +236,8 @@ class CodeGenVM : public ExprFunctor { return builder_->ConvertConstant(float_imm->value); } else { LOG(FATAL) << "PrimValue should only contain constant after VMShapeLower, " - << "but received " << GetRef(op) << " with type " << op->value->GetTypeKey(); + << "but received " << ffi::GetRef(op) << " with type " + << op->value->GetTypeKey(); } } @@ -249,7 +250,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const TupleNode* op) final { - Tuple tuple = GetRef(op); + Tuple tuple = ffi::GetRef(op); std::vector args; for (Expr arg : tuple->fields) { args.push_back(this->VisitExpr(arg)); @@ -261,7 +262,7 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { - TupleGetItem expr = GetRef(op); + TupleGetItem expr = ffi::GetRef(op); std::vector args = {this->VisitExpr(expr->tuple)}; args.push_back(builder_->ConvertConstant(expr->index)); @@ -273,8 +274,8 @@ class CodeGenVM : public ExprFunctor { } Instruction::Arg VisitExpr_(const GlobalVarNode* op) final { - GlobalVar gvar = GetRef(op); - Optional symbol; + GlobalVar gvar = ffi::GetRef(op); + ffi::Optional symbol; VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc; // Run a look up in the env to see if it maps to an extern func. @@ -306,10 +307,10 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { static const constexpr char* kCSource = "c_source"; static const constexpr char* kCSourceFmt = "c_source_fmt"; - if (Optional opt_code = op->attrs.GetAttr(kCSource)) { - String sym = op->global_symbol; - String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); - String code = opt_code.value(); + if (ffi::Optional opt_code = op->attrs.GetAttr(kCSource)) { + ffi::String sym = op->global_symbol; + ffi::String fmt = op->attrs.GetAttr(kCSourceFmt).value_or("c"); + ffi::String code = opt_code.value(); ffi::Module c_source_module = codegen::CSourceModuleCreate(/*code=*/code, /*fmt=*/fmt, /*func_names=*/{sym}, /*const_vars=*/{}); @@ -388,7 +389,7 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(name, args, dst_reg); } - std::vector VisitArray(const Array& arr) { + std::vector VisitArray(const ffi::Array& arr) { std::vector ret; for (size_t i = 0; i < arr.size(); ++i) { ret.push_back(this->VisitExpr(arr[i])); @@ -440,8 +441,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ * module(s). * \return The created module. */ -void LinkModules(ObjectPtr exec, const Map& params, - const tvm::ffi::Module& lib, const Array& ext_libs) { +void LinkModules(ObjectPtr exec, const ffi::Map& params, + const tvm::ffi::Module& lib, const ffi::Array& ext_libs) { // query if we need const loader for ext_modules // Wrap all submodules in the initialization wrapper. std::unordered_map> const_vars_by_symbol; @@ -450,8 +451,8 @@ void LinkModules(ObjectPtr exec, const MapGetFunction("get_const_vars"); std::vector symbol_const_vars; if (pf_sym.has_value() && pf_var.has_value()) { - String symbol = (*pf_sym)().cast(); - Array variables = (*pf_var)().cast>(); + ffi::String symbol = (*pf_sym)().cast(); + ffi::Array variables = (*pf_var)().cast>(); for (size_t i = 0; i < variables.size(); i++) { symbol_const_vars.push_back(variables[i].operator std::string()); } @@ -484,11 +485,12 @@ void LinkModules(ObjectPtr exec, const Map lib, - Array ext_libs, Map params) { +ffi::Module VMLink(ExecBuilder builder, Target target, ffi::Optional lib, + ffi::Array ext_libs, + ffi::Map params) { ObjectPtr executable = builder->Get(); if (!lib.defined()) { - lib = codegen::CSourceModuleCreate(";", "c", Array{}); + lib = codegen::CSourceModuleCreate(";", "c", ffi::Array{}); } LinkModules(executable, params, lib.value(), ext_libs); return ffi::Module(executable); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index c7cf06ea9d7f..a4e7f3f16bb9 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -50,11 +50,11 @@ using vm::VMFuncInfo; * \note Skip CallPacked with special attrs for now, as they can be * further simplified with PrimValue. */ -class CodeGenVMTIR : public ExprFunctor(const Expr&)> { +class CodeGenVMTIR : public ExprFunctor(const Expr&)> { public: explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) : builder_(builder), ctx_mod_(ctx_mod) { - system_lib_prefix_ = ctx_mod_->GetAttr(tvm::attr::kSystemLibPrefix); + system_lib_prefix_ = ctx_mod_->GetAttr(tvm::attr::kSystemLibPrefix); } static IRModule Run(relax::ExecBuilder builder, IRModule mod) { @@ -66,8 +66,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { // Remove relax function and turn into TIR func. for (auto& p : mod->functions) { if (auto* func = p.second.as()) { - auto tir_func = codegen.Codegen(GetRef(func)); - auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); + auto tir_func = codegen.Codegen(ffi::GetRef(func)); + auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); res_mod->Add(GlobalVar(gsymbol.value()), tir_func); res_mod->Remove(p.first); } @@ -105,8 +105,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { stmt_stack_.back().emplace_back(stmt); } - void EmitCallPacked(String name, const Array& args, int64_t dst_anylist_slot = -1) { - Array all_args; + void EmitCallPacked(ffi::String name, const ffi::Array& args, + int64_t dst_anylist_slot = -1) { + ffi::Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; @@ -124,11 +125,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } } - void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, + void EmitCallCPacked(const tir::PrimFunc& prim_func, const ffi::Array& args, int64_t dst_anylist_slot = -1) { - Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; - Array all_args; + ffi::Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; @@ -147,7 +148,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } tir::PrimFunc Codegen(const Function& func) { - Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; // initialize the state @@ -159,7 +160,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { func_anylist_handle_ = tir::Var("f", DataType::Handle()); const_anylist_handle_ = tir::Var("c", DataType::Handle()); - Array param_names; + ffi::Array param_names; for (Var param : func->params) { param_names.push_back(param->name_hint()); } @@ -174,7 +175,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t ret_reg = NewRegister(); tir::Stmt body = WithNewScope([&]() { - Optional ret = ExprFunctor::VisitExpr(func->body); + ffi::Optional ret = ExprFunctor::VisitExpr(func->body); if (ret.defined()) { this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); } @@ -186,9 +187,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { builder_->EndFunction(gsymbol.value()); Type ret_type = VoidType(); - Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, - func_anylist_handle_}; - String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); + ffi::Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; + ffi::String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); tir::PrimFunc tir_func(tir_params, body, ret_type, {}); tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); registers_num_ = 0; @@ -197,11 +198,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return tir_func; } - Optional VisitExpr_(const SeqExprNode* op) final { + ffi::Optional VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { Expr expr = GetBoundValue(binding); - Optional value = VisitExpr(expr); + ffi::Optional value = VisitExpr(expr); if (expr.as() && value.defined()) { // For a normalized relax module, there should be one @@ -220,8 +221,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return this->VisitExpr(op->body); } - Optional VisitExpr_(const CallNode* call_node) final { - Call call = GetRef(call_node); + ffi::Optional VisitExpr_(const CallNode* call_node) final { + Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), @@ -252,7 +253,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } } - Optional VisitExpr_(const IfNode* op) final { + ffi::Optional VisitExpr_(const IfNode* op) final { // Reserve a register for return size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); @@ -272,18 +273,18 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return RegListGet(merge_register); } - Optional VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + ffi::Optional VisitExpr_(const VarNode* op) final { + Var var = ffi::GetRef(op); auto it = this->var_map_.find(var); ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; return it->second; } - Optional VisitExpr_(const ConstantNode* op) final { + ffi::Optional VisitExpr_(const ConstantNode* op) final { return ConstListGet(builder_->ConvertConstant(op->data).value()); } - Optional VisitExpr_(const ShapeExprNode* op) final { + ffi::Optional VisitExpr_(const ShapeExprNode* op) final { std::vector shape; for (PrimExpr e : op->values) { if (auto* int_value = e.as()) { @@ -295,19 +296,19 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); } - Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } + ffi::Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } - Optional VisitExpr_(const StringImmNode* op) final { + ffi::Optional VisitExpr_(const StringImmNode* op) final { return ConstListGet(builder_->ConvertConstant(op->value).value()); } - Optional VisitExpr_(const DataTypeImmNode* op) final { + ffi::Optional VisitExpr_(const DataTypeImmNode* op) final { return ConstListGet(builder_->ConvertConstant(op->value).value()); } - Optional VisitExpr_(const TupleNode* op) final { - Tuple tuple = GetRef(op); - Array args; + ffi::Optional VisitExpr_(const TupleNode* op) final { + Tuple tuple = ffi::GetRef(op); + ffi::Array args; for (auto arg : tuple->fields) { args.push_back(this->VisitExpr(arg).value()); } @@ -316,9 +317,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return RegListGet(dst_register); } - Optional VisitExpr_(const TupleGetItemNode* op) final { - TupleGetItem expr = GetRef(op); - Array args = {this->VisitExpr(expr->tuple).value()}; + ffi::Optional VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = ffi::GetRef(op); + ffi::Array args = {this->VisitExpr(expr->tuple).value()}; args.push_back(ConstInt64(expr->index)); @@ -328,12 +329,12 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } // Lookup the function and see if it matches - Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { + ffi::Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { if (auto* ext_func = expr.as()) { *kind = VMFuncInfo::FuncKind::kPackedFunc; return ext_func->global_symbol; } else if (auto* gvar_ptr = expr.as()) { - GlobalVar gvar = GetRef(gvar_ptr); + GlobalVar gvar = ffi::GetRef(gvar_ptr); // Run a look up in the env to see if it maps to an extern func. auto it = ctx_mod_->functions.find(gvar); if (it != ctx_mod_->functions.end()) { @@ -362,7 +363,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } // Lookup PrimFunc in the same module // We can do direct PrimFunc call in such cases - Optional LookupPrimFunc(const String& name) { + ffi::Optional LookupPrimFunc(const ffi::String& name) { if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt; GlobalVar gvar = ctx_mod_->GetGlobalVar(name); @@ -370,28 +371,28 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (it != ctx_mod_->functions.end()) { BaseFunc func = (*it).second; if (auto* prim_func = func.as()) { - return GetRef(prim_func); + return ffi::GetRef(prim_func); } } return std::nullopt; } - Optional VisitExpr_(const GlobalVarNode* op) final { + ffi::Optional VisitExpr_(const GlobalVarNode* op) final { VMFuncInfo::FuncKind kind; - auto symbol = LookupFunction(GetRef(op), &kind); + auto symbol = LookupFunction(ffi::GetRef(op), &kind); ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return FuncListGet(builder_->GetFunction(symbol.value()).value()); } - Optional VisitExpr_(const ExternFuncNode* op) final { + ffi::Optional VisitExpr_(const ExternFuncNode* op) final { builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); return FuncListGet(builder_->GetFunction(op->global_symbol).value()); } void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { // Handle args of the call - Array args; + ffi::Array args; args.push_back(ctx_ptr_); for (Expr arg : call_node->args) { args.push_back(this->VisitExpr(arg).value()); @@ -401,7 +402,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { ICHECK_EQ(call_node->args.size(), 4); - Array args; + ffi::Array args; args.reserve(4); for (Expr arg : call_node->args) { args.push_back(this->VisitExpr(arg).value()); @@ -429,7 +430,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { - Array args; + ffi::Array args; // if context is required, pass as first argument. args.push_back(ctx_ptr_); auto* func = call_node->args[0].as(); @@ -446,7 +447,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } void EmitNormalCall(const Call& call_node, int64_t dst_reg) { - Array args = VisitArray(call_node->args); + ffi::Array args = VisitArray(call_node->args); // A function can be a closure that comes from parent // Do call closure to be safe. VMFuncInfo::FuncKind kind; @@ -455,14 +456,14 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) { // primfunc in the same module. // use cpacked to directly invoke without named based lookup - if (Optional prim_func = LookupPrimFunc(symbol.value())) { + if (ffi::Optional prim_func = LookupPrimFunc(symbol.value())) { this->EmitCallCPacked(prim_func.value(), args, dst_reg); } else { this->EmitCallPacked(symbol.value(), args, dst_reg); } } else { // Default path, leverage function table and invoke as closure - Array all_args; + ffi::Array all_args; all_args.push_back(ctx_ptr_); all_args.push_back(this->VisitExpr(call_node->op).value()); for (auto arg : args) { @@ -481,8 +482,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return stmt; } - Array VisitArray(const Array& arr) { - Array ret; + ffi::Array VisitArray(const ffi::Array& arr) { + ffi::Array ret; for (size_t i = 0; i < arr.size(); ++i) { ret.push_back(this->VisitExpr(arr[i]).value()); } @@ -506,11 +507,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { /*! \brief Stack to build up statements */ std::vector> stmt_stack_; /*! \brief Map from var to Expr. */ - std::unordered_map> var_map_; + std::unordered_map> var_map_; /*! \brief the context module. */ IRModule ctx_mod_; /*! \brief system lib prefix */ - Optional system_lib_prefix_; + ffi::Optional system_lib_prefix_; /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 8e229c4fe641..dfb466d038de 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -33,8 +33,8 @@ using namespace vm; TVM_FFI_STATIC_INIT_BLOCK({ ExecBuilderNode::RegisterReflection(); }); ExecBuilder ExecBuilderNode::Create() { - ExecBuilder ret(make_object()); - ret->exec_ = make_object(); + ExecBuilder ret(ffi::make_object()); + ret->exec_ = ffi::make_object(); return ret; } @@ -90,7 +90,7 @@ vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) } void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs, - Optional> param_names, + ffi::Optional> param_names, vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) { auto it = exec_->func_map.find(func_name); if (it == exec_->func_map.end()) { @@ -331,17 +331,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ *ret = builder->ConvertConstant(rt).data(); }) .def("relax.ExecBuilderEmitFunction", - [](ExecBuilder builder, String func, int64_t num_inputs, - Optional> param_names) { + [](ExecBuilder builder, ffi::String func, int64_t num_inputs, + ffi::Optional> param_names) { builder->EmitFunction(func, num_inputs, param_names); }) .def_method("relax.ExecBuilderEndFunction", &ExecBuilderNode::EndFunction) .def("relax.ExecBuilderDeclareFunction", - [](ExecBuilder builder, String name, int32_t kind) { + [](ExecBuilder builder, ffi::String name, int32_t kind) { builder->DeclareFunction(name, static_cast(kind)); }) .def("relax.ExecBuilderEmitCall", - [](ExecBuilder builder, String name, Array args, int64_t dst) { + [](ExecBuilder builder, ffi::String name, ffi::Array args, int64_t dst) { std::vector args_; for (size_t i = 0; i < args.size(); ++i) { args_.push_back(Instruction::Arg::FromData(args[i]->value)); @@ -370,8 +370,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ExecBuilder builder, int64_t value) { return Instruction::Arg::ConstIdx(value).data(); }) - .def("relax.ExecBuilderF", - [](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }) + .def( + "relax.ExecBuilderF", + [](ExecBuilder builder, ffi::String value) { return builder->GetFunction(value).data(); }) .def("relax.ExecBuilderGet", [](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); return ffi::Module(p_exec); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 06adc3daba4c..cb5b8e8b1360 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -60,7 +60,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return InvokeClosure(call); } else if (call->op == alloc_tensor_op_) { LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in expression " - << GetRef(call_node) << ". " + << ffi::GetRef(call_node) << ". " << "This operation should have been lowered earlier " << "using the 'relax.transform.LowerAllocTensor' pass."; } else if (call->op == mem_alloc_storage_op_) { @@ -70,7 +70,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); } else if (const auto* op_node = call->op.as()) { - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); if (lower_builtin_fmap.count(op)) { return lower_builtin_fmap[op](builder_, call); } @@ -101,7 +101,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 2); ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; auto tir_args = Downcast(call_node->args[1]); args.push_back(call_node->args[0]); @@ -144,7 +144,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); auto attrs = call_node->attrs.as(); - Array args; + ffi::Array args; args.push_back(call_node->args[0]); // Get the DLDeviceType and device_id from VDevice VDevice vdev = attrs->dst_vdevice; @@ -160,7 +160,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; auto func = call_node->args[0]; auto closure_args = Downcast(call_node->args[1]); @@ -177,7 +177,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args[0]->IsInstance()); ICHECK(call_node->args[1]->IsInstance()); - Array args; + ffi::Array args; args.push_back(call_node->args[0]); @@ -192,7 +192,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 397490023cbe..da9f1a029a44 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -63,8 +63,8 @@ struct PrimExprSlot { */ struct MatchShapeTodoItem { Expr input; - Array pattern; - String err_ctx; + ffi::Array pattern; + ffi::String err_ctx; }; /*! \brief Slot map used for shape lowering. */ @@ -200,7 +200,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { */ class VMShapeLowerMutator : public ExprMutator, - public StructInfoFunctor*)> { public: static IRModule Lower(IRModule mod, bool emit_err_ctx) { @@ -208,7 +208,7 @@ class VMShapeLowerMutator for (auto& kv : mod->functions) { if (auto* func = kv.second.as()) { - Function updated_func = mutator.Rewrite(kv.first, GetRef(func)); + Function updated_func = mutator.Rewrite(kv.first, ffi::GetRef(func)); mutator.builder_->UpdateFunction(kv.first, updated_func); } } @@ -235,7 +235,7 @@ class VMShapeLowerMutator // prepare slot information this->PopulateSlotInfo(); - Array blocks; + ffi::Array blocks; builder_->BeginScope(func->params); @@ -305,7 +305,7 @@ class VMShapeLowerMutator for (auto& kv : slot_map_) { auto* slot = kv.second; if (!slot->expr.as()) { - Array dep_vars = tir::UndefinedVars(slot->expr); + ffi::Array dep_vars = tir::UndefinedVars(slot->expr); for (auto var : dep_vars) { auto it = slot_map_.find(var); ICHECK(it != slot_map_.end()) @@ -323,7 +323,7 @@ class VMShapeLowerMutator //------------------------------------------------------- // Helper functions //------------------------------------------------------- - StringImm GetErrContext(String err_ctx) const { + StringImm GetErrContext(ffi::String err_ctx) const { return emit_err_ctx_ ? StringImm(err_ctx) : StringImm(""); } @@ -350,7 +350,7 @@ class VMShapeLowerMutator Expr VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " << " to run it after LambdaLift"; - return GetRef(op); + return ffi::GetRef(op); } std::pair MakeSymbolicShapeArg(const PrimExpr& expr) { @@ -376,10 +376,10 @@ class VMShapeLowerMutator bool is_const_value = op->value->IsInstance() || op->value->IsInstance(); if (is_const_value) { - return GetRef(op); + return ffi::GetRef(op); } - Array args = {shape_heap_}; + ffi::Array args = {shape_heap_}; auto [code, value_or_index] = MakeSymbolicShapeArg(op->value); args.push_back(code); args.push_back(value_or_index); @@ -396,10 +396,11 @@ class VMShapeLowerMutator return e->IsInstance(); }); if (is_const_shape) { - return GetRef(op); + return ffi::GetRef(op); } - Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; + ffi::Array args = {shape_heap_, + PrimValue::Int64(static_cast(op->values.size()))}; for (PrimExpr expr : op->values) { auto [code, value_or_index] = MakeSymbolicShapeArg(expr); args.push_back(code); @@ -502,7 +503,7 @@ class VMShapeLowerMutator bool all_nop = true; bool any_nop = false; - Array args = {item.input, shape_heap_}; + ffi::Array args = {item.input, shape_heap_}; Expr match_op; if (item.input->struct_info_.as()) { @@ -567,18 +568,18 @@ class VMShapeLowerMutator ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. tir::Var heap("heap", DataType::Handle()); - Array buffer_shape{heap_size_}; + ffi::Array buffer_shape{heap_size_}; tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); - Map buffer_map; + ffi::Map buffer_map; buffer_map.Set(heap, buffer); - auto var_map = [&](const tir::Var& var) -> Optional { + auto var_map = [&](const tir::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); ICHECK(it != slot_map_.end()); return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); }; - Array seq; + ffi::Array seq; for (PrimExprSlot* slot : to_compute) { ICHECK(!slot->value_computed); slot->value_computed = true; @@ -587,7 +588,7 @@ class VMShapeLowerMutator } tir::Stmt body = tir::SeqStmt::Flatten(seq); - Array params{heap}; + ffi::Array params{heap}; Type ret_type = VoidType(); // TODO(relax-team): Consider attach the target attribute to @@ -623,14 +624,14 @@ class VMShapeLowerMutator * visit the match cast. */ void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) { return this->VisitStructInfo(struct_info, value, always_check, dynamic_only, err_ctx, match_todos); } void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // short-cut, if the struct info already satisfies the // constraint during match cast, we can skip matching @@ -640,11 +641,11 @@ class VMShapeLowerMutator } void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final {} void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape if (always_check || !IsBaseOf(PrimStructInfo(op->dtype), GetStructInfo(value))) { @@ -663,7 +664,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { @@ -683,7 +684,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // emit runtime check of shape auto* shape_expr = op->shape.as(); @@ -734,7 +735,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { auto* value_tinfo = GetStructInfoAs(value); if (value_tinfo) { @@ -757,7 +758,7 @@ class VMShapeLowerMutator } void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, - bool dynamic_only, const String& err_ctx, + bool dynamic_only, const ffi::String& err_ctx, std::vector* match_todos) final { // we only check function is callable. if (!always_check && MatchStructInfo(value)) return; @@ -779,7 +780,7 @@ class VMShapeLowerMutator std::vector> slot_vec_; /*! \brief Expr => slot. */ PrimExprSlotMap slot_map_; - Optional current_gvar_ = std::nullopt; + ffi::Optional current_gvar_ = std::nullopt; /*! * \brief List of vars that are being defined but * have not go through outstanding shape compute check. @@ -790,7 +791,7 @@ class VMShapeLowerMutator const Op& null_value_op_ = Op::Get("relax.null_value"); // common struct info const StructInfo object_sinfo_ = ObjectStructInfo(); - const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + const StructInfo void_sinfo_ = TupleStructInfo(ffi::Array({})); // check function const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index 491ffc12fa57..12feeacc8b0b 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -29,7 +29,8 @@ namespace tvm { namespace tir { -Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::Analyzer* analyzer) { +Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, + arith::Analyzer* analyzer) { if (index.as()) { return Downcast(index); } @@ -47,12 +48,12 @@ Var GetShardingVarFromIndex(PrimExpr index, Map var_range, arith::An return Var(); } // the floormod must take no effect - if (!analyzer->CanProve( - floordiv(var_range[GetRef(source_var)]->extent, highest_iter_split->lower_factor) <= - highest_iter_split->extent)) { + if (!analyzer->CanProve(floordiv(var_range[ffi::GetRef(source_var)]->extent, + highest_iter_split->lower_factor) <= + highest_iter_split->extent)) { return Var(); } - return GetRef(source_var); + return ffi::GetRef(source_var); } } // namespace tir } // namespace tvm @@ -75,7 +76,7 @@ const TensorStructInfoNode* GetTensorStructInfo(Expr tensor) { throw; } -void UnaryOpHelper(Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { +void UnaryOpHelper(ffi::Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { int n_dim = GetTensorStructInfo(tensor_list[0])->ndim; for (const auto& tensor : tensor_list) { ICHECK(GetTensorStructInfo(tensor)->ndim == n_dim); @@ -91,7 +92,7 @@ void UnaryOpHelper(Array tensor_list, distributed::AxisGroupGraph* axis_gr void BuildAxisGraphUnary(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { - Array tensor_list; // vars in param and output + ffi::Array tensor_list; // vars in param and output if (call->args[0]->IsInstance()) { tensor_list.push_back(call->args[0]); } @@ -101,7 +102,7 @@ void BuildAxisGraphUnary(const Var& output_var, const Call& call, void BuildAxisGraphBinary(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { - Array tensor_list; // vars in param and output + ffi::Array tensor_list; // vars in param and output if (call->args[0]->struct_info_.as() || call->args[0]->struct_info_.as()) { tensor_list.push_back(call->args[0]); @@ -162,7 +163,7 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, void BuildAxisGraphReduce(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { Expr input_tensor = call->args[0]; - Array axes; + ffi::Array axes; bool keepdims; if (const auto* attrs = call->attrs.as()) { if (attrs->axis.defined()) { @@ -228,10 +229,10 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, const auto* x1_shape = x1_sinfo->shape.as(); const auto* x2_shape = x2_sinfo->shape.as(); ICHECK(x1_shape && x2_shape); - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; int x1_prefix_ndim = x1_shape_prefix.size(); int x2_prefix_ndim = x2_shape_prefix.size(); @@ -311,8 +312,8 @@ void BuildAxisGraphReshape(const Var& output_var, const Call& call, const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); const auto* old_shape_sinfo = GetStructInfoAs(tensor_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); - Array old_shape_values = old_shape_sinfo->values.value(); - Array new_shape_values = new_shape_sinfo->values.value(); + ffi::Array old_shape_values = old_shape_sinfo->values.value(); + ffi::Array new_shape_values = new_shape_sinfo->values.value(); int i = old_shape_values.size(); int j = new_shape_values.size(); PrimExpr old_shape_product = 1, new_shape_product = 1; @@ -349,8 +350,8 @@ inline int GetNumOutput(Call call) { void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func, distributed::AxisGroupGraph* axis_group_graph) { auto tir_var_axis_group_list = tir::BufferAxisGraphExtractor::GetTIRVarAxisGraph(func); - Map input_var_to_relax_expr; - Array input_list = Downcast(call->args[1])->fields; + ffi::Map input_var_to_relax_expr; + ffi::Array input_list = Downcast(call->args[1])->fields; input_list.push_back(output_var); for (int i = 0; i < static_cast(input_list.size()); i++) { if (func->buffer_map.count(func->params[i])) { diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index b4f435569330..4ac44d252560 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -26,12 +26,12 @@ namespace distributed { TVM_FFI_STATIC_INIT_BLOCK({ DeviceMeshNode::RegisterReflection(); }); -DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { +DeviceMesh::DeviceMesh(ffi::Shape shape, ffi::Array device_ids) { int prod = 1; for (int i = 0; i < static_cast(shape.size()); i++) { prod *= shape[i]; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); CHECK_EQ(prod, static_cast(device_ids.size())) << "The number of device ids must match the product of the shape"; n->shape = std::move(shape); @@ -40,8 +40,8 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { } DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { - ObjectPtr n = make_object(); - Array device_ids; + ObjectPtr n = ffi::make_object(); + ffi::Array device_ids; int range_start = device_range->min.as()->value; int range_extent = device_range->extent.as()->value; for (int i = range_start; i < range_start + range_extent; i++) { @@ -63,7 +63,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.distributed.DeviceMesh", - [](ffi::Shape shape, Array device_ids, Optional device_range) { + [](ffi::Shape shape, ffi::Array device_ids, ffi::Optional device_range) { if (device_range.defined()) return DeviceMesh(shape, device_range.value()); else diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 0b6f3624cc10..64ee815b19ba 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -35,14 +35,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); PlacementSpec PlacementSpec::Sharding(int axis) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->axis = axis; n->kind = PlacementSpecKind::kSharding; return PlacementSpec(n); } PlacementSpec PlacementSpec::Replica() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->axis = -1; n->kind = PlacementSpecKind::kReplica; return PlacementSpec(n); @@ -55,7 +55,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.distributed.Replica", []() { return PlacementSpec::Replica(); }); }); -String PlacementNode::ToString() const { +ffi::String PlacementNode::ToString() const { std::stringstream ss; for (size_t i = 0; i < dim_specs.size(); ++i) { if (i != 0) { @@ -70,14 +70,14 @@ String PlacementNode::ToString() const { return ss.str(); } -Placement::Placement(Array dim_specs) { - ObjectPtr n = make_object(); +Placement::Placement(ffi::Array dim_specs) { + ObjectPtr n = ffi::make_object(); n->dim_specs = std::move(dim_specs); data_ = std::move(n); } -Placement Placement::FromText(String text_repr) { - Array dim_specs; +Placement Placement::FromText(ffi::String text_repr) { + ffi::Array dim_specs; std::stringstream ss(text_repr); while (true) { char indicator = 0; @@ -114,7 +114,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("relax.distributed.PlacementFromText", Placement::FromText) .def("relax.distributed.Placement", - [](Array dim_specs) { return Placement(dim_specs); }); + [](ffi::Array dim_specs) { return Placement(dim_specs); }); }); // DTensor @@ -127,7 +127,7 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d CHECK_LT(spec->axis, tensor_sinfo->ndim) << "ValueError: Sharding dimension should be smaller than tensor ndim"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->device_mesh = std::move(device_mesh); n->placement = std::move(placement); n->tensor_sinfo = std::move(tensor_sinfo); diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 47f28252ff51..d9a786867453 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -55,7 +55,7 @@ class RedistributeLegalizer : public ExprMutator { continue; } Expr new_func_body = VisitExpr(func_->body); - auto new_func = make_object(*func_); + auto new_func = ffi::make_object(*func_); new_func->body = new_func_body; builder_->UpdateFunction(gv, Function(new_func)); } diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 036867043f71..e4131549f487 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -52,10 +52,10 @@ class DistIRSharder : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsDistIRFunc(GetRef(func_))) { + if (func_ == nullptr || !IsDistIRFunc(ffi::GetRef(func_))) { continue; } - Function func = RewriteFunction(GetRef(func_)); + Function func = RewriteFunction(ffi::GetRef(func_)); builder_->UpdateFunction(gv, func); } return builder_->GetContextIRModule(); @@ -63,7 +63,7 @@ class DistIRSharder : public ExprMutator { ShapeExpr ShardShape(ShapeExpr orig_shape, DeviceMesh device_mesh, Placement placement) { ffi::Shape device_mesh_shape = device_mesh->shape; - Array new_tensor_shape_value = orig_shape->values; + ffi::Array new_tensor_shape_value = orig_shape->values; for (int i = 0; i < static_cast(device_mesh_shape.size()); i++) { if (placement->dim_specs[i]->kind == PlacementSpecKind::kSharding) { int shard_size = device_mesh_shape[i]; @@ -78,25 +78,25 @@ class DistIRSharder : public ExprMutator { TensorStructInfo tensor_sinfo = orig_sinfo->tensor_sinfo; ICHECK(tensor_sinfo->shape); const auto* orig_shape = tensor_sinfo->shape.as(); - auto new_tensor_sinfo = make_object(*tensor_sinfo.get()); - new_tensor_sinfo->shape = - ShardShape(GetRef(orig_shape), orig_sinfo->device_mesh, orig_sinfo->placement); + auto new_tensor_sinfo = ffi::make_object(*tensor_sinfo.get()); + new_tensor_sinfo->shape = ShardShape(ffi::GetRef(orig_shape), + orig_sinfo->device_mesh, orig_sinfo->placement); return TensorStructInfo(new_tensor_sinfo); } StructInfo ConvertSinfo(StructInfo orig_sinfo, bool shard_shape) { if (const auto* dtensor_sinfo = orig_sinfo.as()) { if (shard_shape) { - return ShardDTensorSinfo(GetRef(dtensor_sinfo)); + return ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo)); } else { return dtensor_sinfo->tensor_sinfo; } } else if (const auto* tuple_sinfo = orig_sinfo.as()) { - Array new_fields; + ffi::Array new_fields; for (const auto& field_sinfo : tuple_sinfo->fields) { if (const auto* dtensor_sinfo = field_sinfo.as()) { if (shard_shape) { - new_fields.push_back(ShardDTensorSinfo(GetRef(dtensor_sinfo))); + new_fields.push_back(ShardDTensorSinfo(ffi::GetRef(dtensor_sinfo))); } else { new_fields.push_back(dtensor_sinfo->tensor_sinfo); } @@ -157,12 +157,13 @@ class DistIRSharder : public ExprMutator { for (int i = 0; i < static_cast(func_->params.size()); i++) { Var param = func_->params[i]; if (const auto* dtensor_sinfo = GetStructInfoAs(param)) { - EmitBroadcastOrScatter(param, new_params_[i], GetRef(dtensor_sinfo)); + EmitBroadcastOrScatter(param, new_params_[i], + ffi::GetRef(dtensor_sinfo)); } else if (const auto* tuple_sinfo = GetStructInfoAs(param)) { for (int j = 0; j < static_cast(tuple_sinfo->fields.size()); j++) { if (const auto* dtensor_sinfo = tuple_sinfo->fields[j].as()) { EmitBroadcastOrScatter(TupleGetItem(param, j), TupleGetItem(new_params_[i], j), - GetRef(dtensor_sinfo)); + ffi::GetRef(dtensor_sinfo)); } } } @@ -170,7 +171,7 @@ class DistIRSharder : public ExprMutator { } Function RewriteFunction(Function func) { - Array new_params; + ffi::Array new_params; for (const Var& var : func->params) { Var new_param = Downcast(ShardInputParamTensorAndConstant(var)); var_remap_[var->vid] = new_param; @@ -184,8 +185,8 @@ class DistIRSharder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - if (tuple_getitem_remap_.count(GetRef(val))) { - var_remap_[binding->var->vid] = tuple_getitem_remap_[GetRef(val)]; + if (tuple_getitem_remap_.count(ffi::GetRef(val))) { + var_remap_[binding->var->vid] = tuple_getitem_remap_[ffi::GetRef(val)]; } else { ExprMutator::VisitBinding_(binding, val); } @@ -217,19 +218,19 @@ class DistIRSharder : public ExprMutator { ICHECK(call->args[1].as()); const auto* out_sinfo = GetStructInfoAs(binding_var); ICHECK(out_sinfo); - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); new_call_node->args.Set(1, ShardShape(Downcast(call->args[1]), out_sinfo->device_mesh, out_sinfo->placement)); return Call(new_call_node); } else if (call->op.same_as(call_tir_local_view_op)) { - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); new_call_node->op = call_tir_op; new_call_node->sinfo_args = {ConvertSinfo(GetStructInfo(binding_var), true)}; return Call(new_call_node); } else if (call->op.same_as(call_tir_op)) { LOG(FATAL) << "call_tir should be lowered to call_tir_local_view before lowering to relax"; } else if (const auto* extern_func = call->op.as()) { - auto new_call_node = make_object(*call); + auto new_call_node = ffi::make_object(*call); if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_append") { new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_append"); } else if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { @@ -243,7 +244,7 @@ class DistIRSharder : public ExprMutator { } return Call(new_call_node); } - return GetRef(call); + return ffi::GetRef(call); } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { @@ -253,7 +254,7 @@ class DistIRSharder : public ExprMutator { } Function func_; - Array new_params_; + ffi::Array new_params_; std::unordered_map tuple_getitem_remap_; }; diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 7baf49508d58..b93deb9d2b13 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -36,18 +36,18 @@ using namespace tvm::relax::distributed; class DistBufferReplacer : public StmtExprMutator { public: - static Stmt BufferReplace(Stmt stmt, Map buffer_map) { + static Stmt BufferReplace(Stmt stmt, ffi::Map buffer_map) { DistBufferReplacer replacer(buffer_map); return replacer(stmt); } private: - explicit DistBufferReplacer(Map buffer_map) : buffer_map_(buffer_map) {} + explicit DistBufferReplacer(ffi::Map buffer_map) : buffer_map_(buffer_map) {} Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (buffer_map_.count(store->buffer)) { - ObjectPtr new_store = make_object(*store.get()); + ObjectPtr new_store = ffi::make_object(*store.get()); new_store->buffer = buffer_map_[store->buffer]; return BufferStore(new_store); } @@ -57,7 +57,7 @@ class DistBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* _load) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); if (buffer_map_.count(load->buffer)) { - ObjectPtr new_load = make_object(*load.get()); + ObjectPtr new_load = ffi::make_object(*load.get()); new_load->buffer = buffer_map_[load->buffer]; return BufferLoad(new_load); } @@ -65,15 +65,15 @@ class DistBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = GetRef(_block); + Block old_block = ffi::GetRef(_block); Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = make_object(*block.get()); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, buffer_map_); new_block->writes = ReplaceBuffer(new_block->writes, buffer_map_); return Block(new_block); } - Map buffer_map_; + ffi::Map buffer_map_; }; class DistBlockInfoCollector : public StmtExprVisitor { @@ -136,7 +136,7 @@ class DistBlockInfoCollector : public StmtExprVisitor { Buffer reduce_buffer_; public: - std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> buffer_access_indices; std::string reduce_kind; }; @@ -151,8 +151,8 @@ class DistributedBufferCompactor : StmtExprMutator { const std::vector& sharding_specs, PrimFunc prim_func) { prim_func = RenewDefs(prim_func); DistributedBufferCompactor compactor(sharding_specs, prim_func); - Map new_func_buffer_map; - Map replace_buffer_map; + ffi::Map new_func_buffer_map; + ffi::Map replace_buffer_map; for (const auto& pr : prim_func->buffer_map) { Buffer shard_buffer = compactor.ShardBuffer(pr.second); new_func_buffer_map.Set(pr.first, shard_buffer); @@ -162,7 +162,7 @@ class DistributedBufferCompactor : StmtExprMutator { } Stmt new_body = compactor(prim_func->body); new_body = DistBufferReplacer::BufferReplace(new_body, replace_buffer_map); - ObjectPtr new_func = make_object(*prim_func.get()); + ObjectPtr new_func = ffi::make_object(*prim_func.get()); new_func->buffer_map = new_func_buffer_map; new_func->body = new_body; return std::make_tuple(PrimFunc(new_func), compactor.add_allreduce_kind_); @@ -200,10 +200,9 @@ class DistributedBufferCompactor : StmtExprMutator { } } - Array ShardIterVar( - Block block, - const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& - buffer_access_indices) { + ffi::Array ShardIterVar( + Block block, const std::unordered_map>, ObjectPtrHash, + ObjectPtrEqual>& buffer_access_indices) { std::vector buffers; for (const auto& read : block->reads) { buffers.push_back(read->buffer); @@ -211,7 +210,7 @@ class DistributedBufferCompactor : StmtExprMutator { for (const auto& write : block->writes) { buffers.push_back(write->buffer); } - Map iter_var_range; + ffi::Map iter_var_range; for (const auto& iter_var : block->iter_vars) { iter_var_range.Set(iter_var->var, iter_var->dom); } @@ -220,7 +219,7 @@ class DistributedBufferCompactor : StmtExprMutator { if (buffer_access_indices.count(buffer) == 0 || buffer_shards_.count(buffer) == 0) { continue; } - Array> access_indices = buffer_access_indices.at(buffer); + ffi::Array> access_indices = buffer_access_indices.at(buffer); DimShard dim_shards = buffer_shards_[buffer]; for (const auto& access_index : access_indices) { for (const auto& pr : dim_shards) { @@ -234,7 +233,7 @@ class DistributedBufferCompactor : StmtExprMutator { } } - Array new_iter_vars; + ffi::Array new_iter_vars; for (const auto& iter_var : block->iter_vars) { if (iter_var_shards_.count(iter_var->var)) { int shard = iter_var_shards_[iter_var->var]; @@ -259,7 +258,7 @@ class DistributedBufferCompactor : StmtExprMutator { return buffer; } DimShard dim_shards = buffer_shards_[buffer]; - Array shape; + ffi::Array shape; for (int i = 0; i < static_cast(buffer->shape.size()); i++) { if (dim_shards.count(i)) { shape.push_back(floordiv(buffer->shape[i], dim_shards[i])); @@ -267,7 +266,7 @@ class DistributedBufferCompactor : StmtExprMutator { shape.push_back(buffer->shape[i]); } } - ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); new_buffer->shape = shape; return Buffer(new_buffer); } @@ -276,9 +275,9 @@ class DistributedBufferCompactor : StmtExprMutator { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); DistBlockInfoCollector collector; collector(block); - Array new_iter_vars = ShardIterVar(block, collector.buffer_access_indices); - Array new_alloc_buffers; - Map buffer_map; + ffi::Array new_iter_vars = ShardIterVar(block, collector.buffer_access_indices); + ffi::Array new_alloc_buffers; + ffi::Map buffer_map; for (const Buffer& buffer : block->alloc_buffers) { Buffer sharded_buffer = ShardBuffer(buffer); if (!sharded_buffer.same_as(buffer)) { @@ -295,7 +294,7 @@ class DistributedBufferCompactor : StmtExprMutator { break; } } - ObjectPtr new_block = make_object(*block.operator->()); + ObjectPtr new_block = ffi::make_object(*block.operator->()); new_block->iter_vars = new_iter_vars; new_block->alloc_buffers = new_alloc_buffers; if (new_block->name_hint == "root") { @@ -340,7 +339,7 @@ class DistributedBufferCompactor : StmtExprMutator { std::unordered_map iter_var_shards_; std::unordered_map loop_var_shards_; - Array allocated_buffer_under_root; + ffi::Array allocated_buffer_under_root; BufferAxisGraphExtractor extractor_; std::vector sharding_specs_; std::unordered_map buffer_shards_; @@ -362,11 +361,11 @@ class LowerTIRToLocalView : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsDistIRFunc(GetRef(func_))) { + if (func_ == nullptr || !IsDistIRFunc(ffi::GetRef(func_))) { continue; } Expr new_func_body = this->VisitExpr(func_->body); - ObjectPtr new_func = make_object(*func_); + ObjectPtr new_func = ffi::make_object(*func_); new_func->body = new_func_body; builder_->UpdateFunction(gv, Function(new_func)); } @@ -374,11 +373,11 @@ class LowerTIRToLocalView : public ExprMutator { } private: - inline Array ExtractDTensorStructInfo(Var var) { + inline ffi::Array ExtractDTensorStructInfo(Var var) { if (const auto* dtensor_sinfo = GetStructInfoAs(var)) { - return {GetRef(dtensor_sinfo)}; + return {ffi::GetRef(dtensor_sinfo)}; } else if (const auto* tuple_sinfo = GetStructInfoAs(var)) { - Array ret; + ffi::Array ret; for (const auto& field : tuple_sinfo->fields) { ret.push_back(Downcast(field)); } @@ -395,14 +394,14 @@ class LowerTIRToLocalView : public ExprMutator { return; } std::vector sharding_specs; - Array args = Downcast(val->args[1])->fields; + ffi::Array args = Downcast(val->args[1])->fields; for (const auto& arg : args) { const auto* sinfo = GetStructInfoAs(arg); ICHECK(sinfo); sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } Var output_var = binding->var; - Array output_sinfos = ExtractDTensorStructInfo(output_var); + ffi::Array output_sinfos = ExtractDTensorStructInfo(output_var); for (const auto& sinfo : output_sinfos) { sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } @@ -414,12 +413,12 @@ class LowerTIRToLocalView : public ExprMutator { tir::DistributedBufferCompactor::DistBufferCompact(sharding_specs, prim_func); auto new_gvar = builder_->AddFunction(new_prim_func, gvar->name_hint); Call call = Downcast(this->VisitExpr(binding->value)); - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); new_call_node->op = Op::Get("relax.dist.call_tir_local_view"); new_call_node->args.Set(0, new_gvar); Call new_call(new_call_node); if (allreduce_kind != "") { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->op_type = allreduce_kind; new_call = Call(Op::Get("relax.ccl.allreduce"), {new_call}, Attrs(attrs), {}); } diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 1f46b54cfe50..71e27e8ffd52 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -48,7 +48,7 @@ void CollectAxisGraphBinary(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : binary_op_names) { const Op& binary_op = Op::Get("relax." + op_name); if (call->op.same_as(binary_op)) { - BuildAxisGraphBinary(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphBinary(binding->var, ffi::GetRef(call), axis_group_graph); break; } } @@ -71,7 +71,7 @@ void CollectAxisGraphUnary(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : unary_op_names) { const Op& unary_op = Op::Get("relax." + op_name); if (call->op.same_as(unary_op)) { - BuildAxisGraphUnary(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphUnary(binding->var, ffi::GetRef(call), axis_group_graph); } } } @@ -83,7 +83,7 @@ void CollectAxisGraphReduce(const VarBindingNode* binding, const CallNode* call, for (const auto& op_name : reduction_op_names) { const Op& reduction_op = Op::Get("relax." + op_name); if (call->op.same_as(reduction_op)) { - BuildAxisGraphReduce(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphReduce(binding->var, ffi::GetRef(call), axis_group_graph); break; } } @@ -93,7 +93,7 @@ void CollectAxisGraphMatmul(const VarBindingNode* binding, const CallNode* call, AxisGroupGraph* axis_group_graph) { static const Op& matmul_op = Op::Get("relax.matmul"); if (call->op.same_as(matmul_op)) { - BuildAxisGraphMatmul(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphMatmul(binding->var, ffi::GetRef(call), axis_group_graph); } } @@ -101,7 +101,7 @@ void CollectAxisGraphPermuteDims(const VarBindingNode* binding, const CallNode* AxisGroupGraph* axis_group_graph) { static const Op& permute_dims_op = Op::Get("relax.permute_dims"); if (call->op.same_as(permute_dims_op)) { - BuildAxisGraphPermuteDims(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphPermuteDims(binding->var, ffi::GetRef(call), axis_group_graph); } } @@ -109,15 +109,15 @@ void CollectAxisGraphReshape(const VarBindingNode* binding, const CallNode* call AxisGroupGraph* axis_group_graph) { static const Op& reshape_op = Op::Get("relax.reshape"); if (call->op.same_as(reshape_op)) { - BuildAxisGraphReshape(binding->var, GetRef(call), axis_group_graph); + BuildAxisGraphReshape(binding->var, ffi::GetRef(call), axis_group_graph); } } void CollectAxisGraphForDeviceMesh(const VarBindingNode* binding, const CallNode* call, AxisGroupGraph* axis_group_graph) { - Array tensor_list; + ffi::Array tensor_list; static const Op& call_tir_op = Op::Get("relax.call_tir"); - Array args; + ffi::Array args; if (call->op.same_as(call_tir_op)) { args = Downcast(call->args[1])->fields; } else { @@ -158,8 +158,9 @@ class AxisGroupGraphBuilder : public ExprVisitor { CollectAxisGraphReshape(binding, val, axis_group_graph_); static const Op& call_tir_op = Op::Get("relax.call_tir"); if (val->op.same_as(call_tir_op)) { - if (Optional func = MatchPrimFunc(mod_, val->args[0])) { - BuildAxisGraphCallTIR(binding->var, GetRef(val), func.value(), axis_group_graph_); + if (ffi::Optional func = MatchPrimFunc(mod_, val->args[0])) { + BuildAxisGraphCallTIR(binding->var, ffi::GetRef(val), func.value(), + axis_group_graph_); } } CollectAxisGraphForDeviceMesh(binding, val, axis_group_graph_); @@ -183,9 +184,9 @@ class AxisGroupGraphBuilder : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const VarNode* val) { - Array tensor_sinfos; + ffi::Array tensor_sinfos; if (const auto* tensor_sinfo = binding->var->struct_info_.as()) { - tensor_sinfos.push_back(GetRef(tensor_sinfo)); + tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = binding->var->struct_info_.as()) { ICHECK(tuple_sinfo); for (const auto& sinfo : tuple_sinfo->fields) { @@ -271,7 +272,7 @@ class ShardingConflictHandler : public ExprVisitor { ICHECK(shape); int ndim = sinfo->ndim; std::unordered_set sharded_mesh_dim; - Optional device_mesh; + ffi::Optional device_mesh; for (int i = -1; i < ndim; i++) { AxisShardingSpec sharding_spec; int has_sharding_spec; @@ -318,7 +319,7 @@ class ShardingConflictHandler : public ExprVisitor { } void VisitExpr_(const CallNode* op) final { - Array args = GetCallArgs(GetRef(op)); + ffi::Array args = GetCallArgs(ffi::GetRef(op)); for (const auto& arg : args) { if (arg.as()) { CheckConstantNoSharding(Downcast(arg)); @@ -348,10 +349,10 @@ class DistributedIRBuilder : public ExprMutator { auto mod = builder_->GetContextIRModule(); for (const auto& [gv, base_func] : mod->functions) { const auto* func_ = base_func.as(); - if (func_ == nullptr || !IsShardingAnnotatedFunc(GetRef(func_))) { + if (func_ == nullptr || !IsShardingAnnotatedFunc(ffi::GetRef(func_))) { continue; } - Function func = RewriteFunction(GetRef(func_), mod); + Function func = RewriteFunction(ffi::GetRef(func_), mod); builder_->UpdateFunction(gv, func); } return builder_->GetContextIRModule(); @@ -366,7 +367,7 @@ class DistributedIRBuilder : public ExprMutator { DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({expr.get(), -1, tuple_idx})).first; ICHECK(device_mesh.defined()) << expr << "[" << tuple_idx << "] is not assigned device mesh"; - Array placement_specs( + ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); for (int i = 0; i < ndim; i++) { AxisShardingSpec sharding_spec; @@ -387,7 +388,7 @@ class DistributedIRBuilder : public ExprMutator { new_sinfo = ConvertToDTensorStructInfo(Downcast(tensor->struct_info_), tensor); } else if (const auto* tuple = tensor->struct_info_.as()) { - Array tuple_sinfo_fields; + ffi::Array tuple_sinfo_fields; for (int i = 0; i < static_cast(tuple->fields.size()); i++) { if (tuple->fields[i].as()) { tuple_sinfo_fields.push_back( @@ -419,7 +420,7 @@ class DistributedIRBuilder : public ExprMutator { // Step 3. Handle Sharding Conflict ShardingConflictHandler::HandleShardingConflict(&axis_group_graph_, func); // Step 4. Rewrite Function - Array new_params; + ffi::Array new_params; for (const Var& var : func->params) { if (GetStructInfoAs(var) || GetStructInfoAs(var)) { Var new_param = Downcast(RewriteInputTensorAndConstant(var)); @@ -437,20 +438,20 @@ class DistributedIRBuilder : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); FBuildAxisGraph f = [&](const Var& var, const Call& call, AxisGroupGraph* axis_group_graph) { - Optional prim_func = + ffi::Optional prim_func = MatchPrimFunc(this->builder_->GetContextIRModule(), call->args[0]); ICHECK(prim_func); return BuildAxisGraphCallTIR(var, call, prim_func.value(), axis_group_graph); }; Call new_call = Downcast(ExprMutator::VisitExpr_(call)); - Array args = GetCallArgs(new_call); + ffi::Array args = GetCallArgs(new_call); for (int i = 0; i < static_cast(args.size()); i++) { if (args[i].as()) { args.Set(i, RewriteInputTensorAndConstant(args[i])); } } - ObjectPtr n = make_object(*new_call.get()); + ObjectPtr n = ffi::make_object(*new_call.get()); if (new_call->op.same_as(call_tir_op)) { // do not infer output sinfo when arg size is 0 if (!args.empty()) { @@ -484,13 +485,13 @@ class DistributedIRBuilder : public ExprMutator { return redistribute(expr, device_mesh, placement); } - Call RewriteOutSinfo(Call call, DeviceMesh device_mesh, Array placements) { + Call RewriteOutSinfo(Call call, DeviceMesh device_mesh, ffi::Array placements) { // in cases when infer fails (like arg size is 0), we use propagated sinfo for output Call new_call = call; static Op call_tir_op = Op::Get("relax.call_tir"); if (const auto* extern_func = call->op.as()) { if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_view") { - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); StructInfo new_dtensor_sinfo = DTensorStructInfo( Downcast(call->sinfo_args[0]), device_mesh, placements[0]); new_call_node->sinfo_args = {new_dtensor_sinfo}; @@ -500,14 +501,14 @@ class DistributedIRBuilder : public ExprMutator { } else if (call->op.same_as(call_tir_op)) { ICHECK(call->sinfo_args.size() == 1); if (!SinfoCompatibleWithDistIR(call->sinfo_args)) { - ObjectPtr new_call_node = make_object(*call.get()); + ObjectPtr new_call_node = ffi::make_object(*call.get()); if (placements.size() == 1) { new_call_node->sinfo_args = {DTensorStructInfo( Downcast(call->sinfo_args[0]), device_mesh, placements[0])}; } else { const auto* tuple_sinfo = call->sinfo_args[0].as(); ICHECK(placements.size() == tuple_sinfo->fields.size()); - Array new_tuple_sinfo_fields; + ffi::Array new_tuple_sinfo_fields; for (int i = 0; i < static_cast(placements.size()); i++) { new_tuple_sinfo_fields.push_back(DTensorStructInfo( Downcast(tuple_sinfo->fields[i]), device_mesh, placements[i])); @@ -522,9 +523,9 @@ class DistributedIRBuilder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { - Array orig_output_tensor_sinfos; + ffi::Array orig_output_tensor_sinfos; if (const auto* tensor_sinfo = GetStructInfoAs(binding->var)) { - orig_output_tensor_sinfos.push_back(GetRef(tensor_sinfo)); + orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = GetStructInfoAs(binding->var)) { for (const auto& sinfo : tuple_sinfo->fields) { orig_output_tensor_sinfos.push_back(Downcast(sinfo)); @@ -537,9 +538,9 @@ class DistributedIRBuilder : public ExprMutator { DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({binding->var.get(), -1})).first; ICHECK(device_mesh.defined()); - Array placements; // every tuple element has a placement + ffi::Array placements; // every tuple element has a placement for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { - Array placement_specs( + ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { AxisShardingSpec sharding_spec; @@ -565,7 +566,7 @@ class DistributedIRBuilder : public ExprMutator { new_value = InsertRedistribute(new_value, device_mesh, placements[0]); } if (const auto* var = new_value.as()) { - var_remap_[binding->var->vid] = GetRef(var); + var_remap_[binding->var->vid] = ffi::GetRef(var); } else { ReEmitBinding(binding, builder_->Normalize(new_value)); } @@ -589,22 +590,22 @@ class DistributedIRBuilder : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - if (tuple_getitem_remap_.count(GetRef(val))) { - var_remap_[binding->var->vid] = tuple_getitem_remap_[GetRef(val)]; + if (tuple_getitem_remap_.count(ffi::GetRef(val))) { + var_remap_[binding->var->vid] = tuple_getitem_remap_[ffi::GetRef(val)]; } else { ExprMutator::VisitBinding_(binding, val); } } Expr VisitExpr_(const VarNode* var) final { - auto it = input_tensor_remap_.find(GetRef(var)); + auto it = input_tensor_remap_.find(ffi::GetRef(var)); if (it != input_tensor_remap_.end()) { var_remap_[var->vid] = (*it).second; } return ExprMutator::VisitExpr_(var); } - Map input_tensor_remap_; + ffi::Map input_tensor_remap_; std::unordered_map tuple_getitem_remap_; AxisGroupGraph axis_group_graph_; }; diff --git a/src/relax/distributed/transform/utils.cc b/src/relax/distributed/transform/utils.cc index 42b914617e73..0bcd730d42c8 100644 --- a/src/relax/distributed/transform/utils.cc +++ b/src/relax/distributed/transform/utils.cc @@ -22,7 +22,7 @@ namespace tvm { namespace relax { namespace distributed { -bool SinfoCompatibleWithDistIR(Array sinfos) { +bool SinfoCompatibleWithDistIR(ffi::Array sinfos) { bool compatible = true; for (const auto& sinfo : sinfos) { if (const auto* tuple_sinfo = sinfo.as()) { @@ -34,7 +34,7 @@ bool SinfoCompatibleWithDistIR(Array sinfos) { return compatible; } -bool SinfoCompatibleWithRelax(Array sinfos) { +bool SinfoCompatibleWithRelax(ffi::Array sinfos) { bool compatible = true; for (const auto& sinfo : sinfos) { if (const auto* tuple_sinfo = sinfo.as()) { @@ -46,7 +46,7 @@ bool SinfoCompatibleWithRelax(Array sinfos) { return compatible; } bool IsDistIRFunc(Function func) { - Array param_sinfos; + ffi::Array param_sinfos; for (const auto& param : func->params) { ICHECK(param->struct_info_); param_sinfos.push_back(Downcast(param->struct_info_.value())); diff --git a/src/relax/distributed/transform/utils.h b/src/relax/distributed/transform/utils.h index 2680c892695c..963efc15f6a0 100644 --- a/src/relax/distributed/transform/utils.h +++ b/src/relax/distributed/transform/utils.h @@ -33,12 +33,12 @@ namespace distributed { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ -inline Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { +inline ffi::Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) - Optional base_func = mod_->functions.Get(global_var); + ffi::Optional base_func = mod_->functions.Get(global_var); if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); + return ffi::GetRef(pfunc); } return std::nullopt; } @@ -46,7 +46,7 @@ inline Optional MatchPrimFunc(const IRModule& mod_, const Expr& o * \brief Check whether the given struct infos can appear in DistIR * \return Whether the given struct infos can appear in DistIR */ -bool SinfoCompatibleWithDistIR(Array sinfos); +bool SinfoCompatibleWithDistIR(ffi::Array sinfos); /*! * \brief Check whether the given function is a DistIR function diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 9dae9175ef27..44688e27e162 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -39,7 +39,7 @@ namespace relax { TVM_FFI_STATIC_INIT_BLOCK({ DataflowBlockRewriteNode::RegisterReflection(); }); DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { - auto n = make_object(); + auto n = ffi::make_object(); n->dfb_ = dfb; n->root_fn_ = root_fn; n->original_fn_ptr_ = root_fn.get(); @@ -73,7 +73,7 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { using ExprMutator::VisitExpr_; Expr VisitExpr_(const VarNode* op) override { - return (op == old_var.get()) ? new_var : GetRef(op); + return (op == old_var.get()) ? new_var : ffi::GetRef(op); } BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { @@ -177,7 +177,7 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } for (const VarNode* v : used_vars) { - auto var = GetRef(v); + auto var = ffi::GetRef(v); if (auto users = to_users_.Get(var)) { users.value().push_back(var); } @@ -190,7 +190,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.dfb_rewrite_add_binding", [](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }) .def("relax.dfb_rewrite_add", - [](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + [](DataflowBlockRewrite rwt, Expr expr, ffi::Optional name, bool is_dfvar) { if (name.has_value()) { rwt->Add(name.value(), expr, is_dfvar); } else { @@ -199,7 +199,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -std::set GetUnusedVars(Map> users_map, Array fn_outputs) { +std::set GetUnusedVars(ffi::Map> users_map, ffi::Array fn_outputs) { std::vector unused; // iterative dataflow algorithm. @@ -227,7 +227,7 @@ std::set GetUnusedVars(Map> users_map, Array fn_output // remove def site. for (const auto& used_var : used) { ICHECK(users_map.count(used_var)); - Array var_users = users_map[used_var]; + ffi::Array var_users = users_map[used_var]; // remove the unused var from the use site. if (auto it = std::find(var_users.begin(), var_users.end(), unused[i]); it != var_users.end()) { @@ -244,11 +244,11 @@ std::set GetUnusedVars(Map> users_map, Array fn_output class RemoveUnusedVars : public ExprMutator { public: std::set unused_vars; - Optional caught_rewrite = std::nullopt; + ffi::Optional caught_rewrite = std::nullopt; RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} - RemoveUnusedVars(Map> users, Array fn_outputs) + RemoveUnusedVars(ffi::Map> users, ffi::Array fn_outputs) : RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {} void VisitBinding_(const VarBindingNode* binding) override { @@ -345,7 +345,7 @@ Expr RemoveAllUnused(Expr expr) { } RemoveUnusedVars remover(var_usage.downstream_usage, - Array(externally_exposed.begin(), externally_exposed.end())); + ffi::Array(externally_exposed.begin(), externally_exposed.end())); return remover.VisitExpr(std::move(expr)); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 1a725db904b0..c3ead8cb4676 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -74,13 +74,13 @@ class BlockBuilderImpl : public BlockBuilderNode { IRModule Finalize() final { return transform::NormalizeGlobalVar()(context_mod_); } - GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { + GlobalVar AddFunction(const BaseFunc& func, ffi::String func_name_hint) final { LazyInitCtxFuncDedupMap(); auto it = ctx_func_dedup_map_->find(func); if (it == ctx_func_dedup_map_->end()) { context_mod_.CopyOnWrite(); - String func_name = GetUniqueName(func_name_hint); + ffi::String func_name = GetUniqueName(func_name_hint); while (context_mod_->ContainGlobalVar(func_name)) { func_name = GetUniqueName(func_name_hint); } @@ -160,7 +160,7 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- // Scope management //------------------------------- - Optional LookupBinding(const Var& var) final { + ffi::Optional LookupBinding(const Var& var) final { auto it = binding_table_.find(var->vid); if (it == binding_table_.end()) return std::nullopt; return it->second; @@ -170,7 +170,7 @@ class BlockBuilderImpl : public BlockBuilderNode { void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } - void BeginScope(Optional> params) final { + void BeginScope(ffi::Optional> params) final { // The current implementation handles the collection of shape var // defined in parameter struct info annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), @@ -205,7 +205,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // defined in parameter struct info annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), // but can be further improved. - Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + ffi::Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); for (const auto& kv : var_map) { const tir::Var& shape_var = kv.first; const PrimExpr& shape_expr = kv.second; @@ -239,11 +239,11 @@ class BlockBuilderImpl : public BlockBuilderNode { bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } - Var Emit(Expr expr, String name_hint) final { + Var Emit(Expr expr, ffi::String name_hint) final { return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); } - Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { + Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint) final { value = this->Normalize(value); CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) @@ -265,7 +265,7 @@ class BlockBuilderImpl : public BlockBuilderNode { return var; } - Var EmitOutput(Expr output, String name_hint) final { + Var EmitOutput(Expr output, ffi::String name_hint) final { BlockFrame* cur_frame = CurrentBlockFrame(); ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; @@ -317,7 +317,7 @@ class BlockBuilderImpl : public BlockBuilderNode { /*! * \brief List of bindings */ - Array bindings; + ffi::Array bindings; /*! \brief Whether current block is dataflow block. */ bool is_dataflow; /*! @@ -341,7 +341,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // // TODO(relax-team) tracks the var defined also through match-cast. /*! \brief set of defined symbolic vars, value as themself. */ - Map shape_var_map; + ffi::Map shape_var_map; }; /*! \brief A stack to store block frames. */ @@ -391,7 +391,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * and performs shape/type deductions by calling Normalize. * \return The new variable that \p expr is bound to. */ - Var Emit(Expr expr, bool is_dataflow, String name_hint) { + Var Emit(Expr expr, bool is_dataflow, ffi::String name_hint) { expr = this->Normalize(expr); Var var = CreateVar(is_dataflow, name_hint); @@ -413,7 +413,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * \param name_hint Name hint for the bound variable. * \return The created var. */ - Var CreateVar(bool is_dataflow, String name_hint) { + Var CreateVar(bool is_dataflow, ffi::String name_hint) { if (name_hint.empty()) { name_hint = is_dataflow ? "lv" : "gv"; } @@ -466,7 +466,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // shape vars as defined when calling BeginScope(params) class StructInfoVarCollector : public StructInfoVisitor { public: - static Map Collect(const StructInfo& struct_info) { + static ffi::Map Collect(const StructInfo& struct_info) { StructInfoVarCollector collector; collector(struct_info); return collector.shape_var_map_; @@ -478,17 +478,17 @@ class BlockBuilderImpl : public BlockBuilderNode { for (const PrimExpr& s : shape_expr->values) { // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) if (const auto* var = s.as()) { - shape_var_map_.Set(GetRef(var), s); + shape_var_map_.Set(ffi::GetRef(var), s); } } } } void VisitStructInfo_(const ShapeStructInfoNode* op) final { - for (const PrimExpr& s : op->values.value_or(Array())) { + for (const PrimExpr& s : op->values.value_or(ffi::Array())) { // Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1)) if (const auto* var = s.as()) { - shape_var_map_.Set(GetRef(var), s); + shape_var_map_.Set(ffi::GetRef(var), s); } } } @@ -503,7 +503,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } private: - Map shape_var_map_; + ffi::Map shape_var_map_; }; }; @@ -511,7 +511,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // Normalization //--------------------------------------- #define RELAX_EXPR_NORMALIZER_LEAF(OP) \ - Expr VisitExpr_(const OP* op) final { return GetRef(op); } + Expr VisitExpr_(const OP* op) final { return ffi::GetRef(op); } // TODO(relax-team): Check normalize logic after struct info. @@ -589,13 +589,13 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorstruct_info_.defined()) << "Var " << var->name_hint() << " does not have struct info."; - return GetRef(var); + return ffi::GetRef(var); } Expr VisitExpr_(const VarNode* var_ptr) final { auto var = VisitVar_(var_ptr); if (HasVoidStructInfo(var)) { - return VisitExpr(Tuple(Array{})); + return VisitExpr(Tuple(ffi::Array{})); } else { return var; } @@ -617,7 +617,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor new_fields; + ffi::Array new_fields; for (const Expr& field : op->fields) { Expr new_field = this->NormalizeArgument(field); @@ -625,10 +625,10 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op) : Tuple(new_fields, op->span); + Tuple tuple = unchanged ? ffi::GetRef(op) : Tuple(new_fields, op->span); // Update tuple fields. if (!tuple->struct_info_.defined()) { - Array tuple_sinfo; + ffi::Array tuple_sinfo; for (Expr field : tuple->fields) { tuple_sinfo.push_back(GetStructInfo(field)); } @@ -641,7 +641,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitWithNewScope(op->body, op->params); if (new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -650,11 +650,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->op); - Array new_args = op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); }); + ffi::Array new_args = + op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); }); Call call; if (new_op.same_as(op->op) && new_args.same_as(op->args)) { - call = GetRef(op); + call = ffi::GetRef(op); } else { call = Call(new_op, new_args, op->attrs, op->sinfo_args); } @@ -670,7 +671,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop, nullptr); func_normalize != nullptr) { - Expr normalized = func_normalize(GetRef(this), call); + Expr normalized = func_normalize(ffi::GetRef(this), call); if (!normalized.same_as(call)) { return VisitExpr(normalized); } @@ -682,7 +683,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor new_blocks; + ffi::Array new_blocks; for (BindingBlock block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); new_blocks.push_back(new_block); @@ -711,12 +712,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor normalized_blocks = NormalizeBlocks(new_blocks); + ffi::Array normalized_blocks = NormalizeBlocks(new_blocks); unchanged &= normalized_blocks.same_as(new_blocks); SeqExpr seq_expr; if (unchanged) { - seq_expr = GetRef(op); + seq_expr = ffi::GetRef(op); } else { seq_expr = SeqExpr(normalized_blocks, new_body, op->span); } @@ -736,7 +737,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorcond) && new_true.same_as(op->true_branch) && new_false.same_as(op->false_branch)) { - if_node = GetRef(op); + if_node = ffi::GetRef(op); } else { if_node = If(new_cond, new_true, new_false, op->span); } @@ -751,7 +752,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->tuple); - TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) + TupleGetItem node = new_tuple.same_as(op->tuple) ? ffi::GetRef(op) : TupleGetItem(new_tuple, op->index); if (!node->struct_info_.defined()) { @@ -767,11 +768,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - return this->VisitVarBinding(GetRef(var_binding)); + return this->VisitVarBinding(ffi::GetRef(var_binding)); } else { auto* match_cast = binding.as(); ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); - return this->VisitMatchCast(GetRef(match_cast)); + return this->VisitMatchCast(ffi::GetRef(match_cast)); } } @@ -824,7 +825,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop.as()) { // Case 1: the op field is a primitive op, look up FInferStructInfo attribute - Op op = GetRef(op_ptr); + Op op = ffi::GetRef(op_ptr); bool is_dist_op = false; for (const auto& arg : call->args) { if (arg->struct_info_.as()) { @@ -839,18 +840,18 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorname; - return op_map_dist_infer_struct_info_[op](call, GetRef(this)); + return op_map_dist_infer_struct_info_[op](call, ffi::GetRef(this)); } ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - return op_map_infer_struct_info_[op](call, GetRef(this)); + return op_map_infer_struct_info_[op](call, ffi::GetRef(this)); } else { // derive using function parameters ICHECK(call->op->struct_info_.defined()); auto opt = MatchStructInfo(call->op); ICHECK(opt) << "Call->op must contains a function struct info"; FuncStructInfo finfo = opt.value(); - return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), &analyzer_); } } @@ -862,7 +863,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { + auto f_shape_var_map = [curr_scope](tir::Var var) -> ffi::Optional { auto it = curr_scope->shape_var_map.find(var); if (it != curr_scope->shape_var_map.end()) return (*it).second; return std::nullopt; @@ -870,7 +871,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor> params = std::nullopt) { + Expr VisitWithNewScope(const Expr& expr, ffi::Optional> params = std::nullopt) { if (params.defined()) { this->BeginScope(params.value()); } else { @@ -891,7 +892,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor() && prologue->bindings.empty()) { return post; } - Array bindings; + ffi::Array bindings; if (!prologue->bindings.empty()) { bindings.push_back(prologue); } @@ -906,15 +907,15 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor FlattenBlocks(const Array& blocks) { + ffi::Array FlattenBlocks(const ffi::Array& blocks) { // If there is a binding that is a seq expr, split the current block, // add the nested blocks prior to the seq expr, and bind the seq expr body // to the var - Array ret; + ffi::Array ret; bool changed = false; for (const BindingBlock& block : blocks) { bool is_dataflow = block->IsInstance(); - Array current; + ffi::Array current; for (const Binding& binding : block->bindings) { Expr value; if (const auto* var_binding = binding.as()) { @@ -950,8 +951,8 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor{}))); - Array free_dataflow_vars; + auto free_vars = FreeVars(SeqExpr({block}, Tuple(ffi::Array{}))); + ffi::Array free_dataflow_vars; for (const auto& var : free_vars) { if (auto opt = var.as()) { free_dataflow_vars.push_back(opt.value()); @@ -987,9 +988,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor NormalizeBlocks(const Array& blocks) { + ffi::Array NormalizeBlocks(const ffi::Array& blocks) { bool changed = false; - Array ret; + ffi::Array ret; auto flattened = FlattenBlocks(blocks); if (!flattened.same_as(blocks)) { changed = true; @@ -1003,11 +1004,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - auto n = make_object(*dataflow_block); + auto n = ffi::make_object(*dataflow_block); n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); merged = DataflowBlock(n); } else if (const auto* binding_block = ret.back().as()) { - auto n = make_object(*binding_block); + auto n = ffi::make_object(*binding_block); n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); merged = BindingBlock(n); } else { @@ -1036,14 +1037,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor mod) { - ObjectPtr n = make_object(mod.value_or(IRModule())); +BlockBuilder BlockBuilder::Create(ffi::Optional mod) { + ObjectPtr n = ffi::make_object(mod.value_or(IRModule())); return BlockBuilder(n); } -BlockBuilder BlockBuilder::Create(Optional mod, +BlockBuilder BlockBuilder::Create(ffi::Optional mod, BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) { - ObjectPtr n = make_object( + ObjectPtr n = ffi::make_object( mod.value_or(IRModule()), BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript()); return BlockBuilder(n); } @@ -1056,27 +1057,27 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.BlockBuilderCreate", - [](Optional mod) { return BlockBuilder::Create(mod); }) + [](ffi::Optional mod) { return BlockBuilder::Create(mod); }) .def_method("relax.BlockBuilderBeginDataflowBlock", &BlockBuilderNode::BeginDataflowBlock) .def_method("relax.BlockBuilderBeginBindingBlock", &BlockBuilderNode::BeginBindingBlock) .def_method("relax.BlockBuilderEndBlock", &BlockBuilderNode::EndBlock) .def_method("relax.BlockBuilderNormalize", &BlockBuilderNode::Normalize) .def("relax.BlockBuilderEmit", - [](BlockBuilder builder, Expr expr, String name_hint) { + [](BlockBuilder builder, Expr expr, ffi::String name_hint) { return builder->Emit(expr, name_hint); }) .def("relax.BlockBuilderEmitMatchCast", - [](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { + [](BlockBuilder builder, Expr value, StructInfo struct_info, ffi::String name_hint) { return builder->EmitMatchCast(value, struct_info, name_hint); }) .def("relax.BlockBuilderEmitOutput", - [](BlockBuilder builder, const Expr& output, String name_hint) { + [](BlockBuilder builder, const Expr& output, ffi::String name_hint) { return builder->EmitOutput(output, name_hint); }) .def("relax.BlockBuilderEmitNormalized", [](BlockBuilder builder, Binding binding) { return builder->EmitNormalized(binding); }) .def("relax.BlockBuilderGetUniqueName", - [](BlockBuilder builder, String name_hint) { + [](BlockBuilder builder, ffi::String name_hint) { return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, /*add_underscore*/ false); }) diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index def0e61c986c..b6479f702d44 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -135,7 +135,7 @@ struct MatchState { static std::optional TryMatch(const PNode& p, const RNode& r, const MatchState& current_match, DFPatternMatcher* m, const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + if (!m->Match(ffi::GetRef(p.ptr), ffi::GetRef(r.ptr))) return std::nullopt; MatchState new_match; @@ -192,15 +192,15 @@ static std::optional TryValidate( const std::vector& validation_constraints, arith::Analyzer* analyzer) { MatchState new_match; - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> ffi::Optional { auto it = pattern2node.find(pattern); ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << "DFConstraint attempted to access DFPattern " << ffi::GetRef(pattern) << ", which does not appear in the PatternContext"; const auto& p_node = it->second; if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); + return ffi::GetRef(ptr); } else { return std::nullopt; } @@ -289,9 +289,9 @@ static std::optional MatchTree( return std::nullopt; } -Optional> MatchGraph(const PatternContext& ctx, - const Array& binding_arr, - const Map& bindings) { +ffi::Optional> MatchGraph(const PatternContext& ctx, + const ffi::Array& binding_arr, + const ffi::Map& bindings) { // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; DFPatternMatcher matcher(bindings); @@ -351,15 +351,16 @@ Optional> MatchGraph(const PatternContext& ctx, return std::nullopt; } - Map ret; + ffi::Map ret; for (const auto& [pat, p_node] : pattern2node) { ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + ret.Set(ffi::GetRef(pat), ffi::GetRef(match->matched(p_node))); } return ret; } -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { +ffi::Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb) { return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } @@ -373,9 +374,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ class PatternContextRewriterNode : public PatternMatchingRewriterNode { public: PatternContext pattern; - ffi::TypedFunction(Map, Map)> rewriter_func; + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -388,14 +390,14 @@ class PatternContextRewriterNode : public PatternMatchingRewriterNode { TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); private: - Optional> MatchBindings(const Array& bindings) const { - Map var_lookup; + ffi::Optional> MatchBindings(const ffi::Array& bindings) const { + ffi::Map var_lookup; for (const auto& binding : bindings) { var_lookup.Set(binding->var, GetBoundValue(binding)); } if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { - Map replacements = rewriter_func(matches.value(), var_lookup); + ffi::Map replacements = rewriter_func(matches.value(), var_lookup); if (replacements.size()) { return replacements; } @@ -409,16 +411,17 @@ class PatternContextRewriter : public PatternMatchingRewriter { public: PatternContextRewriter( PatternContext pattern, - ffi::TypedFunction(Map, Map)> rewriter_func); + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func); TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, PatternContextRewriterNode); }; -RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec PatternContextRewriterNode::RewriteBindings(const ffi::Array& bindings) const { std::vector remaining_bindings{bindings.begin(), bindings.end()}; - Map variable_rewrites; + ffi::Map variable_rewrites; while (auto opt = MatchBindings(remaining_bindings)) { auto new_rewrites = opt.value(); remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), @@ -436,8 +439,9 @@ RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bi PatternContextRewriter::PatternContextRewriter( PatternContext pattern, - ffi::TypedFunction(Map, Map)> rewriter_func) { - auto node = make_object(); + ffi::TypedFunction(ffi::Map, ffi::Map)> + rewriter_func) { + auto node = ffi::make_object(); node->pattern = std::move(pattern); node->rewriter_func = std::move(rewriter_func); data_ = std::move(node); @@ -445,7 +449,7 @@ PatternContextRewriter::PatternContextRewriter( Function RewriteBindings( const PatternContext& ctx, - ffi::TypedFunction(Map, Map)> rewriter, + ffi::TypedFunction(ffi::Map, ffi::Map)> rewriter, Function func) { // return BlockPatternRewriter::Run(ctx, rewriter, func); return Downcast(PatternContextRewriter(ctx, rewriter)(func)); diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 21000fec0cb8..a01bdddb9804 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -45,11 +45,11 @@ namespace relax { namespace { class GlobalVarReplacer : public ExprMutator { public: - explicit GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + explicit GlobalVarReplacer(ffi::Map gvar_map) : gvar_map_(gvar_map) {} using ExprMutator::VisitExpr_; Expr VisitExpr_(const GlobalVarNode* op) override { - auto gvar = GetRef(op); + auto gvar = ffi::GetRef(op); if (auto opt = gvar_map_.Get(gvar)) { gvar = opt.value(); } @@ -57,10 +57,10 @@ class GlobalVarReplacer : public ExprMutator { } private: - Map gvar_map_; + ffi::Map gvar_map_; }; -Array TopologicalSort(const Array& bindings) { +ffi::Array TopologicalSort(const ffi::Array& bindings) { std::unordered_set remaining_bindings; for (const auto& binding : bindings) { remaining_bindings.insert(binding->var); @@ -74,7 +74,7 @@ Array TopologicalSort(const Array& bindings) { bool emitted; }; std::vector delayed_bindings; - Array sorted_bindings; + ffi::Array sorted_bindings; // Utility function to append the auto push_sorted_binding = [&](Binding binding) { @@ -159,7 +159,7 @@ void RewriteSpec::Append(RewriteSpec other) { gvar_name_supply->ReserveName(gvar->name_hint); } - Map gvar_rewrites; + ffi::Map gvar_rewrites; for (auto [gvar, func] : other.new_subroutines) { if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { // The two rewrites provide the same GlobalVar. @@ -197,14 +197,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("relax.dpl.PatternMatchingRewriterFromPattern", [](DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func) { + ffi::TypedFunction(Expr, ffi::Map)> func) { return PatternMatchingRewriter::FromPattern(pattern, func); }) .def("relax.dpl.PatternMatchingRewriterFromModule", [](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }) .def("relax.dpl.PatternMatchingRewriterApply", [](PatternMatchingRewriter rewriter, - Variant obj) -> Variant { + ffi::Variant obj) -> ffi::Variant { if (auto expr = obj.as()) { return rewriter(expr.value()); } else if (auto mod = obj.as()) { @@ -215,9 +215,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindings) const { - Map variable_rewrites; - Map binding_lookup; +RewriteSpec ExprPatternRewriterNode::RewriteBindings(const ffi::Array& bindings) const { + ffi::Map variable_rewrites; + ffi::Map binding_lookup; for (const auto& binding : bindings) { auto bound_value = GetBoundValue(binding); if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { @@ -233,8 +233,8 @@ RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindi } } -Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, - const Map& bindings) const { +ffi::Optional ExprPatternRewriterNode::RewriteExpr( + const Expr& expr, const ffi::Map& bindings) const { if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { auto matches = opt_matches.value(); if (additional_bindings) { @@ -249,7 +249,7 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, } } - Optional rewritten_expr = func(expr, matches); + ffi::Optional rewritten_expr = func(expr, matches); if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { return rewritten_expr.value(); } @@ -261,15 +261,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.dpl.PatternRewriter", - [](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { + [](DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func) { return ExprPatternRewriter(pattern, func); }); }); ExprPatternRewriter::ExprPatternRewriter( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, Map new_subroutines) { - auto node = make_object(); + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { + auto node = ffi::make_object(); node->pattern = std::move(pattern); node->func = std::move(func); node->additional_bindings = std::move(additional_bindings); @@ -277,7 +280,7 @@ ExprPatternRewriter::ExprPatternRewriter( data_ = std::move(node); } -RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec OrRewriterNode::RewriteBindings(const ffi::Array& bindings) const { auto lhs_match = lhs->RewriteBindings(bindings); if (!lhs_match) { // If no rewrites found on LHS, RHS is allowed to modify any @@ -291,7 +294,7 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons // the LHS. Variable replacements from the RHS may still occur, // but will need to wait for the next round of // iterate-until-converged. - Array remaining_bindings; + ffi::Array remaining_bindings; for (const auto& binding : bindings) { if (!lhs_match.variable_rewrites.count(binding->var)) { remaining_bindings.push_back(binding); @@ -316,17 +319,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { - auto node = make_object(); + auto node = ffi::make_object(); node->lhs = std::move(lhs); node->rhs = std::move(rhs); data_ = std::move(node); } -RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec TupleRewriterNode::RewriteBindings(const ffi::Array& bindings) const { CHECK_LE(patterns.size(), 3) << "For performance reasons, " << "matching of implicit tuple patterns is currently limited" << " to tuples with 3 elements or fewer."; - Map variable_rewrites = GenerateVariableRewrites(bindings); + ffi::Map variable_rewrites = GenerateVariableRewrites(bindings); if (variable_rewrites.size()) { return RewriteSpec{variable_rewrites, new_subroutines}; @@ -335,10 +338,11 @@ RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) c } } -Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { - Map rewrites; +ffi::Map TupleRewriterNode::GenerateVariableRewrites( + const ffi::Array& bindings) const { + ffi::Map rewrites; - Map binding_lookup; + ffi::Map binding_lookup; std::vector info_vec; @@ -534,7 +538,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( } } - Map merged_matches = info_vec[indices[0]].matches[0].value(); + ffi::Map merged_matches = info_vec[indices[0]].matches[0].value(); for (size_t i = 1; i < indices.size(); i++) { for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { if (auto it = merged_matches.find(pat); it != merged_matches.end()) { @@ -572,7 +576,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( } auto full_tuple = [&]() -> relax::Expr { - Array fields; + ffi::Array fields; for (size_t index : indices) { fields.push_back(info_vec[index].expr); } @@ -606,18 +610,20 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.TupleRewriter", - [](Array patterns, - ffi::TypedFunction(Expr, Map)> func) { - return TupleRewriter(patterns, func); - }); + refl::GlobalDef().def( + "relax.dpl.TupleRewriter", + [](ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func) { + return TupleRewriter(patterns, func); + }); }); -TupleRewriter::TupleRewriter(Array patterns, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, - Map new_subroutines) { - auto node = make_object(); +TupleRewriter::TupleRewriter( + ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { + auto node = ffi::make_object(); node->patterns = std::move(patterns); node->func = std::move(func); node->additional_bindings = std::move(additional_bindings); @@ -626,8 +632,10 @@ TupleRewriter::TupleRewriter(Array patterns, } PatternMatchingRewriter PatternMatchingRewriter::FromPattern( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings, Map new_subroutines) { + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings, + ffi::Map new_subroutines) { if (auto or_pattern = pattern.as()) { auto new_additional_bindings = additional_bindings.value_or({}); new_additional_bindings.push_back(pattern); @@ -678,10 +686,10 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return Downcast(base_func); }(); - Map new_subroutines; + ffi::Map new_subroutines; for (const auto& [gvar, func] : mod->functions) { if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { - bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); CHECK(!is_public) << "ValueError: " << "Expected module to have no publicly-exposed functions " << "other than 'pattern' and 'replacement'. " @@ -699,8 +707,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { << "but the pattern has struct info " << sinfo_pattern << ", while the replacement has struct info " << sinfo_replacement; - Array param_wildcards; - Map pattern_lookup; + ffi::Array param_wildcards; + ffi::Map pattern_lookup; for (const auto& param : func_pattern->params) { WildcardPattern wildcard; param_wildcards.push_back(wildcard); @@ -752,15 +760,15 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { DFPattern top_pattern = make_pattern(func_pattern->body->body); - ffi::TypedFunction(Expr, Map)> rewriter_func = + ffi::TypedFunction(Expr, ffi::Map)> rewriter_func = [param_wildcards = std::move(param_wildcards), orig_func_replacement = std::move(func_replacement)]( - Expr expr, Map matches) -> Optional { + Expr expr, ffi::Map matches) -> ffi::Optional { auto func_replacement = CopyWithNewVars(orig_func_replacement); - Array new_blocks; + ffi::Array new_blocks; - Array wildcard_bindings; + ffi::Array wildcard_bindings; ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); for (size_t i = 0; i < param_wildcards.size(); i++) { Expr matched_expr = matches[param_wildcards[i]]; @@ -787,8 +795,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { new_subroutines); } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { +ffi::Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, ffi::Optional> bindings_opt) { auto bindings = bindings_opt.value_or({}); DFPatternMatcher matcher(bindings); @@ -804,7 +812,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.dpl.extract_matched_expr", ExtractMatchedExpr); }); -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { +bool MatchExpr(DFPattern pattern, Expr expr, ffi::Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } @@ -823,7 +831,7 @@ class PatternMatchingMutator : public ExprMutator { PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) : rewriter_(rewriter) {} - Map GetNewSubroutines() const { return new_subroutines_; } + ffi::Map GetNewSubroutines() const { return new_subroutines_; } Expr VisitExpr_(const SeqExprNode* seq) override { SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); @@ -861,13 +869,13 @@ class PatternMatchingMutator : public ExprMutator { return prev; } - Optional TryRewriteSeqExpr(const SeqExpr& seq) { - Array old_blocks = seq->blocks; + ffi::Optional TryRewriteSeqExpr(const SeqExpr& seq) { + ffi::Array old_blocks = seq->blocks; // If the SeqExpr's output is not a variable, treat it as if it // were the last variable binding of the last block. This // simplifies the special handling of the SeqExpr's body. - Optional dummy_output_var = std::nullopt; + ffi::Optional dummy_output_var = std::nullopt; if (!seq->body->IsInstance()) { dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); VarBinding dummy_binding(dummy_output_var.value(), seq->body); @@ -878,7 +886,7 @@ class PatternMatchingMutator : public ExprMutator { old_blocks.pop_back(); return last_block; } else { - return BindingBlock(Array{}); + return BindingBlock(ffi::Array{}); } }(); @@ -886,7 +894,7 @@ class PatternMatchingMutator : public ExprMutator { old_blocks.push_back(last_block); } - auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrite_block = [&](ffi::Array orig_bindings) -> ffi::Array { auto rewrites = rewriter_->RewriteBindings(orig_bindings); if (!rewrites) return orig_bindings; @@ -921,7 +929,7 @@ class PatternMatchingMutator : public ExprMutator { // Utility function to return the rewrites that should be applied // to a given block. - auto get_rewrites = [&](BindingBlock block) -> Array { + auto get_rewrites = [&](BindingBlock block) -> ffi::Array { if (block.as()) { // Early return for DataflowBlock. Since neither control flow // nor impure functions are allowed within the dataflow block, @@ -931,8 +939,8 @@ class PatternMatchingMutator : public ExprMutator { RewriteSpec rewrites; - Array collected_bindings; - Array finalized_bindings; + ffi::Array collected_bindings; + ffi::Array finalized_bindings; auto handle_collected_rewrites = [&]() { if (collected_bindings.size()) { @@ -1029,17 +1037,17 @@ class PatternMatchingMutator : public ExprMutator { private: const PatternMatchingRewriterNode* rewriter_; - Map new_subroutines_; + ffi::Map new_subroutines_; }; Expr PatternMatchingRewriter::operator()(Expr expr) { PatternMatchingMutator mutator(get()); auto new_expr = mutator(expr); auto new_subroutines = mutator.GetNewSubroutines(); - CHECK_EQ(new_subroutines.size(), 0) - << "If PatternMatchingRewriter provides subroutines, " - << "then it must be applied to an entire IRModule. " - << "However, PatternMatchingRewriter produced subroutines " << [&]() -> Array { + CHECK_EQ(new_subroutines.size(), 0) << "If PatternMatchingRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, PatternMatchingRewriter produced subroutines " + << [&]() -> ffi::Array { std::vector vec; for (const auto& [gvar, func] : new_subroutines) { vec.push_back(gvar); @@ -1079,7 +1087,8 @@ tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { } Function RewriteCall(const DFPattern& pat, - ffi::TypedFunction)> rewriter, Function func) { + ffi::TypedFunction)> rewriter, + Function func) { return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index b70c97cc3d13..5c0fd6d8f554 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -60,7 +60,7 @@ using tvm::arith::Analyzer; * \param attributes The attributes to match. * \return True if the attributes match, false otherwise. */ -bool MatchAttrs(const Any& attrs, const Map& attributes) { +bool MatchAttrs(const Any& attrs, const ffi::Map& attributes) { // TODO(tqchen): consider lift to common utils if (auto* dict_attrs = attrs.as()) { for (auto kv : attributes) { @@ -85,7 +85,7 @@ bool MatchAttrs(const Any& attrs, const Map& attributes) { const Object* obj = attrs.cast(); ffi::reflection::ForEachFieldInfoWithEarlyStop( type_info, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); + ffi::String field_name(field_info->name); if (attributes.count(field_name)) { ffi::reflection::FieldGetter field_getter(field_info); ffi::Any field_value = field_getter(obj); @@ -108,12 +108,12 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { - auto unwrap = [&](Expr expr) -> Optional { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const ffi::Map& var2val) { + auto unwrap = [&](Expr expr) -> ffi::Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { if (const VarNode* var = expr.as()) { - if (auto may = var2val.Get(GetRef(var))) { + if (auto may = var2val.Get(ffi::GetRef(var))) { return may.value(); } } @@ -187,7 +187,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons VLOG(1) << "considering AttrPatternNode at:\n" << expr; auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); for (auto kv : attributes) { auto attr_name = kv.first; auto attr_value = kv.second; @@ -257,8 +257,8 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex if (matches_op) { auto watermark2 = matched_nodes_.size(); - auto match_args = [this, &watermark2](const Array& pattern_args, auto expr_begin, - auto expr_end) { + auto match_args = [this, &watermark2](const ffi::Array& pattern_args, + auto expr_begin, auto expr_end) { bool matches = true; auto pattern_it = pattern_args.begin(); auto expr_it = expr_begin; @@ -385,8 +385,8 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e return matches; } -bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array patterns, - const tvm::Array fields, +bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::ffi::Array patterns, + const tvm::ffi::Array fields, std::vector& match_cache, std::vector& matched) { if (idx >= patterns.size()) return true; @@ -456,7 +456,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { return condition; } - auto sort_key = [](PrimExpr expr) -> String { + auto sort_key = [](PrimExpr expr) -> ffi::String { if (const auto* equal = expr.as()) { if (const auto* var = equal->a.as()) { return var->name_hint; @@ -476,7 +476,8 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { return analyzer_.Simplify(sorted_condition); } -static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { +static bool ShapeEqual(Analyzer* analyzer, const ffi::Array& lhs, + const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; @@ -495,8 +496,8 @@ bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& e } std::tuple SameShapeConstraintNode::AsPrimExpr( - std::function(const DFPatternNode*)> match_state) const { - Optional> expected_shape; + std::function(const DFPatternNode*)> match_state) const { + ffi::Optional> expected_shape; bool all_shapes_defined = true; // The expression that must be true in order @@ -505,7 +506,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( for (const auto& arg : args) { if (auto opt_var = match_state(arg.get())) { auto var = opt_var.value(); - auto opt_var_shape = [&]() -> Optional> { + auto opt_var_shape = [&]() -> ffi::Optional> { auto sinfo = GetStructInfo(var); if (auto tensor = sinfo.as()) { return tensor->GetShape(); diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 71fa4a4c35c1..bece0af12070 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -38,15 +38,15 @@ namespace relax { class DFPatternMatcher : public DFPatternFunctor { public: - using var2val_t = Map; + using var2val_t = ffi::Map; explicit DFPatternMatcher() {} explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} bool Match(const DFPattern& pattern, const Expr& expr); - Map GetMemo() { return memo_; } + ffi::Map GetMemo() { return memo_; } /* \brief Unwrap trivial expressions/bindings */ - static Expr UnwrapBindings(Expr expr, const Map& bindings); + static Expr UnwrapBindings(Expr expr, const ffi::Map& bindings); protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -73,8 +73,8 @@ class DFPatternMatcher : public DFPatternFunctor patterns, - const tvm::Array fields, std::vector& match_cache, + bool TryUnorderedMatch(size_t idx, const tvm::ffi::Array patterns, + const tvm::ffi::Array fields, std::vector& match_cache, std::vector& matched); /* \brief Simplify a boolean condition using the analyzer diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index f0f40e4df1a1..581752e6257f 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -63,29 +63,29 @@ TVM_FFI_STATIC_INIT_BLOCK({ REPR_LAMBDA(p, node); \ }) -ExternFuncPattern::ExternFuncPattern(String global_symbol) { - ObjectPtr n = make_object(); +ExternFuncPattern::ExternFuncPattern(ffi::String global_symbol) { + ObjectPtr n = ffi::make_object(); n->global_symbol_ = std::move(global_symbol); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.ExternFuncPattern", - [](String global_symbol) { return ExternFuncPattern(global_symbol); }); + [](ffi::String global_symbol) { return ExternFuncPattern(global_symbol); }); }); RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; }); -VarPattern::VarPattern(String name_hint) { - ObjectPtr n = make_object(); +VarPattern::VarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.VarPattern", - [](String name_hint) { return VarPattern(name_hint); }); + [](ffi::String name_hint) { return VarPattern(name_hint); }); }); RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { p->stream << "VarPattern(" << node->name_hint() << ")"; @@ -94,10 +94,10 @@ RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.DataflowVarPattern", - [](String name_hint) { return DataflowVarPattern(name_hint); }); + [](ffi::String name_hint) { return DataflowVarPattern(name_hint); }); }); -DataflowVarPattern::DataflowVarPattern(String name_hint) { - ObjectPtr n = make_object(); +DataflowVarPattern::DataflowVarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } @@ -105,22 +105,22 @@ RELAX_PATTERN_PRINTER_DEF(DataflowVarPatternNode, [](auto p, auto node) { p->stream << "DataflowVarPattern(" << node->name_hint() << ")"; }); -GlobalVarPattern::GlobalVarPattern(String name_hint) { - ObjectPtr n = make_object(); +GlobalVarPattern::GlobalVarPattern(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name_hint); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.GlobalVarPattern", - [](String name_hint) { return GlobalVarPattern(name_hint); }); + [](ffi::String name_hint) { return GlobalVarPattern(name_hint); }); }); RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; }); ExprPattern::ExprPattern(Expr expr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->expr = std::move(expr); data_ = std::move(n); } @@ -133,15 +133,15 @@ RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.ConstantPattern", []() { - auto c = ConstantPattern(make_object()); + auto c = ConstantPattern(ffi::make_object()); return c; }); }); RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, [](auto p, auto node) { p->stream << "ConstantPattern()"; }); -CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_wildcard) { - ObjectPtr n = make_object(); +CallPattern::CallPattern(DFPattern op, ffi::Array args, bool varg_default_wildcard) { + ObjectPtr n = ffi::make_object(); n->op = std::move(op); n->args = std::move(args); n->varg_default_wildcard = varg_default_wildcard; @@ -150,7 +150,7 @@ CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.CallPattern", - [](DFPattern op, Array args, bool varg_default_wildcard) { + [](DFPattern op, ffi::Array args, bool varg_default_wildcard) { return CallPattern(op, args, varg_default_wildcard); }); }); @@ -167,66 +167,67 @@ RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { p->stream << ")"; }); -PrimArrPattern::PrimArrPattern(Array arr) { - ObjectPtr n = make_object(); +PrimArrPattern::PrimArrPattern(ffi::Array arr) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(arr); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.PrimArrPattern", - [](Array arr) { return PrimArrPattern(std::move(arr)); }); + [](ffi::Array arr) { return PrimArrPattern(std::move(arr)); }); }); RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { p->stream << "PrimArrPattern(" << node->fields << ")"; }); -FunctionPattern::FunctionPattern(Array params, DFPattern body) { - ObjectPtr n = make_object(); +FunctionPattern::FunctionPattern(ffi::Array params, DFPattern body) { + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.FunctionPattern", [](Array params, DFPattern body) { - return FunctionPattern(params, body); - }); + refl::GlobalDef().def( + "relax.dpl.FunctionPattern", + [](ffi::Array params, DFPattern body) { return FunctionPattern(params, body); }); }); RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; }); -TuplePattern::TuplePattern(tvm::Array fields) { - ObjectPtr n = make_object(); +TuplePattern::TuplePattern(tvm::ffi::Array fields) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.TuplePattern", - [](tvm::Array fields) { return TuplePattern(fields); }); + [](tvm::ffi::Array fields) { return TuplePattern(fields); }); }); RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { p->stream << "TuplePattern(" << node->fields << ")"; }); -UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { - ObjectPtr n = make_object(); +UnorderedTuplePattern::UnorderedTuplePattern(tvm::ffi::Array fields) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.UnorderedTuplePattern", - [](tvm::Array fields) { return UnorderedTuplePattern(fields); }); + refl::GlobalDef().def("relax.dpl.UnorderedTuplePattern", [](tvm::ffi::Array fields) { + return UnorderedTuplePattern(fields); + }); }); RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { p->stream << "UnorderedTuplePattern(" << node->fields << ")"; }); TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->tuple = std::move(tuple); n->index = index; data_ = std::move(n); @@ -242,7 +243,7 @@ RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { }); AndPattern::AndPattern(DFPattern left, DFPattern right) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->left = std::move(left); n->right = std::move(right); data_ = std::move(n); @@ -257,7 +258,7 @@ RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { }); OrPattern::OrPattern(DFPattern left, DFPattern right) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->left = std::move(left); n->right = std::move(right); data_ = std::move(n); @@ -272,7 +273,7 @@ RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { }); NotPattern::NotPattern(DFPattern reject) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->reject = std::move(reject); data_ = std::move(n); } @@ -284,7 +285,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ RELAX_PATTERN_PRINTER_DEF(NotPatternNode, [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); -WildcardPattern::WildcardPattern() { data_ = make_object(); } +WildcardPattern::WildcardPattern() { data_ = ffi::make_object(); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.WildcardPattern", []() { return WildcardPattern(); }); @@ -292,7 +293,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->struct_info = std::move(struct_info); data_ = std::move(n); @@ -309,24 +310,24 @@ RELAX_PATTERN_PRINTER_DEF(StructInfoPatternNode, [](auto p, auto node) { << node->struct_info << ")"; }); -ShapePattern::ShapePattern(DFPattern pattern, Array shape) { - ObjectPtr n = make_object(); +ShapePattern::ShapePattern(DFPattern pattern, ffi::Array shape) { + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->shape = std::move(shape); data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.ShapePattern", [](DFPattern pattern, Array shape) { - return ShapePattern(pattern, shape); - }); + refl::GlobalDef().def( + "relax.dpl.ShapePattern", + [](DFPattern pattern, ffi::Array shape) { return ShapePattern(pattern, shape); }); }); RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; }); -SameShapeConstraint::SameShapeConstraint(Array args) { - ObjectPtr n = make_object(); +SameShapeConstraint::SameShapeConstraint(ffi::Array args) { + ObjectPtr n = ffi::make_object(); n->args = std::move(args); data_ = std::move(n); @@ -337,7 +338,7 @@ SameShapeConstraint::SameShapeConstraint(Array args) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.dpl.SameShapeConstraint", - [](Array args) { return SameShapeConstraint(args); }); + [](ffi::Array args) { return SameShapeConstraint(args); }); }); RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { p->stream << "SameShapeConstraint("; @@ -351,7 +352,7 @@ RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { }); DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->dtype = std::move(dtype); data_ = std::move(n); @@ -367,7 +368,7 @@ RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { }); AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); n->attrs = std::move(attrs); data_ = std::move(n); @@ -396,10 +397,10 @@ class DFPatternDuplicator : public DFPatternFunctor DFPattern VisitDFPattern_(const NotPatternNode* op) override { return NotPattern(op->reject); } DFPattern VisitDFPattern_(const VarPatternNode* op) override { return VarPattern(op->name); } DFPattern VisitDFPattern_(const ConstantPatternNode* op) override { - return ConstantPattern(make_object()); + return ConstantPattern(ffi::make_object()); } DFPattern VisitDFPattern_(const WildcardPatternNode* op) override { - return WildcardPattern(make_object()); + return WildcardPattern(ffi::make_object()); } DFPattern VisitDFPattern_(const ExprPatternNode* op) override { return ExprPattern(op->expr); } DFPattern VisitDFPattern_(const GlobalVarPatternNode* op) override { @@ -443,7 +444,7 @@ class DFPatternDuplicator : public DFPatternFunctor // Syntatic Sugar CallPattern DFPattern::operator()(const std::vector& args) const { - return CallPattern(*this, Array(args)); + return CallPattern(*this, ffi::Array(args)); } OrPattern DFPattern::operator|(const DFPattern& other) const { return OrPattern(*this, other); } @@ -451,7 +452,7 @@ AndPattern DFPattern::operator&(const DFPattern& other) const { return AndPatter NotPattern DFPattern::operator~() const { return NotPattern(*this); } -AttrPattern DFPattern::HasAttr(const Map& attrs) const { +AttrPattern DFPattern::HasAttr(const ffi::Map& attrs) const { return AttrPattern(*this, DictAttrs(attrs)); } StructInfoPattern DFPattern::HasStructInfo(const StructInfo& struct_info) const { @@ -463,7 +464,7 @@ DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { return HasDtype(DataType(ffi::StringToDLDataType(dtype))); } -ShapePattern DFPattern::HasShape(const Array& shape) const { +ShapePattern DFPattern::HasShape(const ffi::Array& shape) const { return ShapePattern(*this, shape); } @@ -474,13 +475,13 @@ std::stack& pattern_ctx_stack() { return graph_pattern_managers; } -Optional PatternContext::Current() { +ffi::Optional PatternContext::Current() { if (pattern_ctx_stack().empty()) return std::nullopt; return pattern_ctx_stack().top(); } PatternContext::PatternContext(bool incremental) { - auto n = make_object(); + auto n = ffi::make_object(); if (incremental) { ICHECK(!pattern_ctx_stack().empty()) << "Incremental context needs to be built inside a existing context."; @@ -506,16 +507,16 @@ static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, P } PatternSeq::PatternSeq(DFPattern init_pattern) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = {init_pattern}; n->pair_constraints = {}; data_ = std::move(n); } -PatternSeq::PatternSeq(tvm::Array patterns, bool only_used_by) { +PatternSeq::PatternSeq(tvm::ffi::Array patterns, bool only_used_by) { ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::vector(n->patterns.size() - 1, cons); data_ = std::move(n); @@ -532,8 +533,8 @@ PatternSeq PatternSeq::OnlyUsedBy(PatternSeq other, int index) const { PatternSeq PatternSeq::dup() const { PatternSeq ret; - ObjectPtr n = make_object(); - n->patterns = Array{}; + ObjectPtr n = ffi::make_object(); + n->patterns = ffi::Array{}; n->patterns.reserve(get()->patterns.size()); n->pair_constraints = this->get()->pair_constraints; @@ -549,9 +550,10 @@ PatternSeq PatternSeq::dup() const { } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.PatternSeq", [](Array patterns, bool only_used_by) { - return PatternSeq(std::move(patterns), only_used_by); - }); + refl::GlobalDef().def("relax.dpl.PatternSeq", + [](ffi::Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); }); RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "["; @@ -580,7 +582,7 @@ PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), PairCons{PairCons::kUsedBy, index}); - Array patterns; + ffi::Array patterns; patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); @@ -591,7 +593,7 @@ PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), rhs->pair_constraints.end()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::move(pair_constraints); ret.data_ = std::move(n); @@ -607,7 +609,7 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), constraint); - Array patterns; + ffi::Array patterns; patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); @@ -618,7 +620,7 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), rhs->pair_constraints.end()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->patterns = std::move(patterns); n->pair_constraints = std::move(pair_constraints); ret.data_ = std::move(n); @@ -627,13 +629,13 @@ PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { } PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.OnlyUsedBy(rhs); } -VarPattern IsVar(const String& name) { return VarPattern(name); } -ConstantPattern IsConst() { return ConstantPattern(make_object()); } -WildcardPattern Wildcard() { return WildcardPattern(make_object()); } +VarPattern IsVar(const ffi::String& name) { return VarPattern(name); } +ConstantPattern IsConst() { return ConstantPattern(ffi::make_object()); } +WildcardPattern Wildcard() { return WildcardPattern(ffi::make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } -ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args, - Optional tir_vars) { +ExprPattern IsOp(const ffi::String& op_name) { return IsExpr(Op::Get(op_name)); } +CallPattern IsCallTIR(const ffi::String& name, ffi::Optional var_args, + ffi::Optional tir_vars) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -647,10 +649,10 @@ CallPattern IsCallTIR(const String& name, Optional var_args, return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); } -CallPattern IsCallTIR(const String& name, TuplePattern var_args) { +CallPattern IsCallTIR(const ffi::String& name, TuplePattern var_args) { return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); } -CallPattern IsCallDPSPacked(const String& name, Optional var_args) { +CallPattern IsCallDPSPacked(const ffi::String& name, ffi::Optional var_args) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -661,11 +663,11 @@ CallPattern IsCallDPSPacked(const String& name, Optional var_args) return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern); } -CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) { +CallPattern IsCallDPSPacked(const ffi::String& name, TuplePattern var_args) { return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args); } -DFPattern IsTuple(const Array& fields, bool unordered) { +DFPattern IsTuple(const ffi::Array& fields, bool unordered) { if (unordered) return UnorderedTuplePattern(fields); else diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index 6b64226d77b7..c6fe514bbc9f 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -40,8 +40,8 @@ namespace tvm { namespace relax { struct RewriteSpec { - Map variable_rewrites; - Map new_subroutines; + ffi::Map variable_rewrites; + ffi::Map new_subroutines; explicit operator bool() const { return variable_rewrites.size(); } @@ -50,7 +50,7 @@ struct RewriteSpec { class PatternMatchingRewriterNode : public tvm::transform::PassNode { public: - virtual RewriteSpec RewriteBindings(const Array& bindings) const { + virtual RewriteSpec RewriteBindings(const ffi::Array& bindings) const { return RewriteSpec(); } @@ -68,9 +68,10 @@ class PatternMatchingRewriterNode : public tvm::transform::PassNode { class PatternMatchingRewriter : public tvm::transform::Pass { public: static PatternMatchingRewriter FromPattern( - DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + DFPattern pattern, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); static PatternMatchingRewriter FromModule(IRModule mod); @@ -83,13 +84,13 @@ class PatternMatchingRewriter : public tvm::transform::Pass { class ExprPatternRewriterNode : public PatternMatchingRewriterNode { public: DFPattern pattern; - ffi::TypedFunction(Expr, Map)> func; - Optional> additional_bindings; - Map new_subroutines; + ffi::TypedFunction(Expr, ffi::Map)> func; + ffi::Optional> additional_bindings; + ffi::Map new_subroutines; - RewriteSpec RewriteBindings(const Array& bindings) const final; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const final; - Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + ffi::Optional RewriteExpr(const Expr& expr, const ffi::Map& bindings) const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -105,9 +106,9 @@ class ExprPatternRewriterNode : public PatternMatchingRewriterNode { class ExprPatternRewriter : public PatternMatchingRewriter { public: ExprPatternRewriter(DFPattern pattern, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, ExprPatternRewriterNode); @@ -118,7 +119,7 @@ class OrRewriterNode : public PatternMatchingRewriterNode { PatternMatchingRewriter lhs; PatternMatchingRewriter rhs; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -140,12 +141,12 @@ class OrRewriter : public PatternMatchingRewriter { class TupleRewriterNode : public PatternMatchingRewriterNode { public: - Array patterns; - ffi::TypedFunction(Expr, Map)> func; - Optional> additional_bindings; - Map new_subroutines; + ffi::Array patterns; + ffi::TypedFunction(Expr, ffi::Map)> func; + ffi::Optional> additional_bindings; + ffi::Map new_subroutines; - RewriteSpec RewriteBindings(const Array& bindings) const override; + RewriteSpec RewriteBindings(const ffi::Array& bindings) const override; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -161,12 +162,12 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { struct VarInfo { Var var; Expr expr; - Array>> matches; + ffi::Array>> matches; std::unordered_set downstream_usage; bool used = false; }; - Map GenerateVariableRewrites(const Array& bindings) const; + ffi::Map GenerateVariableRewrites(const ffi::Array& bindings) const; std::optional> TryMatchByBindingIndex(const std::vector& info_vec, const std::vector& indices) const; @@ -174,10 +175,10 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { class TupleRewriter : public PatternMatchingRewriter { public: - TupleRewriter(Array patterns, - ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = std::nullopt, - Map new_subroutines = {}); + TupleRewriter(ffi::Array patterns, + ffi::TypedFunction(Expr, ffi::Map)> func, + ffi::Optional> additional_bindings = std::nullopt, + ffi::Map new_subroutines = {}); TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); }; diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index d46b634ca7c9..a57434567185 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -38,8 +38,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_FFI_STATIC_INIT_BLOCK({ RXPlaceholderOpNode::RegisterReflection(); }); -te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { - auto n = make_object(); +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name) { + auto n = ffi::make_object(); n->name = name; n->value = value; @@ -51,7 +51,7 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string int ndim = constant->data->ndim; ffi::Shape shape_tuple = constant->data.Shape(); - Array shape; + ffi::Array shape; shape.reserve(ndim); for (int i = 0; i < ndim; ++i) { shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index aa7cb9db538e..af0dace29c07 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -64,7 +64,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { * shape of the input Expr. * \param name The name of the created tensor. */ -te::Tensor TETensor(Expr value, Map tir_var_map, std::string name); +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 844fd890e1fd..b7123259456c 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -52,20 +52,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ ExternFuncNode::RegisterReflection(); }); -Id::Id(String name_hint) { - ObjectPtr n = make_object(); +Id::Id(ffi::String name_hint) { + ObjectPtr n = ffi::make_object(); n->name_hint = std::move(name_hint); data_ = std::move(n); } -Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { +Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array sinfo_args, + Span span) { CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance()) << "ValueError: " << "Call expects its operator to have FuncStructInfo, " << "but operator " << op << ", which was called with arguments " << args << ", has struct info " << op->struct_info_; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); @@ -74,14 +75,15 @@ Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, data_ = std::move(n); } -Call WithFields(Call call, Optional opt_op, Optional> opt_args, - Optional opt_attrs, Optional> opt_sinfo_args, - Optional opt_span) { +Call WithFields(Call call, ffi::Optional opt_op, ffi::Optional> opt_args, + ffi::Optional opt_attrs, + ffi::Optional> opt_sinfo_args, + ffi::Optional opt_span) { // Collect new values for fields. Expr op = opt_op.value_or(call->op); - Array args = opt_args.value_or(call->args); + ffi::Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); - Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + ffi::Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); Span span = opt_span.value_or(call->span); // Check if anything changed. @@ -119,13 +121,14 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Call", - [](Expr op, Array args, Attrs attrs, Array sinfo_args, - Span span) { return Call(op, args, attrs, sinfo_args, span); }); + refl::GlobalDef().def("relax.Call", [](Expr op, ffi::Array args, Attrs attrs, + ffi::Array sinfo_args, Span span) { + return Call(op, args, attrs, sinfo_args, span); + }); }); If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); @@ -133,8 +136,8 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { data_ = std::move(n); } -If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, - Optional opt_false_branch, Optional opt_span) { +If WithFields(If if_expr, ffi::Optional opt_cond, ffi::Optional opt_true_branch, + ffi::Optional opt_false_branch, ffi::Optional opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); @@ -160,9 +163,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -Tuple::Tuple(tvm::Array fields, Span span) { - Optional tuple_sinfo = [&]() -> Optional { - Array field_sinfo; +Tuple::Tuple(tvm::ffi::Array fields, Span span) { + ffi::Optional tuple_sinfo = [&]() -> ffi::Optional { + ffi::Array field_sinfo; for (const auto& field : fields) { if (field->struct_info_.defined()) { field_sinfo.push_back(GetStructInfo(field)); @@ -173,7 +176,7 @@ Tuple::Tuple(tvm::Array fields, Span span) { return TupleStructInfo(field_sinfo); }(); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = std::move(span); n->struct_info_ = tuple_sinfo; @@ -182,12 +185,13 @@ Tuple::Tuple(tvm::Array fields, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Tuple", - [](tvm::Array fields, Span span) { return Tuple(fields, span); }); + refl::GlobalDef().def( + "relax.Tuple", [](tvm::ffi::Array fields, Span span) { return Tuple(fields, span); }); }); -Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { - Array fields = opt_fields.value_or(tuple->fields); +Tuple WithFields(Tuple tuple, ffi::Optional> opt_fields, + ffi::Optional opt_span) { + ffi::Array fields = opt_fields.value_or(tuple->fields); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; @@ -211,7 +215,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple << " cannot be accessed with negative index " << index; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); if (auto* tuple_info = tuple->struct_info_.as()) { CHECK_LT(index, tuple_info->fields.size()) @@ -226,8 +230,8 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { data_ = std::move(n); } -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { +TupleGetItem WithFields(TupleGetItem tuple_get_item, ffi::Optional opt_tuple, + ffi::Optional opt_index, ffi::Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); Span span = opt_span.value_or(tuple_get_item->span); @@ -250,8 +254,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -ShapeExpr::ShapeExpr(Array values, Span span) { - ObjectPtr n = make_object(); +ShapeExpr::ShapeExpr(ffi::Array values, Span span) { + ObjectPtr n = ffi::make_object(); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { @@ -268,12 +272,13 @@ ShapeExpr::ShapeExpr(Array values, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ShapeExpr", - [](Array values, Span span) { return ShapeExpr(values, span); }); + refl::GlobalDef().def("relax.ShapeExpr", [](ffi::Array values, Span span) { + return ShapeExpr(values, span); + }); }); -Var::Var(Id vid, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +Var::Var(Id vid, ffi::Optional struct_info_annotation, Span span) { + ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -290,9 +295,9 @@ VarNode* Var::CopyOnWrite() { if (!data_.unique()) { ObjectPtr node; if (auto dataflow_var = as()) { - node = make_object(*dataflow_var); + node = ffi::make_object(*dataflow_var); } else { - node = make_object(*(operator->())); + node = ffi::make_object(*(operator->())); } ObjectPtr(std::move(node)).swap(data_); } @@ -302,15 +307,14 @@ VarNode* Var::CopyOnWrite() { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("relax.Var", [](String name_hint, Optional struct_info_annotation, + .def("relax.Var", [](ffi::String name_hint, ffi::Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); }) - .def("relax.VarFromId", [](Id vid, Optional struct_info_annotation, Span span) { - return Var(vid, struct_info_annotation, span); - }); + .def("relax.VarFromId", [](Id vid, ffi::Optional struct_info_annotation, + Span span) { return Var(vid, struct_info_annotation, span); }); }); -DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +DataflowVar::DataflowVar(Id vid, ffi::Optional struct_info_annotation, Span span) { + ObjectPtr n = ffi::make_object(); n->vid = std::move(vid); n->struct_info_ = std::move(struct_info_annotation); n->span = std::move(span); @@ -322,22 +326,23 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.DataflowVar", - [](String name_hint, Optional struct_info_annotation, Span span) { + [](ffi::String name_hint, ffi::Optional struct_info_annotation, Span span) { return DataflowVar(name_hint, struct_info_annotation, span); }) .def("relax.DataflowVarFromId", - [](Id vid, Optional struct_info_annotation, Span span) { + [](Id vid, ffi::Optional struct_info_annotation, Span span) { return DataflowVar(vid, struct_info_annotation, span); }); }); -Constant::Constant(runtime::Tensor data, Optional struct_info_annotation, Span span) { - ObjectPtr n = make_object(); +Constant::Constant(runtime::Tensor data, ffi::Optional struct_info_annotation, + Span span) { + ObjectPtr n = ffi::make_object(); n->data = std::move(data); n->span = std::move(span); // set struct info. - Array values; + ffi::Array values; auto shape_tuple = n->data.Shape(); for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); @@ -356,12 +361,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.Constant", - [](runtime::Tensor data, Optional struct_info_annotation = std::nullopt, + [](runtime::Tensor data, ffi::Optional struct_info_annotation = std::nullopt, Span span = Span()) { return Constant(data, struct_info_annotation, span); }); }); PrimValue::PrimValue(PrimExpr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->struct_info_ = PrimStructInfo(value); n->value = std::move(value); n->span = std::move(span); @@ -378,8 +383,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](PrimExpr value, Span span) { return PrimValue(value, span); }); }); -StringImm::StringImm(String value, Span span) { - ObjectPtr n = make_object(); +StringImm::StringImm(ffi::String value, Span span) { + ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); n->struct_info_ = ObjectStructInfo(); @@ -389,11 +394,11 @@ StringImm::StringImm(String value, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.StringImm", - [](String value, Span span) { return StringImm(value, span); }); + [](ffi::String value, Span span) { return StringImm(value, span); }); }); DataTypeImm::DataTypeImm(DataType value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->value = std::move(value); n->span = std::move(span); n->struct_info_ = ObjectStructInfo(); @@ -407,7 +412,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); ICHECK(var.defined()) << "MatchCast requires var to be defined"; n->var = std::move(var); n->value = std::move(value); @@ -425,7 +430,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); VarBinding::VarBinding(Var var, Expr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->var = std::move(var); n->value = std::move(value); n->span = span; @@ -467,8 +472,8 @@ uint64_t VarBindingNode::SHash(uint64_t init_hash, return hash_value; } -BindingBlock::BindingBlock(Array bindings, Span span) { - ObjectPtr n = make_object(); +BindingBlock::BindingBlock(ffi::Array bindings, Span span) { + ObjectPtr n = ffi::make_object(); n->bindings = std::move(bindings); n->span = span; data_ = std::move(n); @@ -484,9 +489,9 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { if (!data_.unique()) { ObjectPtr node; if (auto dataflow_block = as()) { - node = make_object(*dataflow_block); + node = ffi::make_object(*dataflow_block); } else { - node = make_object(*(operator->())); + node = ffi::make_object(*(operator->())); } ObjectPtr(std::move(node)).swap(data_); } @@ -495,13 +500,13 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.BindingBlock", [](Array bindings, Span span) { + refl::GlobalDef().def("relax.BindingBlock", [](ffi::Array bindings, Span span) { return BindingBlock(bindings, span); }); }); -DataflowBlock::DataflowBlock(Array bindings, Span span) { - ObjectPtr n = make_object(); +DataflowBlock::DataflowBlock(ffi::Array bindings, Span span) { + ObjectPtr n = ffi::make_object(); n->bindings = std::move(bindings); n->span = span; data_ = std::move(n); @@ -509,7 +514,7 @@ DataflowBlock::DataflowBlock(Array bindings, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.DataflowBlock", [](Array bindings, Span span) { + refl::GlobalDef().def("relax.DataflowBlock", [](ffi::Array bindings, Span span) { return DataflowBlock(bindings, span); }); }); @@ -518,12 +523,12 @@ SeqExpr::SeqExpr(Expr body) { if (auto seq = body.as()) { *this = seq.value(); } else { - *this = SeqExpr(Array{}, body); + *this = SeqExpr(ffi::Array{}, body); } } -SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { - ObjectPtr n = make_object(); +SeqExpr::SeqExpr(ffi::Array blocks, Expr body, Span span) { + ObjectPtr n = ffi::make_object(); n->blocks = std::move(blocks); n->body = std::move(body); n->span = span; @@ -532,13 +537,13 @@ SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.SeqExpr", [](Array blocks, Expr body, Span span) { + refl::GlobalDef().def("relax.SeqExpr", [](ffi::Array blocks, Expr body, Span span) { return SeqExpr(blocks, body, span); }); }); -Function::Function(Array params, Expr body, Optional ret_struct_info, bool is_pure, - DictAttrs attrs, Span span) { +Function::Function(ffi::Array params, Expr body, ffi::Optional ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { if (!attrs.defined()) { attrs = DictAttrs(); } @@ -546,7 +551,7 @@ Function::Function(Array params, Expr body, Optional ret_struct // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. - Array param_sinfo; + ffi::Array param_sinfo; for (const Var& param : params) { CHECK(param->struct_info_.defined()) @@ -554,7 +559,7 @@ Function::Function(Array params, Expr body, Optional ret_struct param_sinfo.push_back(GetStructInfo(param)); } - Optional body_sinfo; + ffi::Optional body_sinfo; if (body->struct_info_.defined()) { body_sinfo = GetStructInfo(body); @@ -580,7 +585,7 @@ Function::Function(Array params, Expr body, Optional ret_struct auto f_shape_var_map = [&] { auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); - return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + return [lookup = std::move(lookup)](const tir::Var& var) -> ffi::Optional { if (lookup.count(var)) { return var; } else { @@ -594,7 +599,7 @@ Function::Function(Array params, Expr body, Optional ret_struct FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->ret_struct_info = ret_struct_info.value(); @@ -607,16 +612,16 @@ Function::Function(Array params, Expr body, Optional ret_struct TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.Function", - [](Array params, Expr body, Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { - return Function(params, body, ret_struct_info, is_pure, attrs, span); - }); + refl::GlobalDef().def("relax.Function", [](ffi::Array params, Expr body, + ffi::Optional ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { + return Function(params, body, ret_struct_info, is_pure, attrs, span); + }); }); -Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bool is_pure, +Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { - Array param_sinfo; + ffi::Array param_sinfo; for (const Var& param : params) { ICHECK(param->struct_info_.defined()) << "relax.Function requires params to contain struct_info_."; @@ -634,7 +639,7 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo }(); // set the fields - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->is_pure = is_pure; @@ -648,8 +653,8 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.FunctionCreateEmpty", - [](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { + "relax.FunctionCreateEmpty", [](ffi::Array params, StructInfo ret_struct_info, + bool is_pure, DictAttrs attrs, Span span) { return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); }); }); @@ -680,15 +685,15 @@ FuncStructInfo GetExternFuncStructInfo() { return FuncStructInfo::OpaqueFunc(derive); } -ExternFunc::ExternFunc(String global_symbol, Span span) +ExternFunc::ExternFunc(ffi::String global_symbol, Span span) : ExternFunc(global_symbol, GetExternFuncStructInfo(), span) {} -ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) { +ExternFunc::ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span) { CHECK(struct_info.as()) << "ExternFunc must have FuncStructInfo, " << "but declaration of '" << global_symbol << "' received " << struct_info; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->global_symbol = std::move(global_symbol); n->span = span; n->struct_info_ = struct_info; @@ -697,14 +702,14 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ExternFunc", - [](String global_symbol, Optional struct_info, Span span) { - if (struct_info.defined()) { - return ExternFunc(global_symbol, struct_info.value(), span); - } else { - return ExternFunc(global_symbol, span); - } - }); + refl::GlobalDef().def("relax.ExternFunc", [](ffi::String global_symbol, + ffi::Optional struct_info, Span span) { + if (struct_info.defined()) { + return ExternFunc(global_symbol, struct_info.value(), span); + } else { + return ExternFunc(global_symbol, span); + } + }); }); Expr GetShapeOf(const Expr& expr) { @@ -727,20 +732,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("relax.GetShapeOf", [](const Expr& expr) { return GetShapeOf(expr); }) .def("relax.FuncWithAttr", - [](BaseFunc func, String key, ObjectRef value) -> Optional { + [](BaseFunc func, ffi::String key, ObjectRef value) -> ffi::Optional { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } return std::nullopt; }) .def("relax.FuncWithAttrs", - [](BaseFunc func, Map attr_map) -> Optional { + [](BaseFunc func, ffi::Map attr_map) -> ffi::Optional { if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } return std::nullopt; }) - .def("relax.FuncWithoutAttr", [](BaseFunc func, String key) -> Optional { + .def("relax.FuncWithoutAttr", [](BaseFunc func, ffi::String key) -> ffi::Optional { if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index d772613b5d04..9ddf0f274aff 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -127,7 +127,7 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitExpr(field); } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -135,7 +135,7 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) { void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -167,7 +167,7 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -178,7 +178,7 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->false_branch); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -189,7 +189,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -200,7 +200,7 @@ void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -217,14 +217,14 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitExpr(op->body); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } void ExprVisitor::VisitExpr_(const PrimValueNode* op) { this->VisitPrimExpr(op->value); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } this->VisitSpan(op->span); } @@ -360,24 +360,24 @@ StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( const FuncStructInfoNode* op) { // Do not recurse into function struct info // as they won't contain ref to values in current scope. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { // Constant' struct info won't be affected by Expr/PrimExpr change. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { // FuncStructInfo won't be affected by Expr/PrimExpr change. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { bool unchanged = true; - tvm::Array fields; + tvm::ffi::Array fields; for (Expr field : op->fields) { Expr new_field = this->VisitExpr(field); fields.push_back(new_field); @@ -388,7 +388,7 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { // If tuple's struct info change it means that // one of its fields' struct info will change // so un-changed already implies that struct info won't change - return GetRef(op); + return ffi::GetRef(op); } else { // when there is a change return a new tuple node return Tuple(fields, op->span); @@ -399,7 +399,7 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { // struct info of var-use should remain stable // or the var itself will get replaced - return GetRef(op); + return ffi::GetRef(op); } // Visit the use-site of a defined DataflowVar @@ -413,7 +413,7 @@ Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { Expr body = this->VisitExpr(op->body); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -423,14 +423,14 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { Expr new_op = this->VisitExpr(call_node->op); bool unchanged = call_node->op.same_as(new_op); - Array sinfo_args; + ffi::Array sinfo_args; for (StructInfo sinfo_arg : call_node->sinfo_args) { StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); sinfo_args.push_back(new_sinfo_arg); unchanged &= new_sinfo_arg.same_as(sinfo_arg); } - tvm::Array call_args; + tvm::ffi::Array call_args; for (Expr arg : call_node->args) { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); @@ -438,7 +438,7 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { } if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { - return GetRef(call_node); + return ffi::GetRef(call_node); } else { return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); } @@ -451,20 +451,20 @@ Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } } -Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { auto t = this->VisitExpr(op->tuple); if (op->tuple.same_as(t)) { // struct info can be deterministically derived by tuple and index // if t does not change, then struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } else { return TupleGetItem(t, op->index, op->span); } @@ -475,21 +475,21 @@ Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { if (op->value.same_as(value)) { // struct info can be deterministically derived by value // if value does not change, then struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } return PrimValue(value, op->span); } -Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return ffi::GetRef(op); } -Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); if (values.same_as(op->values)) { // If values does not change, struct info won't change. - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeExpr(values, op->span); } @@ -497,12 +497,12 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { // StructInfo of function remains value independent. - return GetRef(op); + return ffi::GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -515,13 +515,13 @@ Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } return SeqExpr(blocks, body); } BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { - Array bindings; + ffi::Array bindings; if (const auto* node = block.as()) { for (auto binding : node->bindings) { if (auto var_binding = binding.as()) { @@ -562,7 +562,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { } // default case return self. - return GetRef(op); + return ffi::GetRef(op); } // Visit the use-site of a defined DataflowVar @@ -571,7 +571,7 @@ Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (Var param : op->params) { Var new_param = this->VisitVarDef(param); @@ -586,7 +586,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { if (all_params_unchanged && body.same_as(op->body)) { // No changes to the function, return the original object - return GetRef(op); + return ffi::GetRef(op); } else if (IsBaseOf(GetStructInfo(body), op->ret_struct_info)) { // If the function was mutated into a form that can no longer // propagate shape information all the way to the return value, we @@ -615,7 +615,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } @@ -623,7 +623,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -642,7 +642,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return SeqExpr(blocks, body); } @@ -671,7 +671,7 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { // fast path: re-emit binding if nothing changes if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); return; } @@ -704,7 +704,7 @@ void ExprMutator::VisitBinding_(const MatchCastNode* binding) { if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && new_struct_info.same_as(binding->struct_info)) { // re-emit old binding if nothing changes - return GetRef(binding); + return ffi::GetRef(binding); } else { new_value = builder_->NormalizeArgument(new_value); new_var = WithStructInfo(new_var, new_struct_info); @@ -749,14 +749,14 @@ Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { Var ExprMutator::VisitVarDef_(const VarNode* var) { if (auto* sinfo = var->struct_info_.as()) { - StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + StructInfo struct_info = this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); if (struct_info.same_as(var->struct_info_)) { - return GetRef(var); + return ffi::GetRef(var); } else { return Var(var->vid, struct_info, var->span); } } else { - return GetRef(var); + return ffi::GetRef(var); } } @@ -794,7 +794,7 @@ Var ExprMutator::VisitVarDef(const Var& var) { return ret; } -Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { +Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::Optional> params) { ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; @@ -838,7 +838,9 @@ Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { return ret; } -Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } +ffi::Optional ExprMutator::LookupBinding(const Var& var) { + return builder_->LookupBinding(var); +} Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { ICHECK(struct_info.defined()); diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 299839d31f4b..11867dee6db4 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -110,28 +110,29 @@ class PyExprVisitorNode : public Object, public ExprVisitor { PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); void VisitBinding_(const VarBindingNode* binding) - PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(binding), f_visit_var_binding_, ExprVisitor::VisitBinding_(binding)); void VisitBinding_(const MatchCastNode* binding) - PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(binding), f_visit_match_cast_, ExprVisitor::VisitBinding_(binding)); void VisitBindingBlock(const BindingBlock& block) PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); void VisitBindingBlock_(const BindingBlockNode* block) - PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(block), f_visit_binding_block_, ExprVisitor::VisitBindingBlock_(block)); void VisitBindingBlock_(const DataflowBlockNode* block) - PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(block), f_visit_dataflow_block_, ExprVisitor::VisitBindingBlock_(block)); void VisitVarDef(const Var& var) PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); void VisitVarDef_(const VarNode* var) - PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(var), f_visit_var_def_, + ExprVisitor::VisitVarDef_(var)); void VisitVarDef_(const DataflowVarNode* var) - PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + PY_EXPR_VISITOR_DEFAULT(ffi::GetRef(var), f_visit_dataflow_var_def_, ExprVisitor::VisitVarDef_(var)); void VisitSpan(const Span& span) @@ -227,7 +228,7 @@ class PyExprVisitor : public ObjectRef { ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, ffi::Function f_visit_span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_expr = f_visit_expr; n->f_visit_binding = f_visit_binding; n->f_visit_binding_block = f_visit_binding_block; @@ -348,14 +349,14 @@ class PyExprMutatorNode : public Object, public ExprMutator { void VisitBinding_(const VarBindingNode* binding) { if (f_visit_var_binding_ != nullptr) - f_visit_var_binding_(GetRef(binding)); + f_visit_var_binding_(ffi::GetRef(binding)); else ExprMutator::VisitBinding_(binding); } void VisitBinding_(const MatchCastNode* binding) { if (f_visit_match_cast_ != nullptr) - f_visit_match_cast_(GetRef(binding)); + f_visit_match_cast_(ffi::GetRef(binding)); else ExprMutator::VisitBinding_(binding); } @@ -365,18 +366,19 @@ class PyExprMutatorNode : public Object, public ExprMutator { BindingBlock); BindingBlock VisitBindingBlock_(const BindingBlockNode* block) - PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(block), f_visit_binding_block_, ExprMutator::VisitBindingBlock_(block), BindingBlock); BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) - PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(block), f_visit_dataflow_block_, ExprMutator::VisitBindingBlock_(block), BindingBlock); Var VisitVarDef(const Var& var) PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); - Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, - ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const VarNode* var) + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); Var VisitVarDef_(const DataflowVarNode* var) - PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + PY_EXPR_MUTATOR_DEFAULT(ffi::GetRef(var), f_visit_dataflow_var_def_, ExprMutator::VisitVarDef_(var), Var); /*! @@ -510,7 +512,7 @@ class PyExprMutator : public ObjectRef { ffi::Function f_visit_dataflow_block_, ffi::Function f_visit_var_def, ffi::Function f_visit_var_def_, ffi::Function f_visit_dataflow_var_def_, ffi::Function f_visit_span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->builder_ = builder_; n->f_visit_expr = f_visit_expr; n->f_visit_constant_ = f_visit_constant_; diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index d2460a42ce75..945c2e69ac89 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -41,7 +41,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); ObjectStructInfo::ObjectStructInfo(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } @@ -53,7 +53,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // Prim PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = value->dtype; n->value = std::move(value); n->span = span; @@ -61,7 +61,7 @@ PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { } PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->value = std::nullopt; n->span = span; @@ -78,8 +78,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Shape -ShapeStructInfo::ShapeStructInfo(Array values, Span span) { - ObjectPtr n = make_object(); +ShapeStructInfo::ShapeStructInfo(ffi::Array values, Span span) { + ObjectPtr n = ffi::make_object(); n->ndim = static_cast(values.size()); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { @@ -94,7 +94,7 @@ ShapeStructInfo::ShapeStructInfo(Array values, Span span) { } ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->span = span; @@ -104,7 +104,7 @@ ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.ShapeStructInfo", [](Optional> values, int ndim, Span span) { + "relax.ShapeStructInfo", [](ffi::Optional> values, int ndim, Span span) { if (values.defined()) { CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; return ShapeStructInfo(values.value(), span); @@ -115,11 +115,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Tensor -TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Optional vdevice, +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); // assign ndim before move - Optional sinfo = MatchStructInfo(shape); + ffi::Optional sinfo = MatchStructInfo(shape); ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; ICHECK(shape.defined()) << "Must provide a shape in this constructor"; ICHECK(shape->IsInstance() || shape->IsInstance()) @@ -133,8 +133,9 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Optional data_ = std::move(n); } -TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional vdevice, Span span) { - ObjectPtr n = make_object(); +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, ffi::Optional vdevice, + Span span) { + ObjectPtr n = ffi::make_object(); CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->dtype = dtype; @@ -145,20 +146,21 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional v TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TensorStructInfo", [](Optional shape, Optional dtype, - int ndim, VDevice vdevice, Span span) { - if (shape.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; - return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); - } else { - return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); - } - }); + refl::GlobalDef().def( + "relax.TensorStructInfo", [](ffi::Optional shape, ffi::Optional dtype, + int ndim, VDevice vdevice, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + } else { + return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); + } + }); }); // Tuple -TupleStructInfo::TupleStructInfo(Array fields, Span span) { - ObjectPtr n = make_object(); +TupleStructInfo::TupleStructInfo(ffi::Array fields, Span span) { + ObjectPtr n = ffi::make_object(); n->fields = std::move(fields); n->span = span; data_ = std::move(n); @@ -166,14 +168,15 @@ TupleStructInfo::TupleStructInfo(Array fields, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.TupleStructInfo", [](Array fields, Span span) { + refl::GlobalDef().def("relax.TupleStructInfo", [](ffi::Array fields, Span span) { return TupleStructInfo(fields, span); }); }); // Func -FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool purity, Span span) { - ObjectPtr n = make_object(); +FuncStructInfo::FuncStructInfo(ffi::Array params, StructInfo ret, bool purity, + Span span) { + ObjectPtr n = ffi::make_object(); n->params = std::move(params); n->ret = std::move(ret); n->purity = std::move(purity); @@ -183,7 +186,7 @@ FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, bool pu FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->derive_func = std::move(derive_func); n->ret = ObjectStructInfo(); n->purity = std::move(purity); @@ -192,7 +195,7 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, bool } FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ret = std::move(ret); n->purity = std::move(purity); n->span = span; @@ -203,12 +206,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("relax.FuncStructInfo", - [](Array params, StructInfo ret, bool purity, Span span) { + [](ffi::Array params, StructInfo ret, bool purity, Span span) { return FuncStructInfo(params, ret, purity, span); }) .def("relax.FuncStructInfoOpaqueFunc", - [](Optional ret, Optional derive_func, bool purity, - Span span) { + [](ffi::Optional ret, ffi::Optional derive_func, + bool purity, Span span) { if (derive_func.defined()) { ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc index ea8f1da8f04b..58df3c24ff8e 100644 --- a/src/relax/ir/struct_info_functor.cc +++ b/src/relax/ir/struct_info_functor.cc @@ -68,24 +68,24 @@ void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { } StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { - return GetRef(op); + return ffi::GetRef(op); } StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { if (!op->value.defined()) { - return GetRef(op); + return ffi::GetRef(op); } auto new_expr = VisitStructInfoExprField(op->value.value()); if (new_expr.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return PrimStructInfo(new_expr); } } StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { - Optional> values; + ffi::Optional> values; if (op->values.defined()) { // if no changes are made the original array will be returned. @@ -94,14 +94,14 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { } if (values.same_as(op->values)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ShapeStructInfo(values.value(), op->span); } } StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { - Optional shape; + ffi::Optional shape; if (op->shape.defined()) { shape = this->VisitStructInfoExprField(op->shape.value()); @@ -110,7 +110,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { VDevice vdev = op->vdevice.value_or(VDevice()); if (shape.same_as(op->shape)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TensorStructInfo(shape.value(), op->dtype, vdev, op->span); } @@ -123,18 +123,18 @@ StructInfo StructInfoMutator::VisitStructInfo_(const distributed::DTensorStructI } StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { - Array fields = + ffi::Array fields = op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); if (fields.same_as(op->fields)) { - return GetRef(op); + return ffi::GetRef(op); } else { return TupleStructInfo(fields, op->span); } } StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { - Optional> params; + ffi::Optional> params; if (op->params.defined()) { params = op->params.value().Map( @@ -144,7 +144,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { StructInfo ret = this->VisitStructInfo(op->ret); if (params.same_as(op->params) && ret.same_as(op->ret)) { - return GetRef(op); + return ffi::GetRef(op); } else { ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; return FuncStructInfo(params.value(), ret, op->purity, op->span); diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc index ab2d91abcc86..b5bd9df27777 100644 --- a/src/relax/ir/tir_pattern.cc +++ b/src/relax/ir/tir_pattern.cc @@ -24,9 +24,9 @@ namespace relax { TVM_FFI_STATIC_INIT_BLOCK({ MatchResultNode::RegisterReflection(); }); -MatchResult::MatchResult(TIRPattern pattern, Array symbol_values, - Array matched_buffers) { - auto n = make_object(); +MatchResult::MatchResult(TIRPattern pattern, ffi::Array symbol_values, + ffi::Array matched_buffers) { + auto n = ffi::make_object(); n->pattern = std::move(pattern); n->symbol_values = std::move(symbol_values); n->matched_buffers = std::move(matched_buffers); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index fb106e2092db..b33b5f82cb7e 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -103,7 +103,7 @@ class FunctionPass : public Pass { FunctionPass::FunctionPass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -138,7 +138,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) for (const auto& it : updated_mod->functions) { // only picks up relax::Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); auto updated_func = pass_func(func, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } @@ -160,7 +160,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } Pass CreateFunctionPass(std::function pass_func, - int opt_level, String name, tvm::Array required, bool traceable) { + int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(std::move(pass_func), pass_info); } @@ -238,14 +239,14 @@ class DataflowBlockMutator : public ExprMutator { */ BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock - Map global_scope_vars; - Map symbolic_vars; + ffi::Map global_scope_vars; + ffi::Map symbolic_vars; for (const Binding& binding : n->bindings) { Var var = binding->var; if (const auto* match_cast = binding.as()) { auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); for (const tir::VarNode* var : collected_vars) { - symbolic_vars.Set(var->name_hint, GetRef(var)); + symbolic_vars.Set(var->name_hint, ffi::GetRef(var)); } } if (!var.as()) { @@ -254,7 +255,7 @@ class DataflowBlockMutator : public ExprMutator { } // apply pass_func_ to the DataflowBlock - DataflowBlock block = GetRef(n); + DataflowBlock block = ffi::GetRef(n); DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars @@ -325,7 +326,7 @@ class DataflowBlockPass : public Pass { DataflowBlockPass::DataflowBlockPass( std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -361,7 +362,7 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass for (const auto& it : updated_mod->functions) { // only picks up relax::Function if (auto* n = it.second.as()) { - Function func = GetRef(n); + Function func = ffi::GetRef(n); Function updated_func = Downcast(dataflow_block_mutator.VisitExpr(func)); updates.push_back({it.first, updated_func}); } @@ -384,7 +385,7 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass Pass CreateDataflowBlockPass( std::function pass_func, int opt_level, - String name, tvm::Array required, bool traceable) { + ffi::String name, tvm::ffi::Array required, bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return DataflowBlockPass(std::move(pass_func), pass_info); } diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 1f0de47f1f83..9288801ab6dd 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -36,7 +36,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); ShapeType::ShapeType(int ndim, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = ndim; n->span = span; data_ = std::move(n); @@ -49,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); ObjectType::ObjectType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } @@ -60,7 +60,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); TensorType::TensorType(int ndim, DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = std::move(ndim); n->dtype = std::move(dtype); n->span = span; @@ -68,7 +68,7 @@ TensorType::TensorType(int ndim, DataType dtype, Span span) { } TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->ndim = -1; n->dtype = std::move(dtype); n->span = std::move(span); @@ -83,7 +83,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); PackedFuncType::PackedFuncType(Span span) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->span = span; data_ = std::move(n); } diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index f46150654f0e..9f48f72a3fec 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -34,8 +34,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ ScatterCollectiveAttrs::RegisterReflection(); }); -Expr allreduce(Expr x, String op_type, bool in_group) { - ObjectPtr attrs = make_object(); +Expr allreduce(Expr x, ffi::String op_type, bool in_group) { + ObjectPtr attrs = ffi::make_object(); attrs->op_type = std::move(op_type); attrs->in_group = std::move(in_group); @@ -64,7 +64,7 @@ TVM_REGISTER_OP("relax.ccl.allreduce") /* relax.ccl.allgather */ Expr allgather(Expr x, int num_workers, bool in_group) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->in_group = std::move(in_group); @@ -88,7 +88,7 @@ StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { if (!input_shape.defined()) { return input_sinfo; } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(0, floor(output_shape[0] * num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } @@ -126,7 +126,7 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0") /* relax.ccl.scatter_from_worker0 */ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.ccl.scatter_from_worker0"); @@ -158,7 +158,7 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { << " while num_workers is " << num_workers); } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h index 82ea3935675d..1d049382d0ae 100644 --- a/src/relax/op/ccl/ccl.h +++ b/src/relax/op/ccl/ccl.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief AllReduce. */ -Expr allreduce(Expr data, String op_type, bool in_group); +Expr allreduce(Expr data, ffi::String op_type, bool in_group); /*! \brief AllGather. */ Expr allgather(Expr data, int num_workers, bool in_group); diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index 7e89c6497dcc..127dec433afa 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -36,7 +36,8 @@ namespace distributed { template StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo x1_sinfo, x2_sinfo; x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; @@ -55,7 +56,7 @@ StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ct // Shapes and ndims if (x1_shape && x2_shape) { // If all inputs have shapes, directly infer shapes - Optional> output_shape = + ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!output_shape.defined()) { output_tensor_sinfo = TensorStructInfo(output_dtype, /*ndim=*/output_ndim); diff --git a/src/relax/op/distributed/ccl.cc b/src/relax/op/distributed/ccl.cc index 885b084856a1..6ba63986980e 100644 --- a/src/relax/op/distributed/ccl.cc +++ b/src/relax/op/distributed/ccl.cc @@ -25,7 +25,7 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index f9651d8225a4..87118074c95f 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -43,7 +43,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ DistributionAttrs::RegisterReflection(); }); Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, distributed::Placement placement) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->device_mesh = device_mesh; attrs->placement = placement; @@ -71,7 +71,7 @@ TVM_REGISTER_OP("relax.dist.annotate_sharding") Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, distributed::Placement placement) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->device_mesh = device_mesh; attrs->placement = placement; @@ -120,8 +120,8 @@ TVM_REGISTER_OP("relax.dist.call_tir_local_view") .set_attr("FPurity", Bool(true)); Expr MakeCallTIRLocalView(Expr func, Tuple args, - Array out_sinfo_list, - Optional packed_ints) { + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); CHECK(shape != nullptr) @@ -175,14 +175,14 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { << " while num_workers is " << num_workers); } - Array output_shape = input_shape.value(); + ffi::Array output_shape = input_shape.value(); output_shape.Set(attrs->axis, div(output_shape[attrs->axis], num_workers)); return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice); } StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { using namespace distributed; - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; @@ -212,7 +212,7 @@ StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { } Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->num_workers = std::move(num_workers); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.dist.redistribute_replica_to_shard"); diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 727b52c462ec..8fc9cd58d1fc 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -25,7 +25,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo x1_sinfo, x2_sinfo; x1_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; x2_sinfo = input_dtensor_sinfos[1]->tensor_sinfo; @@ -67,11 +68,11 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) ctx->ReportFatal(Diagnostic::Error(call) << "input of distributed operator must have shape"); } - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; - Optional> output_shape_prefix = + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -84,7 +85,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) << x1_reduction_length << " and " << x2_reduction_length << " respectively."); } - Array output_shape = output_shape_prefix.value(); + ffi::Array output_shape = output_shape_prefix.value(); if (!x1_prepended) { output_shape.push_back(x1_shape->values[x1_ndim - 2]); } diff --git a/src/relax/op/distributed/manipulate.cc b/src/relax/op/distributed/manipulate.cc index 8b18b9578eda..edd5fa7ee7f9 100644 --- a/src/relax/op/distributed/manipulate.cc +++ b/src/relax/op/distributed/manipulate.cc @@ -29,7 +29,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* attrs = call->attrs.as(); @@ -84,7 +85,8 @@ StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) if (call->args.size() != 2) { ctx->ReportFatal(Diagnostic::Error(call) << "Reshape op should take 2 arguments"); } - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); @@ -100,7 +102,7 @@ StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) << call->args[1]->struct_info_->GetTypeKey()); } - Optional> old_shape_values; + ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index ec0bdaeb3242..b020d7902f9b 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -24,7 +24,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); TensorStructInfo input_tensor_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; diff --git a/src/relax/op/distributed/statistical.cc b/src/relax/op/distributed/statistical.cc index 3bd0f0651718..44ee90e78976 100644 --- a/src/relax/op/distributed/statistical.cc +++ b/src/relax/op/distributed/statistical.cc @@ -25,7 +25,8 @@ namespace relax { namespace distributed { StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; const auto* attrs = call->attrs.as(); @@ -60,7 +61,7 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx->ReportFatal(Diagnostic::Error(call) << "Input of distributed operator must be known shape"); } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index cfde689421f7..727707a98525 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -34,7 +34,8 @@ namespace distributed { template StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); ICHECK(input_dtensor_sinfos.size() == 1); distributed::DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; @@ -47,7 +48,7 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, << " requires the input tensor to have float dtype. However, the given input dtype is " << input_tensor_sinfo->dtype); } - auto output_sinfo = make_object(*input_tensor_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_tensor_sinfo.get()); output_sinfo->dtype = f_compute_out_dtype(input_tensor_sinfo); TensorStructInfo out_tensor_sinfo(output_sinfo); return distributed::DTensorStructInfo(out_tensor_sinfo, input_dtensor_sinfo->device_mesh, diff --git a/src/relax/op/distributed/utils.cc b/src/relax/op/distributed/utils.cc index 39bdeea037c5..ffa7dbfa3085 100644 --- a/src/relax/op/distributed/utils.cc +++ b/src/relax/op/distributed/utils.cc @@ -24,16 +24,16 @@ namespace tvm { namespace relax { namespace distributed { -Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx) { +ffi::Array GetInputDTensorStructInfo(const Call& call, + const BlockBuilder& ctx) { Op op = Downcast(call->op); - Array args = GetCallArgs(call); - Array input_tensor_sinfo; + ffi::Array args = GetCallArgs(call); + ffi::Array input_tensor_sinfo; input_tensor_sinfo.reserve(args.size()); for (const Expr& arg : args) { const auto* sinfo = GetStructInfoAs(arg); if (sinfo != nullptr) { - input_tensor_sinfo.push_back(GetRef(sinfo)); + input_tensor_sinfo.push_back(ffi::GetRef(sinfo)); } } return input_tensor_sinfo; @@ -42,7 +42,8 @@ Array GetInputDTensorStructInfo(const Call& call StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, const StructInfo& orig_output_sinfo, distributed::FBuildAxisGraph f_build_graph) { - Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); + ffi::Array input_dtensor_sinfos = + GetInputDTensorStructInfo(call, ctx); for (int i = 1; i < static_cast(input_dtensor_sinfos.size()); i++) { ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, input_dtensor_sinfos[i]->device_mesh)); @@ -51,7 +52,7 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, Var output_var("output", orig_output_sinfo); distributed::AxisGroupGraph axis_group_graph; f_build_graph(output_var, call, &axis_group_graph); - Array args = GetCallArgs(call); + ffi::Array args = GetCallArgs(call); int n_input_var = input_dtensor_sinfos.size(); for (int i = 0; i < n_input_var; i++) { distributed::DTensorStructInfo dtensor_sinfo = input_dtensor_sinfos[i]; @@ -66,9 +67,9 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, } } axis_group_graph.PropagateShardingSpec(); - Array orig_output_tensor_sinfos; + ffi::Array orig_output_tensor_sinfos; if (const auto* tensor_sinfo = orig_output_sinfo.as()) { - orig_output_tensor_sinfos.push_back(GetRef(tensor_sinfo)); + orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else { const auto* tuple_sinfo = orig_output_sinfo.as(); ICHECK(tuple_sinfo); @@ -76,9 +77,9 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, orig_output_tensor_sinfos.push_back(Downcast(sinfo)); } } - Array new_output_dtensor_sinfos; + ffi::Array new_output_dtensor_sinfos; for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { - Array output_placement_specs( + ffi::Array output_placement_specs( std::vector(device_mesh->shape.size(), distributed::PlacementSpec::Replica())); for (int i = 0; i < orig_output_tensor_sinfos[idx]->ndim; i++) { diff --git a/src/relax/op/distributed/utils.h b/src/relax/op/distributed/utils.h index 1656df286784..125a2d242ba5 100644 --- a/src/relax/op/distributed/utils.h +++ b/src/relax/op/distributed/utils.h @@ -42,8 +42,8 @@ namespace distributed { * \return The dtensor struct info of each input. * \note This function require every input tensor to be DTensor. */ -Array GetInputDTensorStructInfo(const Call& call, - const BlockBuilder& ctx); +ffi::Array GetInputDTensorStructInfo(const Call& call, + const BlockBuilder& ctx); /*! * \brief Perform a local sharding spec propagation to infer the output dtensor diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index f6923ecb3ab4..e0aba16d8311 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -35,10 +35,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ Resize2DAttrs::RegisterReflection(); }); /* relax.resize2d */ -Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double cubic_alpha, - int cubic_exclude, double extrapolation_value, Optional out_dtype) { - ObjectPtr attrs = make_object(); +Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, + ffi::String method, ffi::String coordinate_transformation_mode, + ffi::String rounding_method, double cubic_alpha, int cubic_exclude, + double extrapolation_value, ffi::Optional out_dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -93,30 +94,30 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; - Optional data_shape = - CheckNdimPerLayoutAndGetShape(call, ctx, GetRef(data_sinfo), data_layout); + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape( + call, ctx, ffi::GetRef(data_sinfo), data_layout); if (!data_shape.defined() || size_value == nullptr) { return TensorStructInfo(out_dtype, data_layout.ndim(), data_sinfo->vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array out_NCHW_shape(data_NCHW_shape); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCHW_shape(data_NCHW_shape); out_NCHW_shape.Set(2, size_value->values[0]); out_NCHW_shape.Set(3, size_value->values[1]); - Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutResize2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutResize2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.image.resize2d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for resize2d. diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h index 3af171c7bfff..5125a17804a8 100644 --- a/src/relax/op/image/resize.h +++ b/src/relax/op/image/resize.h @@ -33,9 +33,10 @@ namespace tvm { namespace relax { /*! \brief Image resize2d operator. */ -Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double cubic_alpha, - int cubic_exclude, double extrapolation_value, Optional out_dtype); +Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, + ffi::String method, ffi::String coordinate_transformation_mode, + ffi::String rounding_method, double cubic_alpha, int cubic_exclude, + double extrapolation_value, ffi::Optional out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 87f6864824ae..5c7fc47057d7 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -30,8 +30,9 @@ namespace tvm { namespace relax { /* relax.op.memory.view */ -Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset) { - Tuple void_expr(Array{}); +Expr view(Expr x, ffi::Optional shape, ffi::Optional dtype, + ffi::Optional relative_byte_offset) { + Tuple void_expr(ffi::Array{}); static const Op& op = Op::Get("relax.memory.view"); return Call(op, { @@ -123,7 +124,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } }(); - auto view_relative_byte_offset = [&]() -> Optional { + auto view_relative_byte_offset = [&]() -> ffi::Optional { StructInfo sinfo = GetStructInfo(arg_relative_byte_offset); if (HasVoidStructInfo(arg_relative_byte_offset)) { @@ -152,9 +153,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } }(); - Optional> input_shape = data_sinfo->GetShape(); + ffi::Optional> input_shape = data_sinfo->GetShape(); - Optional> output_shape = std::nullopt; + ffi::Optional> output_shape = std::nullopt; int output_ndim = kUnknownNDim; if (view_shape_sinfo && view_shape_sinfo->values.defined()) { output_shape = view_shape_sinfo->values.value(); @@ -171,7 +172,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // Helper function, returns the number of bytes per vectorized // element. Cannot use `DataType::bytes`, as it returns the // number of bytes per scalar element. - auto get_size_bytes = [](const DataType& dtype) -> Optional { + auto get_size_bytes = [](const DataType& dtype) -> ffi::Optional { if (dtype.is_void()) { return std::nullopt; } else { @@ -182,7 +183,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // Helper function, returns the number of elements in an array, // given the shape of that array. - auto get_num_elements = [&ctx](const Optional>& shape) -> Optional { + auto get_num_elements = + [&ctx](const ffi::Optional>& shape) -> ffi::Optional { if (!shape.defined()) { return std::nullopt; } @@ -194,11 +196,11 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return ctx->GetAnalyzer()->Simplify(num_elements); }; - Optional input_nelements = get_num_elements(input_shape); - Optional output_nelements = get_num_elements(output_shape); + ffi::Optional input_nelements = get_num_elements(input_shape); + ffi::Optional output_nelements = get_num_elements(output_shape); - Optional input_element_size = get_size_bytes(data_sinfo->dtype); - Optional output_element_size = get_size_bytes(output_dtype); + ffi::Optional input_element_size = get_size_bytes(data_sinfo->dtype); + ffi::Optional output_element_size = get_size_bytes(output_dtype); if (input_nelements && output_nelements && input_element_size && output_element_size && view_relative_byte_offset) { diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index 77ec7e9833cc..6c23ef7b27a2 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -30,7 +30,8 @@ namespace tvm { namespace relax { /*! \brief View a tensor with different properties. */ -Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +Expr view(Expr x, ffi::Optional shape, ffi::Optional dtype, + ffi::Optional relative_byte_offset); /*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ Expr ensure_aligned(const Expr& x); diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 916fa2f39f33..288214cebb6b 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -28,9 +28,10 @@ namespace relax { /* relax.nn.attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size) { - ObjectPtr attrs = make_object(); +Expr attention(Expr query, Expr key, Expr value, ffi::Optional bias, + ffi::Optional scale, ffi::Optional causal_mask, + ffi::Optional window_size) { + ObjectPtr attrs = ffi::make_object(); attrs->scale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; @@ -45,9 +46,9 @@ Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size) { - ObjectPtr attrs = make_object(); + Expr max_seqlen_q, Expr max_seqlen_k, ffi::Optional scale, + ffi::Optional causal_mask, ffi::Optional window_size) { + ObjectPtr attrs = ffi::make_object(); attrs->scale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; @@ -65,11 +66,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo q_sinfo = input_sinfo[0]; TensorStructInfo k_sinfo = input_sinfo[1]; TensorStructInfo v_sinfo = input_sinfo[2]; - auto diag_dim = [&](TensorStructInfo sinfo, String name) { + auto diag_dim = [&](TensorStructInfo sinfo, ffi::String name) { if (sinfo->ndim != 4) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << name << " should have 4 dimension, namely " @@ -89,7 +90,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { PrimExpr num_keys = k_shape->values[1]; PrimExpr head_dim_value = v_shape->values[3]; arith::Analyzer* analyzer = ctx->GetAnalyzer(); - auto diag_equal = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto diag_equal = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, ffi::String dim) { if (analyzer->CanProve(v1 != v2)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " and the " << m2 << " " << dim @@ -97,7 +98,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << v1 << " while the " << dim << " of " << m2 << " is " << v2); } }; - auto multiple_of = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto multiple_of = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, + ffi::String dim) { if (analyzer->CanProve(indexmod(v1, v2) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " should be a multiple of " << m2 << " " @@ -121,7 +123,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { << "The bias should have 4 dimensions." << "However, the bias input has " << bias_sinfo->ndim << " dimensions."); } - auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, String m1, String m2, String dim) { + auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, + ffi::String dim) { if (analyzer->CanProve(v1 != v2) && !tir::is_one(v2)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " and the " << m2 << " " << dim @@ -136,7 +139,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { diag_equal(num_keys, bias_shape->values[3], "key", "bias", "sequence length"); } - Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; + ffi::Array output_shape = {num_batches, num_queries, num_heads, head_dim_value}; return TensorStructInfo(ShapeExpr(output_shape), q_sinfo->dtype, q_sinfo->vdevice); } diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index 346907f8e938..f4fe8ad88fd4 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,8 +33,9 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, - Optional causal_mask, Optional window_size); +Expr attention(Expr query, Expr key, Expr value, ffi::Optional bias, + ffi::Optional scale, ffi::Optional causal_mask, + ffi::Optional window_size); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 7346af3b1c98..b8cf8b95ee46 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -41,9 +41,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.nn.conv1d */ -Expr conv1d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " @@ -66,7 +67,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -81,21 +82,22 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; @@ -133,19 +135,19 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv1d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv1d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv1d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv1d. @@ -200,9 +202,10 @@ TVM_REGISTER_OP("relax.nn.conv1d") /* relax.nn.conv2d */ -Expr conv2d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -231,7 +234,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -246,21 +249,22 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; @@ -303,19 +307,19 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv2d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv2d. @@ -343,8 +347,10 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, auto kernel_si = GetStructInfo(call->args[1]); TensorStructInfo data_sinfo = data_si.as().value(); TensorStructInfo kernel_sinfo = kernel_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); - Optional kernel_shape = GetRef(kernel_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); + ffi::Optional kernel_shape = + ffi::GetRef(kernel_sinfo->shape.as()); bool can_data_proved = CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); @@ -399,9 +405,10 @@ TVM_REGISTER_OP("relax.nn.conv2d") /* relax.nn.conv3d */ -Expr conv3d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype) { +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -432,7 +439,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -447,21 +454,22 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); - Array weight_OIDHW_shape = weight2OIDHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_OIDHW_shape = weight2OIDHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCDHW_shape[1]; @@ -510,19 +518,19 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -InferLayoutOutput InferLayoutConv3d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConv3d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv3d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { // We have a desired layout for conv3d. @@ -575,10 +583,11 @@ TVM_REGISTER_OP("relax.nn.conv3d") .set_attr("FInferMixedPrecision", InferMixedPrecisionConv3d) .set_attr("FPurity", Bool(true)); -Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype) { +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " @@ -593,7 +602,7 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array(); + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->output_padding = ConvertIntImmToInt64(output_padding); @@ -613,7 +622,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -627,21 +636,22 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, // /*tgt_layout=*/"NCW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; @@ -689,7 +699,7 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& attrs->dilation[0] * (kernel_w - 1) + attrs->output_padding[0] + 1; out_NCW_shape[2] = analyzer->Simplify(out_w); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } @@ -705,10 +715,11 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose") /* relax.nn.conv2d_transpose */ -Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype) { +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -732,7 +743,7 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array(); + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->output_padding = ConvertIntImmToInt64(output_padding); @@ -752,7 +763,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo weight_sinfo = input_sinfo[1]; @@ -767,21 +778,22 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& /*tgt_layout=*/"NCHW", // /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); - Optional weight_shape = + ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); DataType out_dtype = attrs->out_dtype.is_void() ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) : attrs->out_dtype; - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); + ffi::Optional vdevice = + InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo); if (!data_shape.defined() || !weight_shape.defined()) { return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; @@ -837,7 +849,7 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& out_NCHW_shape[2] = analyzer->Simplify(out_h); out_NCHW_shape[3] = analyzer->Simplify(out_w); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index c99f03388e19..4fc175b5aa07 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -36,10 +36,11 @@ namespace tvm { namespace relax { template -inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - String out_layout, DataType out_dtype, std::string op_name) { - auto attrs = make_object(); +inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::String out_layout, DataType out_dtype, + std::string op_name) { + auto attrs = ffi::make_object(); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); attrs->dilation = ConvertIntImmToInt64(dilation); @@ -53,19 +54,22 @@ inline Expr MakeConv(Expr data, Expr weight, Array strides, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! \brief 2D convolution */ -Expr conv2d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! \brief 3D convolution */ -Expr conv3d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, int groups, String data_layout, String kernel_layout, - Optional out_layout, Optional out_dtype); +Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! * \brief One dimensional transposed convolution operator. @@ -73,10 +77,11 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding * This operator is intended to be the backward operator of conv1d. It can be used to calculate the * gradient of the result of conv1d w.r.t. the input of conv1d. */ -Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype); +Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); /*! * \brief Two dimensional transposed convolution operator. @@ -84,10 +89,11 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array strides, Array padding, - Array output_padding, Array dilation, int groups, - String data_layout, String kernel_layout, Optional out_layout, - Optional out_dtype); +Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, + ffi::Array padding, ffi::Array output_padding, + ffi::Array dilation, int groups, ffi::String data_layout, + ffi::String kernel_layout, ffi::Optional out_layout, + ffi::Optional out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 3597b16a5bcc..7a2bb0e607d2 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -61,7 +61,7 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/tru /* relax.nn.leakyrelu */ Expr leakyrelu(Expr data, double alpha) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("relax.nn.leakyrelu"); return Call(op, {data}, Attrs(attrs), {}); @@ -83,7 +83,7 @@ TVM_REGISTER_OP("relax.nn.leakyrelu") /* relax.nn.softplus */ Expr softplus(Expr data, double beta, double threshold) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->beta = beta; attrs->threshold = threshold; static const Op& op = Op::Get("relax.nn.softplus"); @@ -106,7 +106,7 @@ TVM_REGISTER_OP("relax.nn.softplus") /* relax.nn.prelu */ Expr prelu(Expr data, Expr alpha, int axis = 1) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); @@ -133,9 +133,9 @@ StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutPRelu(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPRelu( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; @@ -151,7 +151,7 @@ InferLayoutOutput InferLayoutPRelu(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, call->args[1]); @@ -170,7 +170,7 @@ TVM_REGISTER_OP("relax.nn.prelu") /* relax.nn.softmax */ Expr softmax(Expr data, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); @@ -198,9 +198,9 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutSoftmax(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSoftmax( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; @@ -216,7 +216,7 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); } @@ -231,7 +231,7 @@ TVM_REGISTER_OP("relax.nn.softmax") /* relax.nn.log_softmax */ Expr log_softmax(Expr data, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = axis; static const Op& op = Op::Get("relax.nn.log_softmax"); return Call(op, {data}, Attrs(attrs), {}); @@ -251,8 +251,8 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ -Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) { - auto attrs = make_object(); +Expr pad(Expr data, ffi::Array pad_width, ffi::String pad_mode, double pad_value) { + auto attrs = ffi::make_object(); attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); attrs->pad_value = pad_value; @@ -266,13 +266,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); int ndim = input_sinfo[0]->ndim; - Array pad_width = attrs->pad_width; + ffi::Array pad_width = attrs->pad_width; ICHECK(static_cast(pad_width.size()) == 2 * ndim) << "Illegal pad_width"; - Array out_shape; + ffi::Array out_shape; if (input_sinfo[0]->shape.defined()) { // Compute output shape by adding corresponding pad width to each axis. const auto* data_shape = input_sinfo[0]->shape.as(); @@ -299,7 +299,7 @@ TVM_REGISTER_OP("relax.nn.pad") /* relax.nn.pixel_shuffle */ Expr pixel_shuffle(Expr data, int upscale_factor) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->upscale_factor = upscale_factor; static const Op& op = Op::Get("relax.nn.pixel_shuffle"); return Call(op, {data}, Attrs(attrs), {}); @@ -311,7 +311,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); int r = attrs->upscale_factor; ICHECK_GT(r, 0) << "Upscale factor must be positive"; @@ -325,7 +325,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx } const auto* shape = input->shape.as(); - Array in_shape = shape->values; + ffi::Array in_shape = shape->values; int channel_idx = ndim - 3; int h_idx = ndim - 2; @@ -345,7 +345,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx << "Number of input channels must be divisible by the square of the upscale factor"; // Output shape: - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < ndim; ++i) { if (i == channel_idx) { out_shape.push_back(c_in / r_squared); @@ -370,7 +370,8 @@ TVM_REGISTER_OP("relax.nn.pixel_shuffle") /* relax.nn.batchnorm */ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, - const Array& input_sinfo, Array axes) { + const ffi::Array& input_sinfo, + ffi::Array axes) { Op op = Downcast(call->op); int n_input = op->arguments.size(); @@ -405,7 +406,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, } } - std::vector> axis_lengths; + std::vector> axis_lengths; axis_lengths.reserve(n_input); if (const auto* data_shape = data_sinfo->shape.as()) { std::vector lengths; @@ -442,7 +443,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale, double momentum, bool training) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -462,7 +463,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); @@ -478,9 +479,9 @@ StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutBatchNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutBatchNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 5; ++i) { @@ -502,7 +503,7 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, layout = LayoutDecision(InitialLayout(ndim)); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim); return InferLayoutOutput( {layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]}, @@ -523,9 +524,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm") /* relax.nn.layer_norm */ -Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, - bool scale) { - ObjectPtr attrs = make_object(); +Expr layer_norm(Expr data, Expr gamma, Expr beta, ffi::Array axes, double epsilon, + bool center, bool scale) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); attrs->epsilon = epsilon; attrs->center = center; @@ -541,7 +542,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); @@ -551,9 +552,9 @@ StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { : input_sinfo[0]; } -InferLayoutOutput InferLayoutLayerNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutLayerNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -566,7 +567,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); const auto* input_sinfo = GetStructInfoAs(call->args[0]); int ndim = input_sinfo->ndim; std::vector new_axis; @@ -592,8 +593,8 @@ TVM_REGISTER_OP("relax.nn.layer_norm") /* relax.nn.group_norm */ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, - Array axes, double epsilon, bool center, bool scale) { - ObjectPtr attrs = make_object(); + ffi::Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = ffi::make_object(); attrs->num_groups = num_groups; attrs->channel_axis = channel_axis; attrs->axes = std::move(axes); @@ -612,7 +613,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); TensorStructInfo data_sinfo = input_sinfo[0]; @@ -666,9 +667,9 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } -InferLayoutOutput InferLayoutGroupNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutGroupNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -681,7 +682,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, axis->value)); @@ -705,9 +706,9 @@ TVM_REGISTER_OP("relax.nn.group_norm") /* relax.nn.instance_norm */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, ffi::Array axes, double epsilon, bool center, bool scale) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->channel_axis = std::move(channel_axis); attrs->axes = std::move(axes); attrs->epsilon = epsilon; @@ -725,7 +726,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; TensorStructInfo data_sinfo = input_sinfo[0]; @@ -769,9 +770,9 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx return data_sinfo; } -InferLayoutOutput InferLayoutInstanceNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutInstanceNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { @@ -784,7 +785,7 @@ InferLayoutOutput InferLayoutInstanceNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, (axis->value))); @@ -807,8 +808,8 @@ TVM_REGISTER_OP("relax.nn.instance_norm") .set_attr("FPurity", Bool(true)); /* relax.nn.rms_norm */ -Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { - ObjectPtr attrs = make_object(); +Expr rms_norm(Expr data, Expr weight, ffi::Array axes, double epsilon) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); attrs->epsilon = epsilon; @@ -822,7 +823,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); @@ -832,9 +833,9 @@ StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { : input_sinfo[0]; } -InferLayoutOutput InferLayoutRMSNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutRMSNorm( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 2; ++i) { @@ -847,7 +848,7 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); std::vector new_axes; for (const auto& axis : attrs->axes) { new_axes.push_back(FindAxis(layout->layout, axis->value)); @@ -869,7 +870,7 @@ TVM_REGISTER_OP("relax.nn.rms_norm") /* relax.nn.dropout */ Expr dropout(Expr data, double rate) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->rate = rate; static const Op& op = Op::Get("relax.nn.dropout"); @@ -897,7 +898,7 @@ TVM_REGISTER_OP("relax.nn.dropout") /* relax.nn.cross_entropy_with_logits */ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo pred_sinfo = input_sinfo[0]; TensorStructInfo label_sinfo = input_sinfo[1]; @@ -905,7 +906,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); // infer vdevice - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_sinfo, label_sinfo); // infer ndim if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && @@ -916,12 +917,12 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); } - Optional> pred_shape_value; + ffi::Optional> pred_shape_value; if (pred_sinfo->shape.defined()) { pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; } - Optional> label_shape_value; + ffi::Optional> label_shape_value; if (label_sinfo->shape.defined()) { label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; } @@ -939,7 +940,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx } } } - return TensorStructInfo(ShapeExpr(Array()), dtype, vdevice); + return TensorStructInfo(ShapeExpr(ffi::Array()), dtype, vdevice); } Expr cross_entropy_with_logits(Expr predictions, Expr labels) { @@ -961,9 +962,9 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") /* relax.nn.nll_loss */ -Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, +Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi::String reduction, int ignore_index) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") << "The argument reduction of NLLLoss should be one of the following " @@ -1020,12 +1021,12 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { // infer dtype, vdevice DataType output_dtype; - Optional vdevice; + ffi::Optional vdevice; if (wgt_sinfo != nullptr) { - output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef(pred_sinfo), - GetRef(wgt_sinfo)); - vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef(pred_sinfo), - GetRef(wgt_sinfo)); + output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_sinfo), + ffi::GetRef(wgt_sinfo)); + vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_sinfo), + ffi::GetRef(wgt_sinfo)); } else { output_dtype = pred_sinfo->dtype; vdevice = pred_sinfo->vdevice; @@ -1066,11 +1067,11 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } arith::Analyzer* analyzer = ctx->GetAnalyzer(); - Optional N; - Optional C; - Array output_shape; // N, d1, d2, ..., dk + ffi::Optional N; + ffi::Optional C; + ffi::Array output_shape; // N, d1, d2, ..., dk - Optional> pred_shape_value; + ffi::Optional> pred_shape_value; if (pred_sinfo->shape.defined()) { pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; } @@ -1085,7 +1086,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); N = pred_shape_value.value()[0]; C = pred_shape_value.value()[1]; - output_shape = Array(); + output_shape = ffi::Array(); output_shape.push_back(N.value()); for (size_t i = 2; i < pred_shape_value.value().size(); ++i) { output_shape.push_back(pred_shape_value.value()[i]); @@ -1093,7 +1094,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } } - Optional> tgt_shape_value; + ffi::Optional> tgt_shape_value; if (tgt_sinfo->shape.defined()) { tgt_shape_value = GetStructInfoAs(tgt_sinfo->shape.value())->values; } @@ -1148,7 +1149,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } if (wgt_sinfo != nullptr) { - Optional> wgt_shape_value; + ffi::Optional> wgt_shape_value; if (wgt_sinfo->shape.defined()) { wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; } @@ -1166,7 +1167,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - String reduction = attrs->reduction; + ffi::String reduction = attrs->reduction; if (reduction == "none") { // () or (N,) or (N, d1, d2, ..., dk) @@ -1178,7 +1179,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { } } else { // sum or mean. output is scalar - return TensorStructInfo(/*shape=*/ShapeExpr(Array()), output_dtype, vdevice); + return TensorStructInfo(/*shape=*/ShapeExpr(ffi::Array()), output_dtype, vdevice); } } @@ -1187,7 +1188,7 @@ TVM_REGISTER_OP("relax.nn.nll_loss") .set_num_inputs(3) .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "Optional", "The weight of each target values.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLoss) .set_attr("FPurity", Bool(true)); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 39f8c2d73800..c2f4aad2f8a4 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -83,19 +83,19 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ int axis, double epsilon, bool center, bool scale, double momentum, bool training); /*! \brief Compute layer normalization. */ -Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, - bool scale); +Expr layer_norm(Expr data, Expr gamma, Expr beta, ffi::Array axes, double epsilon, + bool center, bool scale); /*! \brief Compute group normalization. */ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, - Array axes, double epsilon, bool center, bool scale); + ffi::Array axes, double epsilon, bool center, bool scale); /*! \brief Compute instance normalization. */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, ffi::Array axes, double epsilon, bool center, bool scale); /*! \brief Compute root mean square normalization. */ -Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); +Expr rms_norm(Expr data, Expr weight, ffi::Array axes, double epsilon); /*! * \brief Applies the dropout operation to the input tensor. @@ -111,7 +111,7 @@ Expr dropout(Expr data, double rate); Expr cross_entropy_with_logits(Expr predictions, Expr labels); /*! \brief Negative log likelihood loss. */ -Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, +Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi::String reduction, int ignore_index); } // namespace relax diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 6a12a60a4ee9..fe134a76bb1a 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -38,9 +38,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.nn.max_pool1d */ -Expr MakePool1d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding1D(std::move(padding)); CHECK_EQ(pool_size.size(), 1) @@ -52,7 +53,7 @@ Expr MakePool1d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -65,9 +66,9 @@ Expr MakePool1d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -88,13 +89,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = attrs->pool_size[0]; @@ -112,13 +113,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { } out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool1d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool1d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -127,7 +128,7 @@ InferLayoutOutput InferLayoutPool1d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -144,9 +145,10 @@ TVM_REGISTER_OP("relax.nn.max_pool1d") /* relax.nn.max_pool2d */ -Expr MakePool2d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding2D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -167,7 +169,7 @@ Expr MakePool2d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -180,9 +182,9 @@ Expr MakePool2d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -203,13 +205,13 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); PrimExpr input_h = data_NCHW_shape[2]; PrimExpr input_w = data_NCHW_shape[3]; @@ -233,13 +235,13 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool2d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool2d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -248,14 +250,15 @@ InferLayoutOutput InferLayoutPool2d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { tir::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -282,9 +285,10 @@ TVM_REGISTER_OP("relax.nn.max_pool2d") /* relax.nn.max_pool3d */ -Expr MakePool3d(String op_name, Expr data, Array pool_size, Array strides, - Array padding, Array dilation, bool ceil_mode, - bool count_include_pad, String layout, Optional out_layout) { +Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, ffi::Array dilation, + bool ceil_mode, bool count_include_pad, ffi::String layout, + ffi::Optional out_layout) { padding = GetCompletePadding3D(std::move(padding)); if (pool_size.size() == 1) { pool_size.push_back(pool_size[0]); @@ -308,7 +312,7 @@ Expr MakePool3d(String op_name, Expr data, Array pool_size, Array(); + auto attrs = ffi::make_object(); attrs->pool_size = ConvertIntImmToInt64(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -321,9 +325,9 @@ Expr MakePool3d(String op_name, Expr data, Array pool_size, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr max_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -344,13 +348,13 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), data_sinfo->vdevice); } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); PrimExpr input_d = data_NCDHW_shape[2]; PrimExpr input_h = data_NCDHW_shape[3]; @@ -380,13 +384,13 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[1]) + 1); out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[2]) + 1); - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPool3d(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPool3d( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -395,7 +399,7 @@ InferLayoutOutput InferLayoutPool3d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -411,9 +415,9 @@ TVM_REGISTER_OP("relax.nn.max_pool3d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool1d */ -Expr avg_pool1d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool1d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -433,9 +437,9 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool2d */ -Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -455,9 +459,9 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_attr("FPurity", Bool(true)); /* relax.nn.avg_pool3d */ -Expr avg_pool3d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { +Expr avg_pool3d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout) { return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding, dilation, ceil_mode, count_include_pad, layout, out_layout); } @@ -478,13 +482,13 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d") /* relax.nn.adaptive_avg_pool1d */ -Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); CHECK_EQ(_output_size.size(), 1) << "The output_size length is expected to be 1. However, the given output_size is " << _output_size; @@ -511,7 +515,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -522,19 +526,19 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder } } - Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); - Array out_NCW_shape(data_NCW_shape); + ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCW_shape(data_NCW_shape); if (attrs->output_size.defined()) { out_NCW_shape.Set(2, attrs->output_size.value()[0]); } - Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool1D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool1D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -543,7 +547,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool1D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); @@ -560,13 +564,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d") /* relax.nn.adaptive_avg_pool2d */ -Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -596,7 +600,7 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -607,20 +611,20 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder } } - Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); - Array out_NCHW_shape(data_NCHW_shape); + ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCHW_shape(data_NCHW_shape); if (attrs->output_size.defined()) { out_NCHW_shape.Set(2, attrs->output_size.value()[0]); out_NCHW_shape.Set(3, attrs->output_size.value()[1]); } - Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool2D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -629,13 +633,14 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { tir::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); - Optional data_shape = GetRef(data_sinfo->shape.as()); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) { // Not handling out_layout being different from in_layout now. Any use case ? new_attrs->layout = desired_layout.name(); @@ -661,13 +666,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") /* relax.nn.adaptive_avg_pool3d */ -Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String layout, - Optional out_layout) { - ObjectPtr attrs = make_object(); +Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout) { + ObjectPtr attrs = ffi::make_object(); attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { - Array _output_size = output_size.value(); + ffi::Array _output_size = output_size.value(); if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } @@ -697,7 +702,7 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder /*tgt_layout=*/"NCDHW", /*tensor_name=*/"output"); - Optional data_shape = + ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); if (!data_shape.defined()) { if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && @@ -708,21 +713,21 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder } } - Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); - Array out_NCDHW_shape(data_NCDHW_shape); + ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); + ffi::Array out_NCDHW_shape(data_NCDHW_shape); if (attrs->output_size.defined()) { out_NCDHW_shape.Set(2, attrs->output_size.value()[0]); out_NCDHW_shape.Set(3, attrs->output_size.value()[1]); out_NCDHW_shape.Set(4, attrs->output_size.value()[2]); } - Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); + ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutAdaptiveAvgPool3D(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutAdaptiveAvgPool3D( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; @@ -731,7 +736,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool3D(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), layout->layout).name(); new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5), layout->layout).name(); return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs)); diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h index 7fd66f2b44c3..c5435303e82b 100644 --- a/src/relax/op/nn/pooling.h +++ b/src/relax/op/nn/pooling.h @@ -33,18 +33,18 @@ namespace tvm { namespace relax { /*! \brief 2D maximum pooling operator. */ -Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr max_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D average pooling operator. */ -Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array padding, - Array dilation, bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr avg_pool2d(Expr data, ffi::Array pool_size, ffi::Array strides, + ffi::Array padding, ffi::Array dilation, bool ceil_mode, + bool count_include_pad, ffi::String layout, ffi::Optional out_layout); /*! \brief 2D adaptive average pooling operator. */ -Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, - Optional out_layout); +Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_size, + ffi::String layout, ffi::Optional out_layout); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 49bf9ae3d93f..ddf6a056f00a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -57,7 +57,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { } StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { - return TupleStructInfo(Array()); + return TupleStructInfo(ffi::Array()); } StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -112,16 +112,16 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c TVM_REGISTER_OP("relax.call_pure_packed") .set_num_inputs(-1) - .add_argument("args", "Array", + .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallPurePacked) .set_attr("FPurity", Bool(true)); -Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs, - Array sinfo_args) { +Expr MakeCallPurePacked(const Expr& callee, ffi::Array args, const Attrs& attrs, + ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.call_pure_packed"); - Array call_args = {callee}; + ffi::Array call_args = {callee}; for (auto arg : args) { call_args.push_back(arg); } @@ -227,7 +227,7 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder TVM_REGISTER_OP("relax.call_inplace_packed") .set_num_inputs(-1) .set_attrs_type() - .add_argument("args", "Array", + .add_argument("args", "ffi::Array", "The first argument is the function being called. The rest are the " "arguments to that function.") .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) @@ -237,13 +237,13 @@ TVM_REGISTER_OP("relax.call_inplace_packed") // side effects other than modifying the arguments specified as "inplace" .set_attr("FPurity", Bool(true)); -Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_indices, - Array sinfo_args) { - ObjectPtr attrs = make_object(); - attrs->inplace_indices = Array(inplace_indices.begin(), inplace_indices.end()); +Expr MakeCallInplacePacked(Expr func, ffi::Array args, ffi::Array inplace_indices, + ffi::Array sinfo_args) { + ObjectPtr attrs = ffi::make_object(); + attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); static const Op& op = Op::Get("relax.call_inplace_packed"); - Array call_args = {func}; + ffi::Array call_args = {func}; call_args.insert(call_args.end(), args.begin(), args.end()); return Call(op, call_args, Attrs(attrs), sinfo_args); } @@ -285,9 +285,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \return The `arg_sinfo`, if it can be inferred from the arguments. * Otherwise, std::nullopt. */ -static Optional InferCallTIROutputStructInfoFromArguments( - StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, - Optional> opt_inplace_indices) { +static ffi::Optional InferCallTIROutputStructInfoFromArguments( + StructInfo func_sinfo, StructInfo arg_sinfo, ffi::Optional packed_ints_sinfo, + ffi::Optional> opt_inplace_indices) { auto opt_callee_sinfo = func_sinfo.as(); CHECK(opt_callee_sinfo) << "TypeError: " << "The first argument to `R.call_tir` must be a function, " @@ -368,16 +368,16 @@ static Optional InferCallTIROutputStructInfoFromArguments( // arguments are used. auto dummy_callee_sinfo = [&]() -> FuncStructInfo { - Array dummy_params(callee_params.begin(), - callee_params.begin() + num_input_arguments); + ffi::Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); i++) { dummy_params.push_back(callee_params[i]); } - Array dummy_ret(callee_params.begin() + num_input_arguments, - callee_params.end() - num_trailing_int_arguments); + ffi::Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); if (opt_inplace_indices) { // For R.call_tir_inplace, the `inplace_indices` are used to @@ -405,8 +405,8 @@ static Optional InferCallTIROutputStructInfoFromArguments( return FuncStructInfo(dummy_params, dummy_out_sinfo); }(); - auto dummy_args = [&]() -> Array { - Array dummy_args = args->fields.Map( + auto dummy_args = [&]() -> ffi::Array { + ffi::Array dummy_args = args->fields.Map( [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); for (size_t i = 0; i < num_trailing_int_arguments; i++) { @@ -488,7 +488,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "R.call_tir should have exactly one `sinfo_args` parameter, " << "which defines the output of the PrimFunc."; - auto unwrap_binding = [&ctx](Expr expr) -> Optional { + auto unwrap_binding = [&ctx](Expr expr) -> ffi::Optional { if (auto var = expr.as()) { if (auto bound_value = ctx->LookupBinding(var.value())) { return bound_value.value(); @@ -519,7 +519,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // and we don't know the value bound to that variable. For // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. - Array tuple_elements; + ffi::Array tuple_elements; size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); for (size_t i = 0; i < num_fields; i++) { tuple_elements.push_back(TupleGetItem(arg_tuple, i)); @@ -546,7 +546,7 @@ void ValidateCallTIR(Call call) { auto callee = call->args[0]; Expr arg_tuple = call->args[1]; - auto packed_int_sinfo = [&]() -> Optional { + auto packed_int_sinfo = [&]() -> ffi::Optional { if (call->args.size() <= 2) { return std::nullopt; } else { @@ -554,7 +554,7 @@ void ValidateCallTIR(Call call) { } }(); - auto opt_inplace_indices = [&]() -> Optional> { + auto opt_inplace_indices = [&]() -> ffi::Optional> { if (const auto* attrs = call->attrs.as()) { return attrs->inplace_indices; } else { @@ -586,8 +586,8 @@ TVM_REGISTER_OP("relax.call_tir") .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, - Optional packed_ints) { +Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " @@ -633,9 +633,9 @@ TVM_REGISTER_OP("relax.call_tir_with_grad") .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, - String te_grad_name, Map te_grad_kwargs, - Optional packed_ints) { +Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out_sinfo_list, + ffi::String te_grad_name, ffi::Map te_grad_kwargs, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) @@ -651,7 +651,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->te_grad_name = te_grad_name; attrs->te_grad_kwargs = te_grad_kwargs; @@ -679,7 +679,7 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // may result in an error if performed before normalization. call = Downcast(NormalizeCallTIR(ctx, std::move(call))); - Array sinfo_outputs = [&]() -> Array { + ffi::Array sinfo_outputs = [&]() -> ffi::Array { auto out_sinfo = call->sinfo_args[0]; if (auto* tuple_output = out_sinfo.as()) { return tuple_output->fields; @@ -778,8 +778,9 @@ TVM_REGISTER_OP("relax.call_tir_inplace") // arguments will no longer be live) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, - Array out_sinfo_list, Optional packed_ints) { +Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indices, + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " @@ -787,8 +788,8 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, << sinfo; } - ObjectPtr attrs = make_object(); - attrs->inplace_indices = Array(inplace_indices.begin(), inplace_indices.end()); + ObjectPtr attrs = ffi::make_object(); + attrs->inplace_indices = ffi::Array(inplace_indices.begin(), inplace_indices.end()); StructInfo out_sinfo{nullptr}; if (out_sinfo_list.size() == 1) { @@ -832,7 +833,7 @@ TVM_REGISTER_OP("relax.call_dps_packed") // little reason to use DPS with an impure op .set_attr("FPurity", Bool(true)); -Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { +Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) @@ -861,7 +862,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { // by default return void. - return TupleStructInfo(Array()); + return TupleStructInfo(ffi::Array()); } else { ICHECK_EQ(call->sinfo_args.size(), 1); return call->sinfo_args[0]; @@ -876,7 +877,7 @@ TVM_REGISTER_OP("relax.call_builtin_with_ctx") // Most builtins are pure, but some are not, like `vm.builtin.attention_kv_cache_append` .set_attr("FPurity", Bool(false)); -Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) { +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.call_builtin_with_ctx"); return Call(op, {func, args}, Attrs(), sinfo_args); } @@ -905,15 +906,15 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_OP("relax.print") .set_num_inputs(-1) - .add_argument("vals", "Array", + .add_argument("vals", "ffi::Array", "The first value is Python-style format string to use to print. The others " "are values to print") .set_attr("FInferStructInfo", ReturnVoidStructInfo) .set_attr("FCallPacked", "relax.run.print") .set_attr("FPurity", Bool(false)); -Expr MakePrint(Array vals, StringImm format) { - Array params; +Expr MakePrint(ffi::Array vals, StringImm format) { + ffi::Array params; params.push_back(format); for (const auto val : vals) { params.push_back(val); @@ -950,7 +951,7 @@ StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_OP("relax.assert_op") .set_num_inputs(-1) - .add_argument("vals", "Array", + .add_argument("vals", "ffi::Array", "The first value is used as the assertion condition. The second value is " "Python-style format string to use for displaying an error message, if the " "assert fails. The others are used as format arguments if there is an error.") @@ -958,9 +959,9 @@ TVM_REGISTER_OP("relax.assert_op") .set_attr("FCallPacked", "relax.run.assert_op") .set_attr("FPurity", Bool(false)); -Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { +Expr MakeAssertOp(Expr condition, ffi::Array vals, StringImm format) { static const Op& op = Op::Get("relax.assert_op"); - Array args = {condition}; + ffi::Array args = {condition}; args.push_back(format); for (auto val : vals) { args.push_back(val); @@ -1012,7 +1013,7 @@ TVM_REGISTER_OP("relax.invoke_closure") // Not all closures are pure. Use invoke_pure_closure for specifying purity .set_attr("FPurity", Bool(false)); -Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { +Expr InvokeClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_closure"); return Call(op, {closure, args}, {}, sinfo_args); } @@ -1031,7 +1032,7 @@ TVM_REGISTER_OP("relax.invoke_pure_closure") .set_attr("FInferStructInfo", InferStructInfoInvokeClosure) .set_attr("FPurity", Bool(true)); -Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { +Expr InvokePureClosure(Expr closure, Tuple args, ffi::Array sinfo_args) { static const Op& op = Op::Get("relax.invoke_pure_closure"); return Call(op, {closure, args}, {}, sinfo_args); } @@ -1132,7 +1133,7 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[1].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } return TensorStructInfo(call->args[0], out_dtype); @@ -1198,7 +1199,7 @@ StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& c << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } return TensorStructInfo(call->args[2], out_dtype); @@ -1295,11 +1296,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { - const DataTypeImm dtype_imm = GetRef(dtype_node); + const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); out_dtype = dtype_imm->value; } if (const auto* output_shape = call->args[2].as()) { - return TensorStructInfo(GetRef(output_shape), out_dtype); + return TensorStructInfo(ffi::GetRef(output_shape), out_dtype); } else if (const auto* shape_sinfo = GetStructInfoAs(call->args[2])) { if (shape_sinfo->values.defined()) { return TensorStructInfo(ShapeExpr(shape_sinfo->values.value()), out_dtype); @@ -1415,7 +1416,7 @@ TVM_REGISTER_OP("relax.to_vdevice") Expr MakeToVDevice(Expr data, VDevice dst_vdev) { static const Op& op = Op::Get("relax.to_vdevice"); - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = dst_vdev; return Call(op, {data}, Attrs(attrs), {}); } @@ -1443,7 +1444,7 @@ TVM_REGISTER_OP("relax.hint_on_device") Expr MakeHintOnDevice(Expr data, Device device) { static const Op& op = Op::Get("relax.hint_on_device"); - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->device_type = static_cast(device.device_type); attrs->index = device.device_id; return Call(op, {data}, Attrs(attrs), {}); diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index f439a345eb19..5b9ed1e5f529 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -24,9 +24,9 @@ namespace tvm { namespace relax { -Array GetCallArgs(const Call& call) { +ffi::Array GetCallArgs(const Call& call) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - Array args; + ffi::Array args; if (call->op.same_as(call_tir_op)) { args = Downcast(call->args[1])->fields; } else { @@ -70,19 +70,19 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const } } -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { +ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); Op op = Downcast(call->op); - Array input_tensor_sinfo; + ffi::Array input_tensor_sinfo; for (size_t i = 0; i < call->args.size(); ++i) { input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); } return input_tensor_sinfo; } -Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup) { +ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup) { const auto* tuple_sinfo = GetStructInfoAs(tup); if (tuple_sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) @@ -91,7 +91,7 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo << tup->struct_info_->GetTypeKey()); } - Array tensor_sinfo; + ffi::Array tensor_sinfo; tensor_sinfo.reserve(tuple_sinfo->fields.size()); for (StructInfo field_sinfo : tuple_sinfo->fields) { const auto* field_tensor_sinfo = field_sinfo.as(); @@ -101,14 +101,14 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " << tup->struct_info_); } - tensor_sinfo.push_back(GetRef(field_tensor_sinfo)); + tensor_sinfo.push_back(ffi::GetRef(field_tensor_sinfo)); } return tensor_sinfo; } -Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, - const Array& x1_shape, - const Array& x2_shape) { +ffi::Optional> InferBinaryBroadcastShape( + const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, + const ffi::Array& x2_shape) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); int x1_ndim = x1_shape.size(); int x2_ndim = x2_shape.size(); @@ -143,11 +143,11 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc for (; i <= max_ndim; ++i) { output_shape.push_back(longer_shape[max_ndim - i]); } - return Array(output_shape.rbegin(), output_shape.rend()); + return ffi::Array(output_shape.rbegin(), output_shape.rend()); } std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, - const Array& axes) { + const ffi::Array& axes) { ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; std::vector appeared_dims_set; std::vector axes_non_neg; @@ -177,21 +177,21 @@ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int nd return axes_non_neg; } -InferLayoutOutput InferLayoutUnaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutUnaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); } bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, - Array shape) { + ffi::Array shape) { bool can_prove = true; try { tir::BijectiveLayout todesired(input_layout, desired_layout); - Array desired_shape = todesired.ForwardShape(shape); - Array back_shape = todesired.BackwardShape(desired_shape); + ffi::Array desired_shape = todesired.ForwardShape(shape); + ffi::Array back_shape = todesired.BackwardShape(desired_shape); arith::Analyzer analyzer; for (size_t i = 0; i < shape.size(); ++i) { if (tir::is_const_int(shape[i])) { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 4da8b18fcb13..b8cc8a64efe0 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -71,7 +71,7 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const * \note This function require every input to be Tensor. The number of call arguments is required * to match the number of inputs of the op being called. */ -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); +ffi::Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); /*! * \brief Get the tensor struct info of the unary operator input. @@ -93,8 +93,8 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const Bl * \return The tensor struct infos of tuple input. * \throw Throw exception if input expression is not a tuple. */ -Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, - const Expr& tup); +ffi::Array GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& tup); namespace detail { /*! \brief Implementation helper for GetArgStructInfo */ @@ -208,7 +208,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx << " requires the input tensor to have float dtype. However, the given input dtype is " << input_sinfo->dtype); } - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = f_compute_out_dtype(input_sinfo); return TensorStructInfo(output_sinfo); } @@ -257,9 +257,9 @@ StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) * \param var_layout_map The layout of vars. * \return The inferred layout result. */ -InferLayoutOutput InferLayoutUnaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map); +InferLayoutOutput InferLayoutUnaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map); /*! * \brief Get the element dtype from StructInfo @@ -338,10 +338,11 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& * \return The inferred output vdevice. * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match */ -inline Optional InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx, - const StructInfo& lhs_sinfo, - const StructInfo& rhs_sinfo) { - auto get_vdevice = [&](const StructInfo& sinfo) -> Optional { +inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, + const BlockBuilder& ctx, + const StructInfo& lhs_sinfo, + const StructInfo& rhs_sinfo) { + auto get_vdevice = [&](const StructInfo& sinfo) -> ffi::Optional { if (const auto* tensor = sinfo.as()) { return tensor->vdevice; } else { @@ -378,9 +379,10 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl * \return The inferred output shape after broadcasting. Or `std::nullopt` if the output shape * cannot be determined due to symbolic broadcast. */ -Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, - const Array& x1_shape, - const Array& x2_shape); +ffi::Optional> InferBinaryBroadcastShape(const Call& call, + const BlockBuilder& ctx, + const ffi::Array& x1_shape, + const ffi::Array& x2_shape); /*! * \brief Convert all axes to non-negative indices, and meanwhile check if the given array of axes @@ -393,7 +395,7 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc * \throw Throw exception if there exists out-of-range axis index or repetitive indices. */ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, - const Array& axes); + const ffi::Array& axes); /*! * \brief Convert the given axis to non-negative index. Meanwhile check if the axis is in range @@ -414,7 +416,7 @@ inline int NormalizeAxis(const Call& call, const BlockBuilder& ctx, int ndim, in * \param shape_values The given shape values. * \return The product of all the given shape values. */ -PrimExpr ComputeShapeProduct(const Array& shape_values); +PrimExpr ComputeShapeProduct(const ffi::Array& shape_values); /*! * \brief Check if the given permutation is identity permutation. @@ -428,7 +430,7 @@ bool IsIdentityPermutation(const std::vector& permutation); * \param int_imms The input IntImms to be converted. * \return The conversion result, where every IntImm has dtype int64 */ -inline Array ConvertIntImmToInt64(const Array& int_imms) { +inline ffi::Array ConvertIntImmToInt64(const ffi::Array& int_imms) { return int_imms.Map([](const IntImm& i) { return Downcast(cast(DataType::Int(64), i)); }); } @@ -442,7 +444,7 @@ inline Array ConvertIntImmToInt64(const Array& int_imms) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1 or 2. */ -inline Array GetCompletePadding1D(Array padding) { +inline ffi::Array GetCompletePadding1D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -463,7 +465,7 @@ inline Array GetCompletePadding1D(Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 2 or 4. */ -inline Array GetCompletePadding2D(Array padding) { +inline ffi::Array GetCompletePadding2D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 2) { @@ -488,7 +490,7 @@ inline Array GetCompletePadding2D(Array padding) { * \return The completed padding. * \throws Throws error if the input padding length is neither 1, 3 or 6. */ -inline Array GetCompletePadding3D(Array padding) { +inline ffi::Array GetCompletePadding3D(ffi::Array padding) { if (padding.size() == 1) { return {padding[0], padding[0], padding[0], padding[0], padding[0], padding[0]}; } else if (padding.size() == 3) { @@ -514,11 +516,9 @@ inline Array GetCompletePadding3D(Array padding) { * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout * accordingly. */ -inline std::pair CheckTensorLayout(const Call& call, - const BlockBuilder& ctx, - const String& tensor_layout, - const String& tgt_layout, - const String& tensor_name) { +inline std::pair CheckTensorLayout( + const Call& call, const BlockBuilder& ctx, const ffi::String& tensor_layout, + const ffi::String& tgt_layout, const ffi::String& tensor_name) { tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); if (!tensor2tgt.defined()) { @@ -539,9 +539,10 @@ inline std::pair CheckTensorLayout(const Call * \param layout The layout that the given tensor is expected to have. * \return The shape of the input tensor in ShapeExpr, or `std::nullopt` if the shape is unknown. */ -inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, - const TensorStructInfo& sinfo, - const tir::Layout& layout) { +inline ffi::Optional CheckNdimPerLayoutAndGetShape(const Call& call, + const BlockBuilder& ctx, + const TensorStructInfo& sinfo, + const tir::Layout& layout) { if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", layout " << layout << " requires the input to be " @@ -549,7 +550,7 @@ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const << sinfo->ndim); } if (const auto* shape_expr = sinfo->shape.as()) { - return GetRef(shape_expr); + return ffi::GetRef(shape_expr); } return std::nullopt; } @@ -568,7 +569,7 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind * \param call The call node * \return The arguments of the call */ -Array GetCallArgs(const Call& call); +ffi::Array GetCallArgs(const Call& call); /** * \brief Checks the given shape can be proved from the source layout to dst layout @@ -578,7 +579,7 @@ Array GetCallArgs(const Call& call); * \return true or false depending on the compatibility */ bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout, - Array shape); + ffi::Array shape); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 74ae8e9cbc5c..eeb4d552e787 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -61,7 +61,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, } // VDevice - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); + ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); auto get_ndim = [&](const StructInfo& sinfo) -> int { if (sinfo.as()) { @@ -86,9 +86,9 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, // Shapes - auto get_shape = [](const StructInfo& sinfo) -> Optional> { + auto get_shape = [](const StructInfo& sinfo) -> ffi::Optional> { if (sinfo.as()) { - return Array{IntImm(DataType::Int(64), 1)}; + return ffi::Array{IntImm(DataType::Int(64), 1)}; } else if (const auto* tensor = sinfo.as()) { return tensor->GetShape(); } else { @@ -101,7 +101,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, auto lhs_shape = get_shape(lhs_sinfo); auto rhs_shape = get_shape(rhs_sinfo); if (lhs_shape && rhs_shape) { - Optional> output_shape = + ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), rhs_shape.value()); if (output_shape.defined()) { ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); @@ -109,7 +109,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, } } - auto get_shape_expr = [](const StructInfo& sinfo) -> Optional { + auto get_shape_expr = [](const StructInfo& sinfo) -> ffi::Optional { if (const auto* tensor = sinfo.as()) { return tensor->shape; } else { @@ -142,9 +142,9 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx const StructInfo& rhs_sinfo) { return DataType::Bool(); }); } -InferLayoutOutput InferLayoutBinaryEwise(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutBinaryEwise( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]); @@ -155,8 +155,8 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; - Optional shape1 = GetRef(x1_sinfo->shape.as()); - Optional shape2 = GetRef(x2_sinfo->shape.as()); + ffi::Optional shape1 = ffi::GetRef(x1_sinfo->shape.as()); + ffi::Optional shape2 = ffi::GetRef(x2_sinfo->shape.as()); // Lets handle sub indexing as long as primal dims are matching if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index a3bec83f749d..8412fd2784b8 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -43,18 +43,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* Initialization operators */ /* relax.full */ -Expr full(Variant> shape, Expr fill_value, Optional dtype) { +Expr full(ffi::Variant> shape, Expr fill_value, + ffi::Optional dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { - shape_in_expr = GetRef(expr); + shape_in_expr = ffi::GetRef(expr); } else if (const auto* _array = shape.as()) { - shape_in_expr = ShapeExpr(GetRef>(_array)); + shape_in_expr = ShapeExpr(ffi::GetRef>(_array)); } else { LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. "; } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.full"); @@ -99,8 +100,8 @@ TVM_REGISTER_OP("relax.full") .set_attr("FPurity", Bool(true)); /* relax.full_like */ -Expr full_like(Expr x, Expr fill_value, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.full_like"); return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); @@ -112,7 +113,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo fill_value_sinfo = input_sinfo[1]; if (fill_value_sinfo->ndim != 0) { @@ -125,7 +126,7 @@ StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { if (attrs->dtype.is_void()) { return data_sinfo; } else { - auto output_sinfo = make_object(*data_sinfo.get()); + auto output_sinfo = ffi::make_object(*data_sinfo.get()); output_sinfo->dtype = attrs->dtype; return TensorStructInfo(output_sinfo); } @@ -164,7 +165,7 @@ StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder if (attrs->dtype.is_void()) { return data_sinfo; } else { - auto output_sinfo = make_object(*data_sinfo.get()); + auto output_sinfo = ffi::make_object(*data_sinfo.get()); output_sinfo->dtype = attrs->dtype; return TensorStructInfo(output_sinfo); } @@ -173,15 +174,15 @@ StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder /* relax.ones & relax.ones_like */ Expr ones(Expr shape, DataType dtype) { CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones"); return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr ones_like(Expr x, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr ones_like(Expr x, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.ones_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); @@ -210,15 +211,15 @@ TVM_REGISTER_OP("relax.ones_like") /* relax.zeros & relax.zeros_like */ Expr zeros(Expr shape, DataType dtype) { CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros"); return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr zeros_like(Expr x, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr zeros_like(Expr x, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.zeros_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); @@ -246,14 +247,14 @@ TVM_REGISTER_OP("relax.zeros_like") /* relax.eye & relax.eye_like */ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye"); return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); } -Expr eye_like(Expr x, PrimValue k, Optional dtype) { - ObjectPtr attrs = make_object(); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.eye_like"); return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); @@ -332,7 +333,7 @@ TVM_REGISTER_OP("relax.eye_like") /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.arange"); return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); @@ -388,7 +389,7 @@ TVM_REGISTER_OP("relax.arange") /* relax.hamming_window */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.hamming_window"); return Call(op, {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)}, diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index f252eebf824f..284448111739 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -41,7 +41,8 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(Variant> shape, Expr fill_value, Optional dtype); +Expr full(ffi::Variant> shape, Expr fill_value, + ffi::Optional dtype); /*! * \brief Construct a tensor such that @@ -54,7 +55,7 @@ Expr full(Variant> shape, Expr fill_value, Optional dtype); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); /*! * \brief Construct a tensor of all ones, with the input shape and dtype. @@ -72,7 +73,7 @@ Expr ones(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr ones_like(Expr x, Optional dtype); +Expr ones_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a tensor of all zeros, with the input shape and dtype. @@ -90,7 +91,7 @@ Expr zeros(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr zeros_like(Expr x, Optional dtype); +Expr zeros_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -114,7 +115,7 @@ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr eye_like(Expr x, PrimValue k, Optional dtype); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 89e7474c1335..da54d25e1bc7 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.astype */ Expr astype(Expr x, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.astype"); @@ -54,7 +54,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - ObjectPtr new_sinfo = make_object(*sinfo.get()); + ObjectPtr new_sinfo = ffi::make_object(*sinfo.get()); new_sinfo->dtype = attrs->dtype; return TensorStructInfo(new_sinfo); } @@ -71,7 +71,7 @@ TVM_REGISTER_OP("relax.astype") /* relax.wrap_param */ Expr MakeWrapParam(Expr data, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.wrap_param"); @@ -86,7 +86,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - ObjectPtr new_sinfo = make_object(*sinfo.get()); + ObjectPtr new_sinfo = ffi::make_object(*sinfo.get()); new_sinfo->dtype = attrs->dtype; return TensorStructInfo(new_sinfo); } diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index 6b0ca941f00c..e120a86470be 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -103,9 +103,9 @@ TVM_REGISTER_OP("relax.grad.end_checkpoint") .set_attr("FPurity", Bool(true)); /* relax.grad.nll_loss_backward */ -Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, - String reduction, int ignore_index) { - ObjectPtr attrs = make_object(); +Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, + ffi::Optional weights, ffi::String reduction, int ignore_index) { + ObjectPtr attrs = ffi::make_object(); attrs->reduction = reduction; attrs->ignore_index = ignore_index; @@ -136,16 +136,16 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward") .add_argument("output_grad", "Tensor", "The output gradient.") .add_argument("predictions", "Tensor", "The prediction tensor.") .add_argument("targets", "Tensor", "The target tensor.") - .add_argument("weights", "Optional", "The weight of each target values.") + .add_argument("weights", "ffi::Optional", "The weight of each target values.") .set_attr("FInferStructInfo", InferStructInfoNLLLossBackward) .set_attr("FPurity", Bool(true)); /* relax.grad.max_pool2d_backward */ -Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { - auto attrs = make_object(); +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { + auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -176,11 +176,11 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward") .set_attr("FPurity", Bool(true)); /* relax.grad.avg_pool2d_backward */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout) { - auto attrs = make_object(); +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout) { + auto attrs = ffi::make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = ConvertIntImmToInt64(strides); attrs->padding = ConvertIntImmToInt64(padding); @@ -212,8 +212,8 @@ TVM_REGISTER_OP("relax.grad.avg_pool2d_backward") /* relax.grad.take_backward */ -Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axis) { - ObjectPtr attrs = make_object(); +Expr take_backward(Expr output_grad, Expr x, Expr indices, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.grad.take_backward"); diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h index b0a58f7e5c49..406d7a2f779e 100644 --- a/src/relax/op/tensor/grad.h +++ b/src/relax/op/tensor/grad.h @@ -41,26 +41,26 @@ Expr no_grad(Expr input); /*! \brief Backward operator of relax.nll_loss. All parameters except output_grad is the same as * relax.nll_loss. Returns the gradient w.r.t. predictions. */ -Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optional weights, - String reduction, int ignore_index); +Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, + ffi::Optional weights, ffi::String reduction, int ignore_index); /*! \brief Backward operator of relax.max_pool2d. All parameters except output_grad is the same as * relax.max_pool2d. Returns the gradient w.r.t. data. */ -Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.avg_pool2d. All parameters except output_grad is the same as * relax.avg_pool2d. Returns the gradient w.r.t. data. */ -Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, - Array strides, Array padding, Array dilation, - bool ceil_mode, bool count_include_pad, String layout, - Optional out_layout); +Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array pool_size, + ffi::Array strides, ffi::Array padding, + ffi::Array dilation, bool ceil_mode, bool count_include_pad, + ffi::String layout, ffi::Optional out_layout); /*! \brief Backward operator of relax.take. All parameters except output_grad is the same as * relax.take. Returns the gradient w.r.t. data. */ -Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axis); +Expr take_backward(Expr output_grad, Expr x, Expr indices, ffi::Optional axis); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index dea79b804bb4..5780cd9cce1f 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -44,8 +44,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.take */ -Expr take(Expr x, Expr indices, Optional axis, String mode) { - ObjectPtr attrs = make_object(); +Expr take(Expr x, Expr indices, ffi::Optional axis, ffi::String mode) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); @@ -70,7 +70,7 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { if (auto tensor_sinfo = sinfo.as()) { return tensor_sinfo.value(); } else if (auto prim_sinfo = sinfo.as()) { - return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + return TensorStructInfo(ShapeExpr(ffi::Array{}), prim_sinfo->dtype); } else { ctx->ReportFatal(Diagnostic::Error(call) << "Operator " << call->op << " requires the indices argument to be " @@ -115,7 +115,7 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { data_sinfo->vdevice); } - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { if (i == axis) { for (int j = 0; j < indices_sinfo->ndim; j++) @@ -137,7 +137,7 @@ TVM_REGISTER_OP("relax.take") /* relax.strided_slice */ -Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides, +Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, ffi::Optional strides, bool assume_inbound) { // Initial validation of the arguments. A more complete validation // will be done when inferring the StructInfo, but that requires the @@ -165,10 +165,10 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid check_tuple("end", end); if (strides.defined()) check_tuple("strides", strides.value()); - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->assume_inbound = assume_inbound; - Array args = {x, axes, begin, end}; + ffi::Array args = {x, axes, begin, end}; if (strides.defined()) { args.push_back(strides.value()); } @@ -198,7 +198,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * a tuple from a `TensorStructInfo`.) * * \tparam PrimType The subtype of PrimExpr to extract. For example, - * extracting an `Array` + * extracting an `ffi::Array` * * \param sinfo The StructInfo to inspect * @@ -207,12 +207,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ */ template >> -Optional> UnpackTupleOfPrimValue(Optional sinfo) { +ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional sinfo) { if (!sinfo) return std::nullopt; // An ObjectStructInfo may contain a tuple of the desired type, but // it isn't yet known whether it does. Return early, as we cannot - // provide a known `Array` to the caller. + // provide a known `ffi::Array` to the caller. if (sinfo.as()) return std::nullopt; auto tuple = sinfo.as(); @@ -220,7 +220,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { << "The struct info " << sinfo << " cannot contain a tuple whose elements are " << PrimType::ContainerType::_type_key; - Array output; + ffi::Array output; for (size_t i = 0; i < tuple->fields.size(); i++) { auto field = tuple->fields[i]; @@ -235,7 +235,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { if (!prim_sinfo->value.defined()) return std::nullopt; - Optional element = prim_sinfo->value.as(); + ffi::Optional element = prim_sinfo->value.as(); if (!element) return std::nullopt; output.push_back(element.value()); @@ -257,7 +257,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { * a tuple from a `TensorStructInfo`.) * * \tparam PrimType The subtype of PrimExpr to extract. For example, - * extracting an `Array` + * extracting an `ffi::Array` * * \param expr The `relax::Expr` to inspect * @@ -266,7 +266,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { */ template >> -Optional> UnpackTupleOfPrimValue(Optional expr) { +ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional expr) { if (expr) { return UnpackTupleOfPrimValue(GetStructInfo(expr.value())); } else { @@ -285,7 +285,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx Expr axes = call->args[1]; Expr begin = call->args[2]; Expr end = call->args[3]; - Optional strides = [&]() -> Optional { + ffi::Optional strides = [&]() -> ffi::Optional { if (n_args > 4) { return call->args[4]; } else { @@ -296,7 +296,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx auto axes_sinfo = GetStructInfo(call->args[1]); auto begin_sinfo = GetStructInfo(call->args[2]); auto end_sinfo = GetStructInfo(call->args[3]); - auto strides_sinfo = [&]() -> Optional { + auto strides_sinfo = [&]() -> ffi::Optional { if (n_args > 4) { return GetStructInfo(call->args[4]); } else { @@ -342,7 +342,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx const auto* data_sinfo = data->struct_info_.as(); DataType dtype = DataType::Void(); - Optional vdevice = std::nullopt; + ffi::Optional vdevice = std::nullopt; int ndim = kUnknownNDim; if (data_sinfo) { dtype = data_sinfo->dtype; @@ -350,7 +350,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx ndim = data_sinfo->ndim; } - Optional shape = [&]() -> Optional { + ffi::Optional shape = [&]() -> ffi::Optional { if (!data_sinfo) return std::nullopt; if (!data_sinfo->shape) return std::nullopt; @@ -378,14 +378,14 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple << ") and " << end_tuple.size() << " 'end' indices specified (" << end_tuple << ")"; - Array strides_tuple; + ffi::Array strides_tuple; if (strides.defined()) { auto opt_strides_tuple = UnpackTupleOfPrimValue(strides); if (!opt_strides_tuple) return std::nullopt; strides_tuple = opt_strides_tuple.value(); } else { - strides_tuple = Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); + strides_tuple = ffi::Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); } CHECK_EQ(axes_tuple.size(), strides_tuple.size()) @@ -406,7 +406,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple); auto attrs = call->attrs.as(); - Array output_shape = data_sinfo->GetShape().value(); + ffi::Array output_shape = data_sinfo->GetShape().value(); for (size_t i = 0; i < axes.size(); i++) { size_t axis = axes[i]; PrimExpr input_dim = output_shape[axis]; @@ -436,9 +436,9 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx } } -InferLayoutOutput InferLayoutStridedSlice(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStridedSlice( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -460,9 +460,9 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call, << " requires slices to be along static axes. " << "However, expression " << call << " slices along non-static axes " << call->args[1]; - Array axes_tuple = opt_axes_tuple.value(); + ffi::Array axes_tuple = opt_axes_tuple.value(); - Array new_axes; + ffi::Array new_axes; for (const auto& axis : axes_tuple) { int new_axis = FindAxis(existing_layout->layout, axis->value); new_axes.push_back(relax::PrimValue::Int64(new_axis)); @@ -515,7 +515,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& } int n_axis = data_sinfo->ndim; - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name) { ICHECK(sinfo) << "Dynamic strided slice requires the input " << name << " to be have the struct info. Please try normalizing the inputs."; CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name @@ -524,7 +524,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ICHECK(shape) << "Dynamic strided slice requires the input " << name << " to have well-defined shape."; // NOTE(tvm-team): This strong restriction seems necessary for now until we have a generic - // solution in converting 1d Tensor with unknown num_elem to Array. + // solution in converting 1d Tensor with unknown num_elem to ffi::Array. const auto* num_elem = shape->values[0].as(); ICHECK(num_elem) << "Dynamic strided slice requires the input " << name << " to have a known integer shape value."; diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h index a45fb93792ed..0c5b45c68f2c 100644 --- a/src/relax/op/tensor/index.h +++ b/src/relax/op/tensor/index.h @@ -41,7 +41,7 @@ namespace relax { * \param mode The mode for handling out-of-bounds indices. * \return The taken result. */ -Expr take(Expr x, Expr indices, Optional axis, String mode = "fast"); +Expr take(Expr x, Expr indices, ffi::Optional axis, ffi::String mode = "fast"); /*! * \brief Strided slice of a tensor. @@ -55,8 +55,8 @@ Expr take(Expr x, Expr indices, Optional axis, String mode = "fast"); * \param assume_inbound Whether to assume the indices are in bound. * \return The sliced result */ -Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides = std::nullopt, - bool assume_inbound = false); +Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, + ffi::Optional strides = std::nullopt, bool assume_inbound = false); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 7dd193ce37cb..01843ba0a3c0 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -85,7 +85,7 @@ std::tuple GetTensorArgInfoWithIndex(const Cal << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; } - return {GetRef(tensor_sinfo), GetRef(axis_sinfo)}; + return {ffi::GetRef(tensor_sinfo), ffi::GetRef(axis_sinfo)}; } DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } @@ -103,7 +103,7 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); - tir::PrimFunc func(Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); + tir::PrimFunc func(ffi::Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, PrimStructInfo(field_dtype)); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index dcd2a1e24fca..e50ca70f60ce 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -41,8 +41,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.matmul */ -Expr matmul(Expr x1, Expr x2, Optional out_dtype) { - ObjectPtr attrs = make_object(); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { + ObjectPtr attrs = ffi::make_object(); attrs->out_dtype = out_dtype.value_or(DataType::Void()); static const Op& op = Op::Get("relax.matmul"); @@ -55,7 +55,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); Expr lhs = call->args[0]; Expr rhs = call->args[1]; TensorStructInfo x1_sinfo = input_sinfo[0]; @@ -121,11 +121,11 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(out_dtype, output_ndim); } - Array x1_shape_prefix{x1_shape->values.begin(), - x1_shape->values.end() - 2 + x1_prepended}; - Array x2_shape_prefix{x2_shape->values.begin(), - x2_shape->values.end() - 2 + x2_appended}; - Optional> output_shape_prefix = + ffi::Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + ffi::Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); if (!output_shape_prefix.defined()) { if (vdev.defined()) { @@ -146,7 +146,7 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { << x2_reduction_length << " are not equal."); } - Array output_shape = output_shape_prefix.value(); + ffi::Array output_shape = output_shape_prefix.value(); if (!x1_prepended) { output_shape.push_back(x1_shape->values[x1_ndim - 2]); } @@ -175,8 +175,8 @@ TVM_REGISTER_OP("relax.matmul") /* relax.einsum */ -Expr einsum(Expr operands, String subscripts) { - ObjectPtr attrs = make_object(); +Expr einsum(Expr operands, ffi::String subscripts) { + ObjectPtr attrs = ffi::make_object(); attrs->subscripts = std::move(subscripts); static const Op& op = Op::Get("relax.einsum"); @@ -192,7 +192,7 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "Einsum op should take 1 argument"); } - Array operands_tensor_sinfo = + ffi::Array operands_tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (operands_tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) @@ -219,10 +219,10 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { } } - String subscripts = attrs->subscripts; + ffi::String subscripts = attrs->subscripts; DataType operand_dtype = operands_tensor_sinfo[0]->dtype; - std::vector> input_shapes; + std::vector> input_shapes; input_shapes.reserve(operands_tensor_sinfo.size()); for (TensorStructInfo tensor_sinfo : operands_tensor_sinfo) { @@ -246,7 +246,7 @@ StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { } } // Calculate output shape using InferEinsumShape in topi - Array oshape = topi::InferEinsumShape(subscripts, input_shapes); + ffi::Array oshape = topi::InferEinsumShape(subscripts, input_shapes); if (!vdevice_unknown) { return TensorStructInfo(ShapeExpr(oshape), operand_dtype, vdev); @@ -290,7 +290,7 @@ StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { if (!x1_shape || !x2_shape) { return TensorStructInfo(x1_sinfo->dtype, 2); } - Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; + ffi::Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype); } diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index eb003fed1c76..ddfceae4dc35 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -41,7 +41,7 @@ namespace relax { * When it is not specified, the output dtype will be the same as input dtype. * \return The computed result. */ -Expr matmul(Expr x1, Expr x2, Optional out_dtype); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); /*! * \brief Einstein summation on the operands. @@ -49,7 +49,7 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype); * \param subscripts The einsum expression string. * \return The computed result. */ -Expr einsum(Expr operands, String subscripts); +Expr einsum(Expr operands, ffi::String subscripts); /*! * \brief Compute the outer product of two input expressions. diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 83b157034279..1e3844982d4b 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -107,8 +107,8 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) } arith::Analyzer* analyzer = ctx->GetAnalyzer(); - Array old_shape_value = shape_sinfo->values.value(); - Array tgt_shape_value = tgt_shape_sinfo->values.value(); + ffi::Array old_shape_value = shape_sinfo->values.value(); + ffi::Array tgt_shape_value = tgt_shape_sinfo->values.value(); int old_ndim = old_shape_value.size(); int tgt_ndim = tgt_shape_value.size(); for (int i = 0; i < old_ndim; ++i) { @@ -141,8 +141,8 @@ TVM_REGISTER_OP("relax.broadcast_to") /* relax.concat */ -Expr concat(Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); +Expr concat(Expr tensors, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.concat"); @@ -154,9 +154,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.op.concat", concat); }); -Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { +ffi::Optional> CheckConcatOutputShape( + const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, int axis) { bool shape_unknown = false; arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr concat_sum = [&]() { @@ -174,7 +174,7 @@ Optional> CheckConcatOutputShape(const Call& call, const BlockBu // General case, add up the dimensions along the specified axis. PrimExpr concat_sum = IntImm(DataType::Int(64), 0); - for (Array shape_value : shape_values) { + for (ffi::Array shape_value : shape_values) { concat_sum += shape_value[axis]; } return concat_sum; @@ -201,7 +201,7 @@ Optional> CheckConcatOutputShape(const Call& call, const BlockBu if (shape_unknown) { return std::nullopt; } - Array output_shape = shape_values[0]; + ffi::Array output_shape = shape_values[0]; output_shape.Set(axis, concat_sum); return output_shape; } @@ -210,7 +210,8 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); } - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array tensor_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op expects at least one tensor in the input Tuple. However, the " @@ -220,11 +221,11 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1; DataType output_dtype = DataType::Void(); - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; - std::vector> shape_values; + std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); for (TensorStructInfo sinfo : tensor_sinfo) { @@ -310,7 +311,8 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } // As long as the there is known shape value, we will do the best effort check to ensure safety. - Optional> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis); + ffi::Optional> output_shape = + CheckConcatOutputShape(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { @@ -325,9 +327,9 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutConcat(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutConcat( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -338,12 +340,12 @@ InferLayoutOutput InferLayoutConcat(const Call& call, int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); - Array input_layouts, output_layouts; + ffi::Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); } output_layouts.push_back(layout); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(layout->layout, attrs->axis.value_or(0)); return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); } @@ -359,8 +361,8 @@ TVM_REGISTER_OP("relax.concat") /* relax.expand_dims */ -Expr expand_dims(Expr x, Array axis) { - ObjectPtr attrs = make_object(); +Expr expand_dims(Expr x, ffi::Array axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.expand_dims"); @@ -411,9 +413,9 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutExpandDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutExpandDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; @@ -462,7 +464,7 @@ TVM_REGISTER_OP("relax.expand_dims") .set_attr("FPurity", Bool(true)); // Helper function for flatten and reshape. -PrimExpr ComputeShapeProduct(const Array& shape_values) { +PrimExpr ComputeShapeProduct(const ffi::Array& shape_values) { PrimExpr shape_prod = IntImm(DataType::Int(64), 1); for (PrimExpr value : shape_values) { shape_prod *= value; @@ -525,7 +527,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + ffi::Array indices_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (indices_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) @@ -534,7 +537,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) DataType output_dtype = data_sinfo->dtype; int n_indices = static_cast(indices_sinfo.size()); - Optional vdev = data_sinfo->vdevice; + ffi::Optional vdev = data_sinfo->vdevice; // Indices must be integers for (int i = 0; i < n_indices; ++i) { @@ -555,7 +558,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) arith::Analyzer* analyzer = ctx->GetAnalyzer(); bool all_index_have_shape_value = true; - std::vector> index_shapes; + std::vector> index_shapes; int max_index_ndim = 0; for (const auto& s : indices_sinfo) { @@ -571,12 +574,12 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } } - Optional> broadcast_shape; + ffi::Optional> broadcast_shape; bool shape_unknown = !all_index_have_shape_value; if (all_index_have_shape_value) { // initialise broadcast result with 1's - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < max_index_ndim; ++i) { out_shape.push_back(IntImm(DataType::Int(64), 1)); } @@ -636,7 +639,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) if (broadcast_shape.defined()) { const auto* data_shape_expr = data_sinfo->shape.as(); if (data_shape_expr) { - Array result_shape = broadcast_shape.value(); + ffi::Array result_shape = broadcast_shape.value(); for (int i = n_indices; i < data_sinfo->ndim; ++i) { result_shape.push_back(data_shape_expr->values[i]); } @@ -657,10 +660,10 @@ TVM_REGISTER_OP("relax.index_tensor") /* relax.layout_transform */ -Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators, - Optional> input_axis_separators) { - ObjectPtr attrs = make_object(); +Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, + ffi::Optional> axis_separators, + ffi::Optional> input_axis_separators) { + ObjectPtr attrs = ffi::make_object(); attrs->index_map = std::move(index_map); attrs->pad_value = std::move(pad_value); attrs->axis_separators = std::move(axis_separators); @@ -679,7 +682,7 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); tir::IndexMap index_map = attrs->index_map; - Optional optional_pad_value = attrs->pad_value; + ffi::Optional optional_pad_value = attrs->pad_value; // Check pad_value has same dtype as input. if (optional_pad_value.defined()) { @@ -717,7 +720,7 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& } arith::Analyzer analyzer; - Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); + ffi::Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } @@ -731,8 +734,8 @@ TVM_REGISTER_OP("relax.layout_transform") /* relax.permute_dims */ -Expr permute_dims(Expr x, Optional> axes) { - ObjectPtr attrs = make_object(); +Expr permute_dims(Expr x, ffi::Optional> axes) { + ObjectPtr attrs = ffi::make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("relax.permute_dims"); @@ -798,9 +801,9 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutPermuteDims(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutPermuteDims( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -817,7 +820,7 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, existing_layout = LayoutDecision(InitialLayout(ndim)); } - Array order; + ffi::Array order; if (attrs->axes.defined()) { order = attrs->axes.value(); } else { @@ -830,13 +833,13 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call, for (const auto& axis : order) { order_str.push_back(axis->value + 'A'); } - String new_axes = + ffi::String new_axes = TransposeStrLike(InitialLayout(ndim).name(), existing_layout->layout, order_str); - Array new_order; + ffi::Array new_order; for (size_t i = 0; i < new_axes.size(); ++i) { new_order.push_back(Integer(new_axes.at(i) - 'A')); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axes = new_order; return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(ndim)}, Attrs(new_attrs)); } @@ -851,14 +854,15 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, + const ffi::Variant>& shape) { const ffi::ArrayObj* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { array = e->values.as(); // Other non-shape expressions are used directly. } else if (const auto* e = shape.as()) { - return GetRef(e); + return ffi::GetRef(e); // Process special values in constants and produce an expression. } else { array = shape.as(); @@ -874,7 +878,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " << shape; - PrimExpr len = GetRef(_len); + PrimExpr len = ffi::GetRef(_len); CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " "integers. However, the give new shape is " << shape; @@ -895,7 +899,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant } } - Array array_ref = GetRef>(array); + ffi::Array array_ref = ffi::GetRef>(array); // When there is no dimension to infer, just return the input array as ShapeExpr. if (dim_to_infer == -1 && zero_dims.empty()) { return ShapeExpr(array_ref); @@ -944,7 +948,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const Variant return ShapeExpr(array_ref); } -Expr reshape(Expr x, Variant> shape) { +Expr reshape(Expr x, ffi::Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -973,7 +977,7 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << call->args[1]->struct_info_->GetTypeKey()); } - Optional> old_shape_values; + ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); ICHECK_NOTNULL(old_shape_sinfo); @@ -1011,8 +1015,8 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ -Expr split(Expr x, Variant> indices_or_sections, int axis) { - ObjectPtr attrs = make_object(); +Expr split(Expr x, ffi::Variant> indices_or_sections, int axis) { + ObjectPtr attrs = ffi::make_object(); ObjectRef indices_or_sections_obj; if (const auto* indices = indices_or_sections.as()) { @@ -1022,7 +1026,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) "However, the given indices " << indices_or_sections << " contains some non-integer."; } - indices_or_sections_obj = ConvertIntImmToInt64(GetRef>(indices)); + indices_or_sections_obj = ConvertIntImmToInt64(ffi::GetRef>(indices)); } else if (const auto* n_section = indices_or_sections.as()) { CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " "positive integer. However, the given number of sections is " @@ -1051,7 +1055,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { int axis = data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); - if (auto opt_indices = attrs->indices_or_sections.as>()) { + if (auto opt_indices = attrs->indices_or_sections.as>()) { auto p_indices = opt_indices.value(); // When there is not index, return the input tensor's struct info. if (p_indices.size() == 0) { @@ -1059,7 +1063,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(Array( + return TupleStructInfo(ffi::Array( p_indices.size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } @@ -1091,7 +1095,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { split_dim = tvm::max(split_dim, 0); split_dim = ctx->GetAnalyzer()->Simplify(split_dim); - Array shape = data_shape->values; + ffi::Array shape = data_shape->values; shape.Set(axis, split_dim); output_sinfo.push_back( TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); @@ -1106,7 +1110,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { - return TupleStructInfo(Array( + return TupleStructInfo(ffi::Array( n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } ICHECK_NE(axis, -1); @@ -1114,7 +1118,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { split_len = ctx->GetAnalyzer()->Simplify(split_len); // Construct struct info for tensors except the last one. - Array shape = data_shape->values; + ffi::Array shape = data_shape->values; shape.Set(axis, split_len); std::vector output_sinfo( n_section - 1, TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); @@ -1131,9 +1135,9 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { throw; } -InferLayoutOutput InferLayoutSplit(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSplit( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1157,7 +1161,8 @@ InferLayoutOutput InferLayoutSplit(const Call& call, "output structinfo, but got " << si; auto sinfo = Downcast(si); - Optional shape_expr = GetRef(sinfo->shape.as()); + ffi::Optional shape_expr = + ffi::GetRef(sinfo->shape.as()); CHECK(shape_expr.defined()); auto shape_arr = shape_expr.value(); if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, @@ -1168,10 +1173,10 @@ InferLayoutOutput InferLayoutSplit(const Call& call, } } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); ICHECK(out_tuple != nullptr) << "Invalid Call"; - NLayout tuple_layouts(Array(out_tuple->fields.size(), existing_layout)); + NLayout tuple_layouts(ffi::Array(out_tuple->fields.size(), existing_layout)); return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); } @@ -1186,8 +1191,8 @@ TVM_REGISTER_OP("relax.split") /* relax.squeeze */ -Expr squeeze(Expr x, Optional> axis) { - ObjectPtr attrs = make_object(); +Expr squeeze(Expr x, ffi::Optional> axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.squeeze"); @@ -1210,7 +1215,7 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); } - Optional> shape_value; + ffi::Optional> shape_value; if (data_sinfo->shape.defined()) { shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; } @@ -1280,9 +1285,9 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutSqueeze(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutSqueeze( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1295,7 +1300,7 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, const auto* shape = tensor_sinfo->shape.as(); ICHECK(shape != nullptr) << "Only support static shape for now"; - Array axis; + ffi::Array axis; if (attrs->axis.defined()) { axis = attrs->axis.value(); } else { @@ -1322,8 +1327,9 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { existing_layout = LayoutDecision(InitialLayout(ndim)); } - String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); - Array new_axis; + ffi::String new_axis_str = + TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); + ffi::Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { if (new_axis_str.at(i) == '1') { new_axis.push_back(Integer(i)); @@ -1333,7 +1339,7 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call, output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'), output_layout.end()); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = new_axis; return InferLayoutOutput({existing_layout}, {LayoutDecision(Layout(output_layout))}, Attrs(new_attrs)); @@ -1349,7 +1355,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FPurity", Bool(true)); void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, - const Array& data_shape, const Array& target_shape) { + const ffi::Array& data_shape, + const ffi::Array& target_shape) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); int data_ndim = data_shape.size(); @@ -1388,8 +1395,8 @@ void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, /* relax.stack */ -Expr stack(Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); +Expr stack(Expr tensors, ffi::Optional axis) { + ObjectPtr attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.stack"); @@ -1401,9 +1408,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.op.stack", stack); }); -Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { +ffi::Optional> CheckStackOutputShape( + const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, int axis) { bool shape_unknown = false; arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1426,7 +1433,7 @@ Optional> CheckStackOutputShape(const Call& call, const BlockBui } // Insert new dimension at axis position - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < axis; ++i) { output_shape.push_back(shape_values[0][i]); } @@ -1442,7 +1449,8 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 argument"); } - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array tensor_sinfo = + GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Stack op expects at least one tensor in the input Tuple. " @@ -1455,11 +1463,11 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { // Default axis is 0 if not specified int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension DataType output_dtype = DataType::Void(); - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; - std::vector> shape_values; + std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); for (TensorStructInfo sinfo : tensor_sinfo) { @@ -1522,7 +1530,7 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } return TensorStructInfo(output_dtype, output_ndim); } - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < axis; ++i) { output_shape.push_back(shape_values[0][i]); } @@ -1544,7 +1552,8 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(output_dtype, output_ndim); } - Optional> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis); + ffi::Optional> output_shape = + CheckStackOutputShape(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { return TensorStructInfo(output_dtype, output_ndim, vdev); @@ -1558,9 +1567,9 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutStack(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStack( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -1571,7 +1580,7 @@ InferLayoutOutput InferLayoutStack(const Call& call, int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); - Array input_layouts, output_layouts; + ffi::Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); } @@ -1583,7 +1592,7 @@ InferLayoutOutput InferLayoutStack(const Call& call, Layout output_layout = Layout(layout_str); output_layouts.push_back(LayoutDecision(output_layout)); - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = Integer(FindAxis(layout->layout, axis)); return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); } @@ -1609,17 +1618,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo collapse_target_sinfo = input_sinfo[1]; DataType output_dtype = data_sinfo->dtype; - Optional> data_shape_value; + ffi::Optional> data_shape_value; if (data_sinfo->shape.defined()) { data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; } - Optional> collapse_target_shape_value; + ffi::Optional> collapse_target_shape_value; if (collapse_target_sinfo->shape.defined()) { collapse_target_shape_value = GetStructInfoAs(collapse_target_sinfo->shape.value())->values; @@ -1680,7 +1689,7 @@ StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ct DataType output_dtype = data_sinfo->dtype; - Optional> data_shape_value; + ffi::Optional> data_shape_value; if (data_sinfo->shape.defined()) { data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; } @@ -1700,8 +1709,8 @@ TVM_REGISTER_OP("relax.collapse_sum_to") /* relax.repeat */ -Expr repeat(Expr data, int repeats, Optional axis) { - auto attrs = make_object(); +Expr repeat(Expr data, int repeats, ffi::Optional axis) { + auto attrs = ffi::make_object(); attrs->repeats = std::move(repeats); attrs->axis = std::move(axis); @@ -1748,7 +1757,7 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { if (!attrs->axis.has_value()) { PrimExpr new_shape = analyzer->Simplify(ComputeShapeProduct(data_shape->values) * attrs->repeats); - return TensorStructInfo(ShapeExpr(Array({new_shape})), data_sinfo->dtype, + return TensorStructInfo(ShapeExpr(ffi::Array({new_shape})), data_sinfo->dtype, data_sinfo->vdevice); } @@ -1768,8 +1777,8 @@ TVM_REGISTER_OP("relax.repeat") /* relax.tile */ -Expr tile(Expr data, Array repeats) { - auto attrs = make_object(); +Expr tile(Expr data, ffi::Array repeats) { + auto attrs = ffi::make_object(); attrs->repeats = std::move(repeats); static const Op& op = Op::Get("relax.tile"); @@ -1809,7 +1818,7 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { int out_ndim = std::max(l, ndim); int l_delta = out_ndim - l; int ndim_delta = out_ndim - ndim; - Array out_shape; + ffi::Array out_shape; for (int i = 0; i < out_ndim; ++i) { if (i < l_delta) { out_shape.push_back(data_shape->values[i - ndim_delta]); @@ -1835,7 +1844,7 @@ TVM_REGISTER_OP("relax.tile") /* relax.flip */ Expr flip(Expr data, Integer axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.flip"); return Call(op, {std::move(data)}, Attrs{attrs}, {}); @@ -1874,7 +1883,7 @@ TVM_REGISTER_OP("relax.flip") /* relax.gather_elements */ Expr gather_elements(Expr data, Expr indices, int axis) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = Integer(axis); static const Op& op = Op::Get("relax.gather_elements"); return Call(op, {data, indices}, Attrs(attrs), {}); @@ -1945,7 +1954,7 @@ TVM_REGISTER_OP("relax.gather_elements") /* relax.gather_nd */ Expr gather_nd(Expr data, Expr indices, int batch_dims) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->batch_dims = Integer(batch_dims); static const Op& op = Op::Get("relax.gather_nd"); return Call(op, {data, indices}, Attrs(attrs), {}); @@ -2012,7 +2021,7 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { } // In this condition, all input shapes are known - Array out_shape; + ffi::Array out_shape; if (l > input_dims - batch_dims) { ctx->ReportFatal(Diagnostic::Error(call) << "GatherND requires the last dimension of indices to be less than or " @@ -2041,7 +2050,7 @@ TVM_REGISTER_OP("relax.gather_nd") /* relax.index_put */ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->accumulate = std::move(accumulate); static const Op& op = Op::Get("relax.index_put"); return Call(op, {data, indices, values}, Attrs(attrs), {}); @@ -2056,7 +2065,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* values_sinfo = GetStructInfoAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires the input " << name @@ -2068,7 +2077,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey()); // Handle indices: either a single tensor or a tuple of tensors - Array indices_tensors; + ffi::Array indices_tensors; if (const auto* tuple_sinfo = GetStructInfoAs(call->args[1])) { // Indices is a tuple of tensors @@ -2080,11 +2089,11 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { << "However, element " << i << " is " << tuple_sinfo->fields[i]->GetTypeKey()); } - indices_tensors.push_back(GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); } } else if (const auto* tensor_sinfo = GetStructInfoAs(call->args[1])) { // Indices is a single tensor - indices_tensors.push_back(GetRef(tensor_sinfo)); + indices_tensors.push_back(ffi::GetRef(tensor_sinfo)); } else { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires indices to be a Tensor or a tuple of Tensors. " @@ -2123,7 +2132,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { // Check data and values dtype compatibility if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { LOG(WARNING) << "Data type of " << name << " has not been specified. Assume it has an integer type."; @@ -2165,8 +2174,8 @@ TVM_REGISTER_OP("relax.index_put") /* relax.meshgrid */ -Expr meshgrid(Expr tensors, Optional indexing) { - ObjectPtr attrs = make_object(); +Expr meshgrid(Expr tensors, ffi::Optional indexing) { + ObjectPtr attrs = ffi::make_object(); attrs->indexing = indexing; static const Op& op = Op::Get("relax.meshgrid"); return Call(op, {std::move(tensors)}, Attrs(attrs), {}); @@ -2181,7 +2190,7 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple input argument."); } - Array input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + ffi::Array input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); int n_inputs = input_sinfo.size(); @@ -2193,7 +2202,7 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { std::vector lengths; DataType common_dtype = DataType::Void(); bool shape_unknown = false; - Optional vdev = std::nullopt; + ffi::Optional vdev = std::nullopt; bool vdevice_unknown = false; for (int i = 0; i < n_inputs; ++i) { @@ -2233,14 +2242,14 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { } } - Array out_shape; + ffi::Array out_shape; if (!shape_unknown && lengths.size() == static_cast(n_inputs)) { for (const PrimExpr& dim : lengths) { out_shape.push_back(dim); } } - Array out_fields; + ffi::Array out_fields; for (int i = 0; i < n_inputs; ++i) { if (!out_shape.empty()) { if (!vdevice_unknown) { @@ -2270,8 +2279,8 @@ TVM_REGISTER_OP("relax.meshgrid") /* relax.scatter_elements */ -Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction) { - auto attrs = make_object(); +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, ffi::String reduction) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_elements"); @@ -2289,7 +2298,7 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& const auto* indices_sinfo = GetStructInfoAs(call->args[1]); const auto* updates_sinfo = GetStructInfoAs(call->args[2]); - auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name, ffi::String type_key) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "ScatterElements requires the input " << name @@ -2325,7 +2334,7 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& } if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { - auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? LOG(WARNING) << "Data type of " << name @@ -2387,8 +2396,8 @@ TVM_REGISTER_OP("relax.scatter_elements") /* relax.scatter_nd */ -Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { - auto attrs = make_object(); +Expr scatter_nd(Expr data, Expr indices, Expr updates, ffi::String reduction) { + auto attrs = ffi::make_object(); attrs->reduction = std::move(reduction); static const Op& op = Op::Get("relax.scatter_nd"); return Call(op, {data, indices, updates}, Attrs(attrs), {}); @@ -2479,14 +2488,15 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { << "data: " << ShapeExpr(data_shape->values) << ", indices: " << ShapeExpr(indices_shape->values)); } - Array expected_updates_shape; + ffi::Array expected_updates_shape; for (size_t i = 0; i < indices_ndim - 1; i++) { expected_updates_shape.push_back(indices_shape->values[i]); } for (size_t i = k_dim->value; i < data_ndim; i++) { expected_updates_shape.push_back(data_shape->values[i]); } - auto check_shape = [&](const Array& expected, const Array& actual) { + auto check_shape = [&](const ffi::Array& expected, + const ffi::Array& actual) { if (expected.size() != actual.size()) { return false; } @@ -2524,7 +2534,7 @@ TVM_REGISTER_OP("relax.scatter_nd") /* relax.scatter_nd */ Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.slice_scatter"); return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); @@ -2542,7 +2552,7 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx auto* attrs = call->attrs.as(); auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr, - String name) { + ffi::String name) { if (sinfo == nullptr) { ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the input " << name << " to be a Tensor. However, the given one is " @@ -2576,7 +2586,7 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx } if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) { - auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) { + auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { LOG(WARNING) << "SliceScatter: Data type of " << name << " has not been specified for call node " << call @@ -2681,7 +2691,7 @@ TVM_REGISTER_OP("relax.slice_scatter") /* relax.one_hot */ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->depth = depth; attrs->axis = axis; @@ -2732,7 +2742,7 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); } - Array output_shape = indices_shape->values; + ffi::Array output_shape = indices_shape->values; int axis = attrs->axis; if (axis < 0) { axis += output_shape.size() + 1; diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index cc15d5d4ab76..84d53addcc69 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -44,7 +44,7 @@ Expr broadcast_to(Expr x, Expr shape); * If it is `std::nullopt`, the input tensor is required to be flattened before concatenation. * \return The concatenated tensor. */ -Expr concat(Expr tensors, Optional axis); +Expr concat(Expr tensors, ffi::Optional axis); /*! * \brief Insert new axes at the positions given by `axis`. @@ -52,7 +52,7 @@ Expr concat(Expr tensors, Optional axis); * \param axis The axes at which the input array are expanded. * \return The transformed result. */ -Expr expand_dims(Expr x, Array axis); +Expr expand_dims(Expr x, ffi::Array axis); /*! * \brief Flatten all the tensor dimensions into one. @@ -72,9 +72,9 @@ Expr flatten(Expr x); * \param input axis_separators Array of values for input buffer. * \return The transformed result. */ -Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators, - Optional> input_axis_separators = std::nullopt); +Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, + ffi::Optional> axis_separators, + ffi::Optional> input_axis_separators = std::nullopt); /*! * \brief Permutes the dimensions of an array. @@ -82,7 +82,7 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v * \param axes The target axes order, reverse order if not specified. * \return The transposed result. */ -Expr permute_dims(Expr x, Optional> axes); +Expr permute_dims(Expr x, ffi::Optional> axes); /*! * \brief Reshape the input array, supporting `-1` inference in the new @@ -92,7 +92,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, Variant> shape); +Expr reshape(Expr x, ffi::Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -107,7 +107,7 @@ Expr reshape(Expr x, Variant> shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, Variant> indices_or_sections, int axis); +Expr split(Expr x, ffi::Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. @@ -117,14 +117,14 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) * If any specified axis has dimension that does not equal 1, it is an error. * \return The squeezed result. */ -Expr squeeze(Expr x, Optional> axis); +Expr squeeze(Expr x, ffi::Optional> axis); /*! * \brief Stack tensors along the specified axis. * \param tensors The input tensors to be stacked. * \param axis The axis along which the tensors will be stacked. * \return The stacked result. */ -Expr stack(Expr tensors, Optional axis); +Expr stack(Expr tensors, ffi::Optional axis); /*! * \brief Return a summation of data to the shape of collapse_target. * For details, please see the operator `relax.collapse_sum_to`. @@ -154,7 +154,7 @@ Expr collapse_sum_to(Expr data, Expr shape); * from the backward. By default, use the flattened input array, and return a flat output array. * \return The computed result. */ -Expr repeat(Expr data, int repeats, Optional axis = std::nullopt); +Expr repeat(Expr data, int repeats, ffi::Optional axis = std::nullopt); /*! * \brief Construct an array by repeating data the number of times given by reps. @@ -171,7 +171,7 @@ Expr repeat(Expr data, int repeats, Optional axis = std::nullopt); * \param repeats The number of repetitions of data along each axis. * \return The computed result. */ -Expr tile(Expr data, Array repeats); +Expr tile(Expr data, ffi::Array repeats); /*! * \brief Reverses the order of elements along given axis. @@ -238,7 +238,7 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false); * \param indexing Indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing). * \return A tuple of tensors representing the coordinate grids. */ -Expr meshgrid(Expr tensors, Optional indexing = String("ij")); +Expr meshgrid(Expr tensors, ffi::Optional indexing = ffi::String("ij")); /*! * \brief Scatter updates into an array according to indices. @@ -250,7 +250,7 @@ Expr meshgrid(Expr tensors, Optional indexing = String("ij")); * either "update", "add", "mul", "mean", "max" or "min". * \return The computed result. */ -Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, ffi::String reduction); /*! * \brief Scatter updates into an array according to indices. @@ -271,7 +271,7 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re * The shape of `updates` must match the shape of `indices` except for the last dimension, * which must match the slice shape at each index. */ -Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +Expr scatter_nd(Expr data, Expr indices, Expr updates, ffi::String reduction); /*! * \brief Embeds the values of the src tensor into input at the given dimension. diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index a51d85820e40..7d51020be806 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ QuantizeAttrs::RegisterReflection(); }); /* relax.quantize */ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.quantize"); @@ -93,7 +93,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, String param_name) { + const TensorStructInfo& data_sinfo, ffi::String param_name) { const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { @@ -108,7 +108,7 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; return TensorStructInfo(output_sinfo); } @@ -125,7 +125,7 @@ TVM_REGISTER_OP("relax.quantize") /* relax.dequantize */ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("relax.dequantize"); @@ -181,7 +181,7 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } auto check_param_size = [&](const TensorStructInfo& param_sinfo, - const TensorStructInfo& data_sinfo, String param_name) { + const TensorStructInfo& data_sinfo, ffi::String param_name) { const PrimExpr& param_dim = param_sinfo->GetShape().value()[0]; const PrimExpr& input_dim = data_sinfo->GetShape().value()[axis]; if (!ctx->GetAnalyzer()->CanProveEqual(param_dim, input_dim)) { @@ -196,7 +196,7 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, "scale"); if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, "zero_point"); - auto output_sinfo = make_object(*input_sinfo.get()); + auto output_sinfo = ffi::make_object(*input_sinfo.get()); output_sinfo->dtype = attrs->out_dtype; return TensorStructInfo(output_sinfo); } diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 803e0a654d1c..7507ef4357c7 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -37,7 +37,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ MultinomialFromUniformAttrs::RegisterReflection(); } /* relax.multinomial_from_uniform */ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype) { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.multinomial_from_uniform"); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index d1ebae3a4fdc..3db995837a97 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.bucketize */ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->out_int32 = std::move(out_int32); attrs->right = std::move(right); static const Op& op = Op::Get("relax.bucketize"); @@ -53,7 +53,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo input_tensor_info = input_sinfo[0]; TensorStructInfo boundaries_info = input_sinfo[1]; @@ -99,7 +99,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo cond_sinfo = input_sinfo[0]; TensorStructInfo x1_sinfo = input_sinfo[1]; TensorStructInfo x2_sinfo = input_sinfo[2]; @@ -139,7 +139,7 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { const auto* x2_shape = x2_sinfo->shape.as(); if (cond_shape && x1_shape && x2_shape) { // Step 1. Compute the broadcasted shape of x1's and x2's - Optional> broadcasted_shape = + ffi::Optional> broadcasted_shape = InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); if (!broadcasted_shape.defined()) { if (vdev.defined()) { @@ -220,12 +220,13 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx const auto* data_shape = data_sinfo->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.has_value() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo(ShapeExpr(Array(out_ndim, IntImm(out_dtype, /*value=*/1))), - out_dtype, data_sinfo->vdevice); + return TensorStructInfo( + ShapeExpr(ffi::Array(out_ndim, IntImm(out_dtype, /*value=*/1))), out_dtype, + data_sinfo->vdevice); } else { - return out_ndim == 0 - ? TensorStructInfo(ShapeExpr(Array()), out_dtype, data_sinfo->vdevice) - : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice); + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), out_dtype, + data_sinfo->vdevice) + : TensorStructInfo(out_dtype, out_ndim, data_sinfo->vdevice); } } @@ -233,7 +234,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx out_dtype = data_shape->values[0]->dtype; } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.has_value() && i != axis) { @@ -247,8 +248,8 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx } #define RELAX_REGISTER_ARGMAX_ARGMIN_OP(OpName) \ - Expr OpName(Expr x, Optional axis, bool keepdims) { \ - ObjectPtr attrs = make_object(); \ + Expr OpName(Expr x, ffi::Optional axis, bool keepdims) { \ + ObjectPtr attrs = ffi::make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = std::move(keepdims); \ static const Op& op = Op::Get("relax." #OpName); \ diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h index 333b5afe76c7..d1cc6e39f43c 100644 --- a/src/relax/op/tensor/search.h +++ b/src/relax/op/tensor/search.h @@ -48,10 +48,10 @@ Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right); Expr where(Expr condition, Expr x1, Expr x2); /*! \brief Computes the argmax of tensor elements over given axis. */ -Expr argmax(Expr x, Optional axis, bool keepdims); +Expr argmax(Expr x, ffi::Optional axis, bool keepdims); /*! \brief Computes the argmin of tensor elements over given axis. */ -Expr argmin(Expr x, Optional axis, bool keepdims); +Expr argmin(Expr x, ffi::Optional axis, bool keepdims); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index f0fc3871371c..eb03725f8587 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -36,7 +36,7 @@ namespace relax { /* relax.unique */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, - PrimValue return_counts, Optional axis) { + PrimValue return_counts, ffi::Optional axis) { static const Op& op = Op::Get("relax.unique"); Call call; if (!axis) { @@ -58,7 +58,7 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { PrimValue axis, return_index, return_inverse, return_counts; if (call->args.size() == 6) { if (auto* prim_value_node = call->args[5].as()) { - axis = GetRef(prim_value_node); + axis = ffi::GetRef(prim_value_node); } } if (!data_sinfo->IsUnknownNdim() && axis.defined()) { @@ -79,7 +79,7 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { CHECK(value->IsInstance()) << value << " expects to be IntImm, but gets " << value->GetTypeKey(); const auto* val_node = value.as(); - auto val_imm = GetRef(val_node); + auto val_imm = ffi::GetRef(val_node); return val_imm->value; }; diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index 251dd1975e9f..4af7478d61ef 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -49,7 +49,7 @@ namespace relax { * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, - PrimValue return_counts, Optional axis); + PrimValue return_counts, ffi::Optional axis); /*! * \brief Returns the indices of the non-zero elements of the input tensor. diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 57e13fa26e01..de28f981567f 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /* relax.sort */ Expr sort(Expr data, int axis, bool descending) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); @@ -67,7 +67,7 @@ TVM_REGISTER_OP("relax.sort") /* relax.argsort */ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { - auto attrs = make_object(); + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); attrs->dtype = std::move(dtype); @@ -100,8 +100,8 @@ TVM_REGISTER_OP("relax.argsort") /* relax.topk */ -Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dtype) { - auto attrs = make_object(); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype) { + auto attrs = ffi::make_object(); attrs->k = std::move(k); attrs->axis = std::move(axis); attrs->ret_type = std::move(ret_type); @@ -124,7 +124,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { DataType indices_type = attrs->dtype.is_void() ? data_sinfo->dtype : attrs->dtype; int ndim = data_sinfo->ndim; int k = attrs->k; - String ret_type = attrs->ret_type; + ffi::String ret_type = attrs->ret_type; int axis = attrs->axis; if (axis < 0 && ndim > 0) { axis += ndim; @@ -137,7 +137,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice)); output_sinfos.push_back(TensorStructInfo(indices_type, data_sinfo->ndim, data_sinfo->vdevice)); } else { - Array out_shape = data_shape->values; + ffi::Array out_shape = data_shape->values; const auto* int_dim = out_shape[axis].as(); if (k > 0 && (int_dim == nullptr || k < int_dim->value)) { out_shape.Set(axis, k); diff --git a/src/relax/op/tensor/sorting.h b/src/relax/op/tensor/sorting.h index 8a785bc4e2b8..a4154ce416ad 100644 --- a/src/relax/op/tensor/sorting.h +++ b/src/relax/op/tensor/sorting.h @@ -63,7 +63,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype); * \param dtype The data type of the indices output. * \return The computed result. */ -Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dtype); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 700016b223ef..cb52a48ee848 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -69,16 +69,16 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { return TensorStructInfo( - ShapeExpr(Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + ShapeExpr(ffi::Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), data_sinfo->dtype, data_sinfo->vdevice); } else { - return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), data_sinfo->dtype, + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, data_sinfo->vdevice) : TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice); } } - Array out_shape; + ffi::Array out_shape; out_shape.reserve(out_ndim); for (int i = 0; i < data_sinfo->ndim; ++i) { if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { @@ -91,9 +91,9 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } -InferLayoutOutput InferLayoutStatistical(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutStatistical( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -103,7 +103,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; int ndim = tensor_sinfo->ndim; - Array axis; + ffi::Array axis; if (attrs->axis.defined()) { axis = attrs->axis.value(); } else { @@ -131,7 +131,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, [](unsigned char c) { return std::isdigit(c); }), new_axis_str.end()); - Array new_axis; + ffi::Array new_axis; for (size_t i = 0; i < new_axis_str.size(); ++i) { if (new_axis_str.at(i) == '#') { new_axis.push_back(Integer(i)); @@ -145,7 +145,7 @@ InferLayoutOutput InferLayoutStatistical(const Call& call, output_layout.push_back(output_layout_ref[i]); } - ObjectPtr new_attrs = make_object(*attrs); + ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = new_axis; return InferLayoutOutput({exisiting_layout}, {attrs->keepdims ? exisiting_layout : Layout(output_layout)}, @@ -168,7 +168,7 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { for (const auto v : data_shape->values) { flattened_d *= v; } - return TensorStructInfo(ShapeExpr(Array({flattened_d})), out_type, + return TensorStructInfo(ShapeExpr(ffi::Array({flattened_d})), out_type, data_sinfo->vdevice); } } @@ -181,8 +181,9 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { } /* relax.cumprod */ -Expr cumprod(Expr data, Optional axis, Optional dtype, Bool exclusive) { - auto attrs = make_object(); +Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, + Bool exclusive) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->dtype = std::move(dtype.value_or(DataType::Void())); attrs->exclusive = std::move(exclusive); @@ -204,8 +205,8 @@ TVM_REGISTER_OP("relax.cumprod") .set_attr("FPurity", Bool(true)); /* relax.cumsum */ -Expr cumsum(Expr data, Optional axis, Optional dtype, Bool exclusive) { - auto attrs = make_object(); +Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, Bool exclusive) { + auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->dtype = std::move(dtype.value_or(DataType::Void())); attrs->exclusive = std::move(exclusive); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index e79ce1d4aeaa..e100b544fb83 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -43,8 +43,8 @@ namespace relax { * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. */ #define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \ - Expr OpName(Expr x, Optional> axis, bool keepdims) { \ - ObjectPtr attrs = make_object(); \ + Expr OpName(Expr x, ffi::Optional> axis, bool keepdims) { \ + ObjectPtr attrs = ffi::make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = keepdims; \ static const Op& op = Op::Get("relax." #OpName); \ @@ -67,22 +67,22 @@ namespace relax { * reduced are left in the result as dimensions with size one. With this option, the result will * broadcast correctly against the input tensor. \return The result after reduction. */ -Expr max(Expr x, Optional> axis, bool keepdims); +Expr max(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the mean of tensor elements over given axes. */ -Expr mean(Expr x, Optional> axis, bool keepdims); +Expr mean(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the min of tensor elements over given axes. */ -Expr min(Expr x, Optional> axis, bool keepdims); +Expr min(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the product of tensor elements over given axes. */ -Expr prod(Expr x, Optional> axis, bool keepdims); +Expr prod(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the standard deviation of tensor elements over given axes. */ -Expr std(Expr x, Optional> axis, bool keepdims); +Expr std(Expr x, ffi::Optional> axis, bool keepdims); /*! \brief Computes the sum of tensor elements over given axes. */ -Expr sum(Expr x, Optional> axis, bool keepdims); +Expr sum(Expr x, ffi::Optional> axis, bool keepdims); /*! * \brief Numpy style cumprod op. Return the cumulative inclusive product of the elements along @@ -97,8 +97,8 @@ Expr sum(Expr x, Optional> axis, bool keepdims); * \return The computed * result. */ -Expr cumprod(Expr data, Optional axis = std::nullopt, - Optional dtype = std::nullopt, Bool exclusive = Bool(false)); +Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, + ffi::Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along @@ -112,11 +112,11 @@ Expr cumprod(Expr data, Optional axis = std::nullopt, * which the first element is not included. * \return The computed result. */ -Expr cumsum(Expr data, Optional axis = std::nullopt, - Optional dtype = std::nullopt, Bool exclusive = Bool(false)); +Expr cumsum(Expr data, ffi::Optional axis = std::nullopt, + ffi::Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! \brief Computes the variance of tensor elements over given axes. */ -Expr variance(Expr x, Optional> axis, bool keepdims); +Expr variance(Expr x, ffi::Optional> axis, bool keepdims); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index b60344e351d6..db7eea4661bc 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -30,7 +30,7 @@ namespace tvm { namespace relax { StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); TensorStructInfo t1 = input_sinfo[0]; TensorStructInfo t2 = input_sinfo[1]; TensorStructInfo t3 = input_sinfo[2]; @@ -87,7 +87,7 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { auto* s3 = t3->shape.as(); arith::Analyzer* analyzer = ctx->GetAnalyzer(); if (s1 && s2 && s3) { - Array output_shape; + ffi::Array output_shape; for (int i = 0; i < ndim; ++i) { PrimExpr dim1 = s1->values[i]; PrimExpr dim2 = s2->values[i]; @@ -115,9 +115,9 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(output_dtype, ndim); } -InferLayoutOutput InferLayoutEwiseFMA(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { +InferLayoutOutput InferLayoutEwiseFMA( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]); diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 49e5b862e900..2edb40cd2c80 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -39,13 +39,13 @@ namespace relax { /*! \brief Append the loss function to the backbone function in an IRModule.*/ class AppendLossMutator : private ExprMutator { public: - static IRModule Transform(IRModule mod, String func_name, Function loss_function, - int num_backbone_outputs, Optional new_func_name) { + static IRModule Transform(IRModule mod, ffi::String func_name, Function loss_function, + int num_backbone_outputs, ffi::Optional new_func_name) { auto* old_func = mod->Lookup(func_name).as(); CHECK(old_func) << func_name << "is not a Relax Function"; // functions should be copied to satisfy the well-formed check - Function new_func = CopyWithNewVars(GetRef(old_func)); + Function new_func = CopyWithNewVars(ffi::GetRef(old_func)); Function new_loss_func = CopyWithNewVars(loss_function); AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs); @@ -53,7 +53,7 @@ class AppendLossMutator : private ExprMutator { WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, new_func_name.value_or(func_name + "_loss")); - auto new_module = GetRef(mod.CopyOnWrite()); + auto new_module = ffi::GetRef(mod.CopyOnWrite()); auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss")); new_module->Add(new_var, new_func_transformed); return new_module; @@ -73,7 +73,7 @@ class AppendLossMutator : private ExprMutator { CheckAndRemapBackboneReturn(); CheckAndRemapLossParams(loss_function_->params); - Array new_params = func->params; + ffi::Array new_params = func->params; new_params.insert(new_params.end(), loss_function_->params.begin() + num_backbone_outputs_, loss_function_->params.end()); Expr new_body = this->VisitExpr(func->body); @@ -85,8 +85,8 @@ class AppendLossMutator : private ExprMutator { CHECK(seq_expr->blocks.size() == 1 && seq_expr->blocks[0]->IsInstance()) << "Backbone should have only one DataflowBlock"; - auto new_blocks = Array({this->VisitBindingBlock(seq_expr->blocks[0])}); - auto ret = Array({loss_body_->body}); + auto new_blocks = ffi::Array({this->VisitBindingBlock(seq_expr->blocks[0])}); + auto ret = ffi::Array({loss_body_->body}); ret.insert(ret.end(), backbone_return_arr_.begin() + num_backbone_outputs_, backbone_return_arr_.end()); return SeqExpr(new_blocks, ret.size() == 1 ? ret[0] : Tuple(ret)); @@ -118,22 +118,22 @@ class AppendLossMutator : private ExprMutator { CHECK(loss_body_->blocks.size() == 1 && loss_body_->blocks[0]->IsInstance()) << "The loss function should have only one DataflowBlock"; auto var_node = loss_body_->body.as(); - CHECK(var_node && IsScalarTensor(GetRef(var_node))) + CHECK(var_node && IsScalarTensor(ffi::GetRef(var_node))) << "The loss function must return a scalar(0-dim Tensor) Var"; } /*! - * \brief Convert the return value of the backbone to Array. The backbone should return one - * or a tuple of Vars. + * \brief Convert the return value of the backbone to ffi::Array. The backbone should return + * one or a tuple of Vars. */ void BackboneReturnToArr(const Expr& backbone_return) { if (auto* var = backbone_return.as()) { - backbone_return_arr_.push_back(GetRef(var)); + backbone_return_arr_.push_back(ffi::GetRef(var)); } else if (auto* tuple = backbone_return.as()) { for (auto i : tuple->fields) { auto var = i.as(); CHECK(var) << "The return value of the backbone should be either a Var or a Tuple of Vars"; - backbone_return_arr_.push_back(GetRef(var)); + backbone_return_arr_.push_back(ffi::GetRef(var)); } } else { LOG(FATAL) << "The return value of the backbone should be either a Var or a Tuple of Vars"; @@ -145,7 +145,7 @@ class AppendLossMutator : private ExprMutator { * and the elements in backbone_return_arr_ and loss_func_params have matched struct_info. Also * sets up var_remap_ from loss parameter Vars to backbone returned Vars. */ - void CheckAndRemapLossParams(const Array& loss_func_params) { + void CheckAndRemapLossParams(const ffi::Array& loss_func_params) { static StructuralEqual checker; CHECK(static_cast(loss_func_params.size()) >= num_backbone_outputs_) << "The number of parameters of the loss function is " << loss_func_params.size() @@ -199,13 +199,13 @@ class AppendLossMutator : private ExprMutator { /*! \brief The body of the loss function */ SeqExpr loss_body_; /*! \brief The unpacked return values of the backbone. All return values should be Vars. */ - Array backbone_return_arr_; + ffi::Array backbone_return_arr_; }; namespace transform { -Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outputs, - Optional new_func_name) { +Pass AppendLoss(ffi::String func_name, Function loss_function, int num_backbone_outputs, + ffi::Optional new_func_name) { auto pass_func = [=](IRModule mod, PassContext pc) { return relax::AppendLossMutator::Transform(mod, func_name, loss_function, num_backbone_outputs, new_func_name); diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index 1bfb20da3521..c22588804d08 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -50,8 +50,8 @@ namespace transform { * will be `func_name + "_loss"`. * \return The Pass. */ -TVM_DLL Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outputs = 1, - Optional new_func_name = std::nullopt); +TVM_DLL Pass AppendLoss(ffi::String func_name, Function loss_function, int num_backbone_outputs = 1, + ffi::Optional new_func_name = std::nullopt); } // namespace transform } // namespace relax diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 55ca86c306eb..7b8dad43b5da 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -40,7 +40,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); @@ -73,15 +73,15 @@ std::tuple)>> Crea pat_permuted_matmul_on_rhs; PrimExpr symbolic_var_constraints = Bool(true); - if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { - Map name_lookup; + if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { + ffi::Map name_lookup; for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { name_lookup.Set(tir_var->name_hint, tir_var); symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); } for (const auto& [key, obj_bound] : upper_bounds.value()) { - auto tir_var_name = Downcast(key); + auto tir_var_name = Downcast(key); if (auto opt_var = name_lookup.Get(tir_var_name)) { auto var = opt_var.value(); auto expr_bound = Downcast(obj_bound); @@ -90,7 +90,7 @@ std::tuple)>> Crea } } - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto expr_a = matches[pat_a]; auto expr_b = matches[pat_b]; auto expr_c = matches[pat_c]; @@ -102,7 +102,7 @@ std::tuple)>> Crea return expr; } - auto get_shape = [](Expr expr) -> Optional> { + auto get_shape = [](Expr expr) -> ffi::Optional> { auto sinfo = expr->struct_info_.as(); if (sinfo) { return sinfo->GetShape(); diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index d0b462bb1e5b..3af7b486bae3 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -52,18 +52,18 @@ class ExternFunctionRewriter : ExprMutator { } Expr VisitExpr_(const FunctionNode* func_node) override { - if (!func_node->GetAttr(attr::kCodegen) && - !func_node->GetAttr(attr::kComposite)) { + if (!func_node->GetAttr(attr::kCodegen) && + !func_node->GetAttr(attr::kComposite)) { return ExprMutator::VisitExpr_(func_node); } if (auto workspace = func_node->GetAttr(attr::kWorkspaceSize)) { // Append the workspace parameter to this function. - Array new_params = func_node->params; + ffi::Array new_params = func_node->params; auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}), DataType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), sinfo); - if (func_node->GetAttr(attr::kCodegen)) { + if (func_node->GetAttr(attr::kCodegen)) { workspace_var_param_ = workspace_param; } @@ -81,7 +81,7 @@ class ExternFunctionRewriter : ExprMutator { if (auto var = new_op.as()) { if (auto callee = builder_->LookupBinding(var.value()); callee && callee->IsInstance() && - Downcast(callee.value())->GetAttr(attr::kComposite)) { + Downcast(callee.value())->GetAttr(attr::kComposite)) { // Append the workspace argument to this call. The callee should have been updated to accept // a workspace as the last parameter. auto new_args = call_node->args; @@ -127,13 +127,13 @@ class WorkspaceProvider : ExprMutator { WithAttr(f, tvm::attr::kGlobalSymbol, new_gvar->name_hint)); gvar_map_[gvar] = new_gvar; new_gvars_.insert(new_gvar); - builder_->GetContextIRModule()->Remove(GetRef(gvar)); + builder_->GetContextIRModule()->Remove(ffi::GetRef(gvar)); } for (const auto& [gvar, f] : mod_->functions) { workspace_var_main_ = Var(); - if (!f->IsInstance() || f->GetAttr(attr::kCodegen) || - f->GetAttr(attr::kComposite)) { + if (!f->IsInstance() || f->GetAttr(attr::kCodegen) || + f->GetAttr(attr::kComposite)) { continue; } auto func = Downcast(mod_->Lookup(gvar)); diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 4013d3aad17e..492219f013a1 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -43,17 +43,17 @@ using namespace tir; static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ -static Array ConstructRangeFromShape(const Array& shape) { +static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { return shape.Map([](const PrimExpr& dim) { return Range(tir::make_zero(dim.dtype()), dim); }); } -static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); ICHECK(shape.defined()); return shape.value(); } -static Array GetShapeFromTensor(const Expr& expr) { +static ffi::Array GetShapeFromTensor(const Expr& expr) { const auto& tensor_sinfo = Downcast(expr->struct_info_); return GetShapeFromTensorStructInfo(tensor_sinfo); } @@ -64,8 +64,8 @@ static IndexMap DeepCopyIndexMap(const IndexMap& index_map) { /*! \brief Checks if the \p transform is bijective on the shape of \p expr */ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { - Array input_shape = GetShapeFromTensor(expr); - Array initial_ranges = ConstructRangeFromShape(input_shape); + ffi::Array input_shape = GetShapeFromTensor(expr); + ffi::Array initial_ranges = ConstructRangeFromShape(input_shape); arith::Analyzer analyzer; auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges, &analyzer); (void)inverse; // to avoid unused variable warning; @@ -80,10 +80,12 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { */ class AlterOpImplMutator : public ExprMutator { public: - AlterOpImplMutator(const IRModule& mod, const Map& op_impl_map, - const Map>& op_buffer_transforms_, - const Map>>>& axis_separators_, - const Map>>>& input_axis_separators_) + AlterOpImplMutator( + const IRModule& mod, const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms_, + const ffi::Map>>>& axis_separators_, + const ffi::Map>>>& + input_axis_separators_) : ExprMutator(mod), mod_(mod), op_impl_map_(op_impl_map), @@ -119,7 +121,7 @@ class AlterOpImplMutator : public ExprMutator { ICHECK(call->args[0]->IsInstance()); const tir::PrimFunc& old_func = Downcast(mod_->Lookup(Downcast(call->args[0]))); - Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); + ffi::Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); // If the callee does not have kOperatorName attribute or no replacement is requested for // it, nothing to do here. @@ -128,9 +130,9 @@ class AlterOpImplMutator : public ExprMutator { const auto& replacement_func = op_impl_map_[op_kind]; - Array buffer_transforms; - Optional>> axis_separators; - Optional>> input_axis_separators; + ffi::Array buffer_transforms; + ffi::Optional>> axis_separators; + ffi::Optional>> input_axis_separators; if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind]; if (op_buffer_axis_separators__.count(op_kind)) axis_separators = op_buffer_axis_separators__[op_kind]; @@ -145,7 +147,7 @@ class AlterOpImplMutator : public ExprMutator { GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind); - auto call_tir_inputs_tuple = GetRef(call->args[1].as()); + auto call_tir_inputs_tuple = ffi::GetRef(call->args[1].as()); Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, input_axis_separators); @@ -159,18 +161,18 @@ class AlterOpImplMutator : public ExprMutator { input_axis_separators); } - Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { + ffi::Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { if (const auto* tensor_sinfo = output_sinfo.as()) - return {GetRef(tensor_sinfo)}; + return {ffi::GetRef(tensor_sinfo)}; const auto* tuple_sinfo = output_sinfo.as(); ICHECK(tuple_sinfo); - Array arr_tensor_sinfo; + ffi::Array arr_tensor_sinfo; arr_tensor_sinfo.reserve(tuple_sinfo->fields.size()); for (const auto& sinfo : tuple_sinfo->fields) { const auto* tensor_sinfo = sinfo.as(); ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; - arr_tensor_sinfo.push_back(GetRef(tensor_sinfo)); + arr_tensor_sinfo.push_back(ffi::GetRef(tensor_sinfo)); } return arr_tensor_sinfo; } @@ -183,12 +185,12 @@ class AlterOpImplMutator : public ExprMutator { } Expr TransformLayout(const Expr& expr, const IndexMap& index_map, - const Array& axis_separators, - const Array& input_axis_separators) { + const ffi::Array& axis_separators, + const ffi::Array& input_axis_separators) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); // We want to avoid two layout_transform ops to share the same index map even if they are // identical. The scope of vars used in index map initial indices is local to the op. Not doing // so would confuse the structural equality check. @@ -202,13 +204,13 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p remove_pad op to the module if it has not already been added before. * \returns The global var associated with the remove_pad PrimFunc. */ - GlobalVar GetOrCreateRemovePadOp(const Array& old_shape, const DataType& dtype) { + GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, const DataType& dtype) { int t_shape = old_shape.size(); if (remove_pad_map_.count(t_shape) != 0) { return remove_pad_map_[t_shape]; } // Create dynamic shapes for input and output tensors - Array dyn_padded_shape, dyn_old_shape; + ffi::Array dyn_padded_shape, dyn_old_shape; for (int i = 0; i < t_shape; i++) { tir::Var var1("p" + std::to_string(i), old_shape[i].dtype()); tir::Var var2("i" + std::to_string(i), old_shape[i].dtype()); @@ -221,12 +223,12 @@ class AlterOpImplMutator : public ExprMutator { // Output tensor of remove_pad op te::Tensor output_tensor = te::compute( dyn_old_shape, - [&placeholder_tensor](const Array& indices) { + [&placeholder_tensor](const ffi::Array& indices) { return placeholder_tensor(indices); }, "output", topi::kElementWise); - String op_name = "remove_pad"; + ffi::String op_name = "remove_pad"; // Create PrimFunc and add op_name to func.attrs PrimFunc remove_pad_with_frozen_layout = WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), kOperatorName, op_name); @@ -242,13 +244,13 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, const TensorStructInfo& old_tensor_sinfo, - const Array& axis_separator, - const Array& input_axis_separator) { + const ffi::Array& axis_separator, + const ffi::Array& input_axis_separator) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } - Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); - Array initial_ranges = ConstructRangeFromShape(old_shape); + ffi::Array old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); + ffi::Array initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges, &analyzer); @@ -269,7 +271,8 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p replacement_func to the module if it has not already been added before. * \returns The global var associated with the PrimFunc. */ - GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, const String& op_kind) { + GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, + const ffi::String& op_kind) { if (cache_.count(replacement_func) != 0) { return cache_[replacement_func]; } @@ -287,22 +290,22 @@ class AlterOpImplMutator : public ExprMutator { /*! * \brief Updates call inputs with layout transformed inputs */ - Tuple UpdateInputs(const Tuple& inputs, const Array& transforms, - const Optional>>& axis_separators, - const Optional>>& input_axis_separators) { + Tuple UpdateInputs(const Tuple& inputs, const ffi::Array& transforms, + const ffi::Optional>>& axis_separators, + const ffi::Optional>>& input_axis_separators) { if (transforms.empty()) return inputs; - Array updated_inputs; + ffi::Array updated_inputs; int index = 0; for (const auto& input : inputs->fields) { - Array axis_separator; - Array input_axis_separator; + ffi::Array axis_separator; + ffi::Array input_axis_separator; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_separator = axis_separators_value[index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_separator = input_axis_separators_value[index]; } auto transform = transforms[index++]; @@ -314,7 +317,7 @@ class AlterOpImplMutator : public ExprMutator { /*! \brief Updates output struct info */ StructInfo UpdateStructInfo(const StructInfo& out_sinfo, - const Array& buffer_transforms) { + const ffi::Array& buffer_transforms) { if (buffer_transforms.empty()) return out_sinfo; if (out_sinfo->IsInstance()) @@ -327,7 +330,7 @@ class AlterOpImplMutator : public ExprMutator { << out_sinfo; const auto& tuple_sinfo = Downcast(out_sinfo); - Array sinfo_fields; + ffi::Array sinfo_fields; size_t first_output_index = buffer_transforms.size() - tuple_sinfo->fields.size(); size_t i = 0; for (const auto& si : tuple_sinfo->fields) { @@ -354,15 +357,16 @@ class AlterOpImplMutator : public ExprMutator { return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype); } - Expr TransformOutputs(const Expr& expr, const Array& buffer_transforms, - const StructInfo& old_struct_info, - const Optional>>& axis_separators, - const Optional>>& input_axis_separators) { + Expr TransformOutputs( + const Expr& expr, const ffi::Array& buffer_transforms, + const StructInfo& old_struct_info, + const ffi::Optional>>& axis_separators, + const ffi::Optional>>& input_axis_separators) { if (buffer_transforms.empty()) return expr; - Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); + ffi::Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); - Array axis_sep, input_axis_sep; + ffi::Array axis_sep, input_axis_sep; size_t num_outputs = old_output_sinfo.size(); if (num_outputs == 0) return expr; @@ -371,11 +375,11 @@ class AlterOpImplMutator : public ExprMutator { if (num_outputs == 1) { IndexMap output_map = buffer_transforms[first_output_index]; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[first_output_index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_sep = input_axis_separators_value[first_output_index]; } return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep, @@ -384,15 +388,15 @@ class AlterOpImplMutator : public ExprMutator { // In case of more than one output, we would have to get each item of the output tuple, // transform it and return a tuple of all transformed outputs. - Array transformed_outputs; + ffi::Array transformed_outputs; for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) { const auto& output_map = buffer_transforms[i + first_output_index]; if (axis_separators.defined()) { - Array> axis_separators_value = axis_separators.value(); + ffi::Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[i + first_output_index]; } if (input_axis_separators.defined()) { - Array> input_axis_separators_value = input_axis_separators.value(); + ffi::Array> input_axis_separators_value = input_axis_separators.value(); input_axis_sep = input_axis_separators_value[i + first_output_index]; } auto output = builder_->Normalize(TupleGetItem(expr, static_cast(i))); @@ -404,19 +408,21 @@ class AlterOpImplMutator : public ExprMutator { private: /*! \brief Cache to keep track of the GlobalVar associated with the new PrimFunc added */ - Map cache_; + ffi::Map cache_; /*! \brief Input IRModule */ const IRModule& mod_; /*! \brief Map from shape_dim.size to the remove_pad GlobalVar */ std::unordered_map remove_pad_map_; /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */ - const Map& op_impl_map_; + const ffi::Map& op_impl_map_; /*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */ - const Map>& op_buffer_transforms__; + const ffi::Map>& op_buffer_transforms__; /*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */ - const Map>>>& op_buffer_axis_separators__; + const ffi::Map>>>& + op_buffer_axis_separators__; /*! \brief Map from kOperatorName attribute to the input axis separatos */ - const Map>>>& op_buffer_input_axis_separators__; + const ffi::Map>>>& + op_buffer_input_axis_separators__; const Op& call_tir_op_ = Op::Get("relax.call_tir"); const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); @@ -424,10 +430,12 @@ class AlterOpImplMutator : public ExprMutator { namespace transform { -Pass AlterOpImpl(const Map& op_impl_map, - const Map>& op_buffer_transforms_, - const Map>>>& axis_separators_, - const Map>>>& input_axis_separators_) { +Pass AlterOpImpl( + const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms_, + const ffi::Map>>>& axis_separators_, + const ffi::Map>>>& + input_axis_separators_) { auto pass_func = [=](IRModule mod, PassContext pc) { return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_, input_axis_separators_) diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index a7c8013a56fd..f2cc2fc842b8 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -70,9 +70,9 @@ class AttrAttacher : public ExprMutator { return call; } GlobalVar gv = Downcast(call->args[0]); - Array call_tir_args = Downcast(call->args[1])->fields; + ffi::Array call_tir_args = Downcast(call->args[1])->fields; // Compute the layout free buffers - Array layout_free_buffers; + ffi::Array layout_free_buffers; for (size_t i = 0; i < call_tir_args.size(); i++) { if (layout_free_exprs_.count(call_tir_args[i].get())) { layout_free_buffers.push_back(i); @@ -88,7 +88,7 @@ class AttrAttacher : public ExprMutator { // So we don't need to worry about the duplicate insertion GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); // Create a new call node with the updated tir::PrimFunc - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->args = {new_gv, Tuple(call_tir_args)}; return Call(n); } diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 9ef135608dc4..324789d3f490 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -34,25 +34,26 @@ namespace transform { Pass AttachGlobalSymbol() { auto pass_func = [=](IRModule mod, PassContext pc) { - String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); + ffi::String c_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix).value_or(""); IRModule updates; - Map gvar_updates; + ffi::Map gvar_updates; for (const auto& [gvar, func] : mod->functions) { - Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional old_name = func->GetAttr(tvm::attr::kGlobalSymbol); // TODO(tvm-team): re-enable once fix relax integration part // if (old_name) continue; - Optional new_name; + ffi::Optional new_name; BaseFunc new_func; if (auto* prim_func = func.as()) { new_name = c_prefix + gvar->name_hint; - new_func = WithAttr(GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); + new_func = + WithAttr(ffi::GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); } else if (auto* relax_func = func.as()) { new_name = gvar->name_hint; - new_func = WithAttr(GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); + new_func = WithAttr(ffi::GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); } if (new_name.has_value() && (!old_name.has_value() || old_name.value() != new_name.value())) { diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 1940a7a24d64..e2074ef085be 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, - Map* symbolic_var_map, arith::Analyzer* analyzer_) { + ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { auto opt_arg_sinfo = MatchStructInfo(arg); CHECK(opt_arg_sinfo) << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " @@ -70,9 +70,9 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, const PrimExpr& const_dim = const_shape->values[i]; ICHECK(tir::is_const_int(const_dim)); if (const auto* shape_var = arg_shape->values[i].as()) { - auto it = symbolic_var_map->find(GetRef(shape_var)); + auto it = symbolic_var_map->find(ffi::GetRef(shape_var)); if (it == symbolic_var_map->end()) { - symbolic_var_map->Set(GetRef(shape_var), const_dim); + symbolic_var_map->Set(ffi::GetRef(shape_var), const_dim); } else { CHECK(analyzer_->CanProveEqual((*it).second, const_dim)) << "The shape of the bound parameter is expected to be " << (*it).second @@ -82,23 +82,23 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } } -std::tuple, Map> NormalizeBindings( - const Function& func, const Map& untyped_params) { +std::tuple, ffi::Map> NormalizeBindings( + const Function& func, const ffi::Map& untyped_params) { ICHECK(func.defined()); ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set var_set; for (const auto& param : func->params) { string_lookup[param->name_hint()].push_back(param); var_set.insert(param.get()); } - Map relax_var_remap; + ffi::Map relax_var_remap; auto normalize_key = [&](ffi::Any obj) -> relax::Var { - if (auto opt_str = obj.as()) { + if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); CHECK(it != string_lookup.end()) @@ -143,7 +143,7 @@ std::tuple, Map> NormalizeBindings( } arith::Analyzer analyzer; - Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); // for (const auto& [bind_param, bind_expr] : relax_var_remap) { // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); @@ -158,7 +158,7 @@ std::tuple, Map> NormalizeBindings( * \param params params dict * \return Function */ -Function FunctionBindParams(Function func, const Map& untyped_params) { +Function FunctionBindParams(Function func, const ffi::Map& untyped_params) { auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); @@ -172,28 +172,29 @@ Function FunctionBindParams(Function func, const Map& untyped_pa * \param param The param dict * \return The module after binding params. */ -IRModule BindParam(IRModule m, String func_name, Map bind_params) { +IRModule BindParam(IRModule m, ffi::String func_name, ffi::Map bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); - Map functions = m->functions; + ffi::Map functions = m->functions; for (const auto& func_pr : functions) { if (const auto* relax_f = func_pr.second.as()) { if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage - Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + ffi::Optional gsymbol = + relax_f->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.has_value() && gsymbol.value() == func_name) { - Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); + Function f_after_bind = FunctionBindParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } else { // Use global var's name_hint if it's internal linkage if (func_pr.first->name_hint == func_name) { - Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); + Function f_after_bind = FunctionBindParams(ffi::GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } } } - return GetRef(new_module); + return ffi::GetRef(new_module); } TVM_FFI_STATIC_INIT_BLOCK({ @@ -203,7 +204,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindParams(String func_name, Map params) { +Pass BindParams(ffi::String func_name, ffi::Map params) { auto pass_func = [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 5ba25b7e16e1..b87597c118a2 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -31,17 +31,17 @@ namespace tvm { namespace relax { -Function FunctionBindSymbolicVars(Function func, - Map, PrimExpr> obj_remap) { +Function FunctionBindSymbolicVars( + Function func, ffi::Map, PrimExpr> obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; } - Array old_symbolic_vars = DefinedSymbolicVars(func); + ffi::Array old_symbolic_vars = DefinedSymbolicVars(func); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; + std::unordered_map> string_lookup; std::unordered_set symbolic_var_set; for (const auto& var : old_symbolic_vars) { string_lookup[var->name_hint].push_back(var); @@ -49,10 +49,10 @@ Function FunctionBindSymbolicVars(Function func, } // Replacement map to be used when rewriting the function. - Map var_remap; + ffi::Map var_remap; for (const auto& [key, replacement] : obj_remap) { if (auto opt = key.as()) { - String string_key = opt.value(); + ffi::String string_key = opt.value(); auto it = string_lookup.find(string_key); CHECK(it != string_lookup.end()) << "Function does not use symbolic var with name \"" << string_key << "\". " @@ -91,8 +91,8 @@ Function FunctionBindSymbolicVars(Function func, } namespace { -IRModule ModuleBindSymbolicVars(IRModule mod, - Map, PrimExpr> binding_map) { +IRModule ModuleBindSymbolicVars( + IRModule mod, ffi::Map, PrimExpr> binding_map) { std::unordered_set used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -100,7 +100,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> Map, PrimExpr> { + auto func_binding_map = [&]() -> ffi::Map, PrimExpr> { std::unordered_set var_names; std::unordered_set vars; for (const auto& var : DefinedSymbolicVars(func)) { @@ -108,10 +108,10 @@ IRModule ModuleBindSymbolicVars(IRModule mod, vars.insert(var.get()); } - Map, PrimExpr> out; + ffi::Map, PrimExpr> out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; - if (auto opt = key.as()) { + if (auto opt = key.as()) { used_by_function = var_names.count(opt.value()); } else if (auto ptr = key.as()) { used_by_function = vars.count(ptr); @@ -134,7 +134,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, } } - Array unused; + ffi::Array unused; for (const auto& [key, replacement] : binding_map) { if (!used.count(key)) { unused.push_back(key); @@ -158,8 +158,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass BindSymbolicVars(Map, PrimExpr> binding_map, - Optional func_name) { +Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, + ffi::Optional func_name) { auto pass_func = [=](IRModule mod, PassContext context) -> IRModule { if (func_name) { auto gvar = mod->GetGlobalVar(func_name.value()); diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index 16b7348b8dc7..faf5e6838f17 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -36,11 +36,11 @@ namespace relax { class ModelParamBundler : public ExprMutator { public: - explicit ModelParamBundler(Optional param_tuple_name) + explicit ModelParamBundler(ffi::Optional param_tuple_name) : param_tuple_name_(param_tuple_name) {} Expr VisitExpr_(const FunctionNode* op) override { - Function func = GetRef(op); + Function func = ffi::GetRef(op); auto opt_num_input = func->attrs.GetAttr(attr::kNumInput); if (!opt_num_input) return func; auto signed_num_input = opt_num_input.value()->value; @@ -51,12 +51,12 @@ class ModelParamBundler : public ExprMutator { << "but only has " << func->params.size() << " parameters total."; size_t num_input = signed_num_input; - Array params; + ffi::Array params; for (size_t i = 0; i < num_input; i++) { params.push_back(func->params[i]); } - Array param_tuple; + ffi::Array param_tuple; for (size_t i = num_input; i < func->params.size(); i++) { param_tuple.push_back(GetStructInfo(func->params[i])); } @@ -74,7 +74,7 @@ class ModelParamBundler : public ExprMutator { } Expr VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (auto it = var_to_expr_.find(var); it != var_to_expr_.end()) { return builder_->Emit((*it).second, op->name_hint()); } else { @@ -83,17 +83,17 @@ class ModelParamBundler : public ExprMutator { } private: - Optional param_tuple_name_; - Map var_to_expr_; + ffi::Optional param_tuple_name_; + ffi::Map var_to_expr_; }; -Function BundleModelParams(const Function& func, Optional param_tuple_name) { +Function BundleModelParams(const Function& func, ffi::Optional param_tuple_name) { ModelParamBundler mutator(param_tuple_name); return Downcast(mutator(func)); } namespace transform { -Pass BundleModelParams(Optional param_tuple_name) { +Pass BundleModelParams(ffi::Optional param_tuple_name) { auto pass_func = [=](IRModule mod, PassContext pc) { IRModule updates; diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index a47b9bfe5105..10508382731f 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -74,7 +74,7 @@ class CallTIRMutator : public ExprMutator { call->op == call_dps_packed_op) { bool is_inplace = (call->op == call_tir_inplace_op); const auto* inplace_attrs = call->attrs.as(); - Array outs; + ffi::Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { // single output case const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); @@ -130,7 +130,7 @@ class CallTIRMutator : public ExprMutator { << expr->struct_info_; } - Array args; + ffi::Array args; if (call->args[1].as()) { args = Downcast(call->args[1])->fields; // for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones @@ -167,7 +167,7 @@ class CallTIRMutator : public ExprMutator { return std::move(Tuple(outs)); } - return GetRef(call); + return ffi::GetRef(call); } /*! \brief The context IRModule. */ diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 54c508ff2302..38dd80899fa7 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -59,7 +59,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { << ", while the later definition of Relax variable " << binding->var << " instead implies that TIR variable " << tir_var << " is " << prim_expr; } else { - known_values_[tir_var] = KnownValue{prim_expr, GetRef(binding)}; + known_values_[tir_var] = KnownValue{prim_expr, ffi::GetRef(binding)}; } } ExprMutator::VisitBinding_(binding); @@ -76,7 +76,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op); + return ffi::GetRef(op); } // The two branches may have had different TIR variables inlined. @@ -119,7 +119,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (known_values_.empty()) { return expr; } - PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> Optional { + PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> ffi::Optional { if (auto it = known_values_.find(var); it != known_values_.end()) { return it->second.expr; } else { @@ -144,10 +144,10 @@ class SymbolicVarCanonicalizer : public ExprMutator { }; struct CanonicalizationPlan { - Map replace_usage; - Map replace_binding; + ffi::Map replace_usage; + ffi::Map replace_binding; std::unordered_set bindings_to_remove; - Map inline_constant; + ffi::Map inline_constant; }; /*! \brief Utility class to identify usage location @@ -232,8 +232,8 @@ class CanonicalizePlanner : public ExprVisitor { void VisitExpr_(const FunctionNode* func) override { // for functions, treat any free vars as used outside their home DF block auto cache = current_block_; - current_block_ = Optional(); - auto free_vars = FreeVars(GetRef(func)); + current_block_ = ffi::Optional(); + auto free_vars = FreeVars(ffi::GetRef(func)); for (auto var : free_vars) { used_outside_home_dataflow_.insert(var); } @@ -244,26 +244,26 @@ class CanonicalizePlanner : public ExprVisitor { void VisitExpr_(const SeqExprNode* seq) override { // need to reset current_block_ for nested seq exprs (such as in If nodes) auto cache = current_block_; - current_block_ = Optional(); + current_block_ = ffi::Optional(); ExprVisitor::VisitExpr_(seq); current_block_ = cache; } void VisitBindingBlock_(const BindingBlockNode* block) override { CHECK(!current_block_.defined()) << "Forgetting to unset current block"; - current_block_ = GetRef(block); + current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); - current_block_ = Optional(); + current_block_ = ffi::Optional(); } void VisitBindingBlock_(const DataflowBlockNode* block) override { CHECK(!current_block_.defined()) << "Forgetting to unset current block"; - current_block_ = GetRef(block); + current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); - current_block_ = Optional(); + current_block_ = ffi::Optional(); } - Optional UnwrapKnownValue(Expr expr) { + ffi::Optional UnwrapKnownValue(Expr expr) { // If the expression is a variable, then it can be unwrapped into // its known value. auto unwrap_var = [this](Expr expr) -> Expr { @@ -299,7 +299,7 @@ class CanonicalizePlanner : public ExprVisitor { // If the expression is a Tuple, and each element is // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of // `earlier_tuple`. - auto earlier_tuple = [&]() -> Optional { + auto earlier_tuple = [&]() -> ffi::Optional { auto expr_tuple = expr.as(); if (!expr_tuple) { return std::nullopt; @@ -385,14 +385,14 @@ class CanonicalizePlanner : public ExprVisitor { } void VisitExpr_(const VarNode* var) override { - auto var_ref = GetRef(var); + auto var_ref = ffi::GetRef(var); // if a var is used in a dataflow block but *not* the one // where it was defined, it also needs to be exposed, so also we treat that as // used outside of a dataflow block if (!inside_dataflow() || (def_blocks_.count(var_ref) && (current_block_.defined() && !current_block_.value().same_as(def_blocks_.at(var_ref))))) { - used_outside_home_dataflow_.insert(GetRef(var)); + used_outside_home_dataflow_.insert(ffi::GetRef(var)); } } @@ -400,12 +400,12 @@ class CanonicalizePlanner : public ExprVisitor { return current_block_.defined() && current_block_.value().as(); } - Optional current_block_; - Map def_blocks_; + ffi::Optional current_block_; + ffi::Map def_blocks_; - Map trivial_bindings_; - Map known_bindings_; - Map known_bound_to_constant_; + ffi::Map trivial_bindings_; + ffi::Map known_bindings_; + ffi::Map known_bound_to_constant_; std::unordered_set defined_inside_dataflow_; // Set of vars either used outside a dataflow block altogether or outside their // home dataflow block (the one where they were defined) @@ -440,7 +440,7 @@ class BindingCanonicalizer : public ExprMutator { } Expr VisitExpr_(const VarNode* var) override { - Var new_var = GetRef(var); + Var new_var = ffi::GetRef(var); while (auto opt = plan_.replace_usage.Get(new_var->vid)) { new_var = opt.value(); } @@ -470,7 +470,7 @@ class BindingCanonicalizer : public ExprMutator { // disqualify any vars that appear in the RHS // (for a function literal, consider only free vars) - Array rhs_vars; + ffi::Array rhs_vars; if (!value->IsInstance()) { rhs_vars = FreeVars(value); } else { @@ -494,12 +494,12 @@ class BindingCanonicalizer : public ExprMutator { // disqualify if the RHS is not a single dataflow var // or if the var has been output before if (const auto* rhs_var = value.as()) { - if (output_vars.count(GetRef(rhs_var))) { - disqualified_set.insert(GetRef(rhs_var)); + if (output_vars.count(ffi::GetRef(rhs_var))) { + disqualified_set.insert(ffi::GetRef(rhs_var)); } - output_vars.insert(GetRef(rhs_var)); + output_vars.insert(ffi::GetRef(rhs_var)); } else { - Array disqualified; + ffi::Array disqualified; // for function literal, consider only free vars if (value->IsInstance()) { disqualified = FreeVars(value); @@ -518,7 +518,7 @@ class BindingCanonicalizer : public ExprMutator { // second pass: for each binding where the LHS is a candidate, remove the binding. // If the RHS is a candidate, replace it with the definition - Array new_bindings; + ffi::Array new_bindings; bool changed = false; for (auto binding : new_block->bindings) { if (binding->var->IsInstance() && diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 9c0318ee3926..34dfa1530c2f 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -39,13 +39,13 @@ namespace tvm { namespace relax { -using FCheck = ffi::TypedFunction, Array, Map)>; +using FCheck = ffi::TypedFunction, ffi::Array, ffi::Map)>; /*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose batch sizes are compatible are combined. */ std::unordered_map> GroupShapes( - const std::vector>& shapes) { + const std::vector>& shapes) { std::unordered_map> indices_map; for (size_t i = 0; i < shapes.size(); ++i) { indices_map[shapes[i].size()].push_back(i); @@ -77,7 +77,7 @@ struct Patterns { struct SplitInfo { Var rhs; - Optional bias; + ffi::Optional bias; PrimExpr split_size; DFPattern pattern_to_replace; }; @@ -116,10 +116,10 @@ Patterns CreatePatterns(const BranchInfo& branch_info) { } /*! \brief Create a rewriter for the given parallel matmul branches. */ -ffi::TypedFunction(Map, Map)> GetRewriter( +ffi::TypedFunction(ffi::Map, ffi::Map)> GetRewriter( const Patterns& patterns, const BranchInfo& branch_info, FCheck check) { auto batch_dims_compatible = [](size_t rhs_dim, const std::vector& indices, - const std::vector>& rhs_shapes) { + const std::vector>& rhs_shapes) { arith::Analyzer ana; for (auto ind : indices) { ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); @@ -133,17 +133,17 @@ ffi::TypedFunction(Map, Map)> GetRewri return true; }; - return [=](Map matchings, Map bindings) { - std::vector> rhs_shapes; + return [=](ffi::Map matchings, ffi::Map bindings) { + std::vector> rhs_shapes; for (const auto& rhs_pat : patterns.rhs) { auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape(); if (!rhs_shape_opt) { - return Map{}; + return ffi::Map{}; } rhs_shapes.push_back(rhs_shape_opt.value()); } - Map replacements; + ffi::Map replacements; for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue; @@ -159,7 +159,7 @@ ffi::TypedFunction(Map, Map)> GetRewri std::vector splits; for (auto index : indices) { Var rhs = matchings[patterns.rhs[index]]; - Optional bias = std::nullopt; + ffi::Optional bias = std::nullopt; if (branch_info.bias_dim.has_value()) { bias = matchings[patterns.bias[index]]; } @@ -190,8 +190,8 @@ ffi::TypedFunction(Map, Map)> GetRewri continue; } - Array rhs; - Array bias; + ffi::Array rhs; + ffi::Array bias; for (const auto& split : splits) { rhs.push_back(split.rhs); if (split.bias) { @@ -228,7 +228,7 @@ ffi::TypedFunction(Map, Map)> GetRewri } int split_index = 0; - Array sections; + ffi::Array sections; for (size_t i = 0; i + 1 < splits.size(); i++) { auto width = splits[i].split_size.as(); ICHECK(width) << "InternalError: " diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index 4fad1f831842..ec768a852543 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -39,7 +39,7 @@ class DataflowBlockExtractor : public ExprMutator { explicit DataflowBlockExtractor(size_t min_size) : ExprMutator(), min_size_(min_size) {} Expr VisitExpr_(const SeqExprNode* seq) override { - Array new_blocks; + ffi::Array new_blocks; Expr new_body = VisitExpr(seq->body); bool changed = !new_body.same_as(seq->body); @@ -49,15 +49,15 @@ class DataflowBlockExtractor : public ExprMutator { // make a dataflowblock. Because these bindings occur prior to // `dataflow_bindings`, this array may only be accumulated into // when `dataflow_bindings` is empty. - Array non_dataflow_bindings; + ffi::Array non_dataflow_bindings; // Current bindings that may legally be added to a DataflowBlock. - Array dataflow_bindings; + ffi::Array dataflow_bindings; // If present, a DataflowBlock whose bindings are currently in // `dataflow_bindings`. Used to propagate DataflowBlock to the // output, even if it doesn't meet the minimum size. - Optional input_dataflow_block; + ffi::Optional input_dataflow_block; // Handle any bindings currently in `dataflow_bindings`. These // are either pushed to their own block, or to the end of @@ -134,7 +134,7 @@ class DataflowBlockExtractor : public ExprMutator { if (changed) { return SeqExpr(new_blocks, new_body); } else { - return GetRef(seq); + return ffi::GetRef(seq); } } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 2ba757c76a70..865b64dcf5e2 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -78,12 +78,13 @@ using tir::Layout; */ class LayoutConvertMutator : public ExprMutator { public: - explicit LayoutConvertMutator(const Map>& desired_layouts) + explicit LayoutConvertMutator( + const ffi::Map>& desired_layouts) : desired_layouts_(desired_layouts) {} private: - Array LayoutToIntegers(const Layout& layout) { - Array ret; + ffi::Array LayoutToIntegers(const Layout& layout) { + ffi::Array ret; LayoutDecision src = InitialLayoutDecision(layout.ndim()); for (size_t i = 0; i < layout.ndim(); ++i) { ret.push_back(Integer(src->layout.IndexOf(layout[i]))); @@ -93,17 +94,17 @@ class LayoutConvertMutator : public ExprMutator { IndexMap LayoutIndexMap(int ndim, const Layout& src_layout, const Layout& desired_layout) { tir::BijectiveLayout todesired(src_layout, desired_layout); - Optional inverse_index_map; + ffi::Optional inverse_index_map; - Array initial_indices; - Array initial_indices_expr; + ffi::Array initial_indices; + ffi::Array initial_indices_expr; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { auto var = tvm::tir::Var("i" + std::to_string(i), DataType::Int(32)); initial_indices.push_back(var); initial_indices_expr.push_back(var); } - Array desired_shape = todesired.ForwardIndex(initial_indices_expr); + ffi::Array desired_shape = todesired.ForwardIndex(initial_indices_expr); return IndexMap(initial_indices, desired_shape, std::move(inverse_index_map)); } @@ -125,9 +126,9 @@ class LayoutConvertMutator : public ExprMutator { } else { auto index_map = LayoutIndexMap(from.LeafValue()->layout.ndim(), from.LeafValue()->layout, to.LeafValue()->layout); - ObjectPtr attrs = make_object(); - Array axis_separator; - Array input_axis_separator; + ObjectPtr attrs = ffi::make_object(); + ffi::Array axis_separator; + ffi::Array input_axis_separator; attrs->index_map = Downcast(LoadJSON(SaveJSON(index_map))); attrs->axis_separators = std::move(axis_separator); attrs->input_axis_separators = std::move(input_axis_separator); @@ -141,9 +142,9 @@ class LayoutConvertMutator : public ExprMutator { std::array({GetNLayout(var_layout_map_, expr), to}), fvisitleaf); } - Array RewriteArgs(const Array& args, const Array& to) { - // The `Array args` array contains both tensor and - // non-tensor arguments, where the `Array to` array only + ffi::Array RewriteArgs(const ffi::Array& args, const ffi::Array& to) { + // The `ffi::Array args` array contains both tensor and + // non-tensor arguments, where the `ffi::Array to` array only // contains tensor arguments. The number of tensor arguments in // `args` should match the full extent of `to`. @@ -175,7 +176,7 @@ class LayoutConvertMutator : public ExprMutator { return RewriteExpr(var, InitialNLayout(var)); } - Expr VisitExpr_(const VarNode* op) final { return VisitVars_(GetRef(op)); } + Expr VisitExpr_(const VarNode* op) final { return VisitVars_(ffi::GetRef(op)); } bool HasUnknownDimTensor(const NLayout& nlayout) { bool find = false; @@ -186,7 +187,7 @@ class LayoutConvertMutator : public ExprMutator { return find; } - bool HasUnknownDimTensor(const Array& args) { + bool HasUnknownDimTensor(const ffi::Array& args) { for (const auto& arg : args) { if (IsNestedTensor(arg)) { if (HasUnknownDimTensor(GetNLayout(var_layout_map_, arg))) { @@ -197,17 +198,18 @@ class LayoutConvertMutator : public ExprMutator { return false; } - Optional GetInferLayoutInfo(const CallNode* call_node, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + ffi::Optional GetInferLayoutInfo( + const CallNode* call_node, + const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { const OpNode* op_node = call_node->op.as(); if (op_node == nullptr) return std::nullopt; - Op op = Downcast(GetRef(op_node)); + Op op = Downcast(ffi::GetRef(op_node)); const auto attr_map = Op::GetAttrMap("FRelaxInferLayout"); if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { // If the op has FRelaxInferLayout, and all the input tensors have known ndim FRelaxInferLayout f = attr_map[op]; - return f(GetRef(call_node), desired_layouts, var_layout_map); + return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); } else { // Otherwise, we use the default policy. return std::nullopt; @@ -215,9 +217,9 @@ class LayoutConvertMutator : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - Optional res = + ffi::Optional res = GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); - ObjectPtr new_call = make_object(*call_node); + ObjectPtr new_call = ffi::make_object(*call_node); new_call->struct_info_ = std::nullopt; if (!res.defined() || (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { @@ -227,14 +229,14 @@ class LayoutConvertMutator : public ExprMutator { for (const auto& arg : call_node->args) { input_layout.push_back(InitialNLayout(arg)); } - Array new_args = RewriteArgs(call_node->args, std::move(input_layout)); + ffi::Array new_args = RewriteArgs(call_node->args, std::move(input_layout)); new_call->args = std::move(new_args); ReEmitBinding(binding, builder_->Normalize(Call(new_call))); // update the layout map var_layout_map_[binding->var] = InitialNLayout(binding->var); } else { // Convert the layout according to the inferred layout output. - Array new_args = RewriteArgs(call_node->args, res.value()->input_layouts); + ffi::Array new_args = RewriteArgs(call_node->args, res.value()->input_layouts); for (const auto& [i, arg] : res.value()->new_args) { new_args.Set(i->value, arg); } @@ -273,7 +275,7 @@ class LayoutConvertMutator : public ExprMutator { input_layout.push_back(InitialNLayout(field)); } } - Array new_fields = RewriteArgs(val->fields, std::move(input_layout)); + ffi::Array new_fields = RewriteArgs(val->fields, std::move(input_layout)); if (IsNestedTensor(binding->var)) { ReEmitBinding(binding, builder_->Normalize(Tuple(new_fields))); var_layout_map_[binding->var] = input_layout; @@ -322,7 +324,7 @@ class LayoutConvertMutator : public ExprMutator { binding->struct_info, std::array({from_layout, input_layout}), fvisitleaf); // re-emit old binding if nothing changes if (new_struct_info.same_as(binding->struct_info)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { Var new_var = builder_->EmitMatchCast(RewriteExpr(binding->value, input_layout), new_struct_info); @@ -332,18 +334,18 @@ class LayoutConvertMutator : public ExprMutator { } std::unordered_map var_layout_map_; - Map> desired_layouts_; + ffi::Map> desired_layouts_; }; // namespace relax DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, - Map> desired_layouts) { + ffi::Map> desired_layouts) { LayoutConvertMutator mutator(desired_layouts); return Downcast(mutator.VisitBindingBlock(df_block)); } namespace transform { -Pass ConvertLayout(Map> desired_layouts) { +Pass ConvertLayout(ffi::Map> desired_layouts) { ffi::TypedFunction pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { return Downcast(ConvertLayoutPass(df_block, desired_layouts)); diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index fa75669362ad..7460e1004782 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -48,7 +48,7 @@ std::unordered_map> AnalyzeLiveness(const DataflowBlock Binding b = block->bindings[i]; Var defined_var = b->var; Expr value = GetBoundValue(b); - Array used_vars; + ffi::Array used_vars; // for a function literal, we consider only the free vars // (those captured from the outer scope) if (value.as()) { @@ -105,7 +105,7 @@ class AliasAnalyzer { // (in the case of in-place ops) safe to overwrite. This may not be true of function args. std::pair>, std::unordered_map>>> - Analyze(const DataflowBlock& block, const Array& inputs) { + Analyze(const DataflowBlock& block, const ffi::Array& inputs) { for (auto input : inputs) { int curr_idx = get_fresh_idx(); alias_map_[input] = {curr_idx}; @@ -227,7 +227,7 @@ class AliasAnalyzer { // TODO(@slyubomirsky): We will probably want special handling for closures ret.insert(get_fresh_idx()); } else if (auto* target_var_node = value.as()) { - auto target_var = GetRef(target_var_node); + auto target_var = ffi::GetRef(target_var_node); if (alias_map_.count(target_var)) { ret.insert(alias_map_[target_var].begin(), alias_map_[target_var].end()); } else { @@ -324,7 +324,7 @@ std::unordered_set GatherCandidateSin // don't consider cases where we don't know the shape at compile time // (we will use the analyzer to do best-effort analysis where there are vars) if (tensor_info->shape.as()) { - return {GetRef(tensor_info)}; + return {ffi::GetRef(tensor_info)}; } else { return {}; } @@ -337,7 +337,7 @@ std::unordered_set GatherCandidateSin } // at least one field should be eligible to be done in-place if (!ret.empty()) { - ret.insert(GetRef(tuple_info)); + ret.insert(ffi::GetRef(tuple_info)); } return ret; } else { @@ -447,7 +447,7 @@ bool InplaceConditionsMet( const std::unordered_map>>& tuple_map, const std::unordered_set& currently_live, const Expr& target, int binding_idx) { if (auto* var_node = target.as()) { - auto current_var = GetRef(var_node); + auto current_var = ffi::GetRef(var_node); // if the var is live past this point, we can't use it for in-place computations anyway if (live_ranges.count(current_var)) { auto live_range = live_ranges.at(current_var); @@ -523,7 +523,7 @@ class InplaceOpportunityNode : public Object { public: // need to use Array for the benefit of the FFI Integer binding_idx; - Array arg_idxs; + ffi::Array arg_idxs; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -540,8 +540,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ InplaceOpportunityNode::RegisterReflection(); }); class InplaceOpportunity : public ObjectRef { public: - TVM_DLL InplaceOpportunity(const Integer& binding_idx, const Array& arg_idxs) { - auto node = make_object(); + TVM_DLL InplaceOpportunity(const Integer& binding_idx, const ffi::Array& arg_idxs) { + auto node = ffi::make_object(); node->binding_idx = binding_idx; node->arg_idxs = arg_idxs; data_ = std::move(node); @@ -564,7 +564,7 @@ class InplaceOpportunity : public ObjectRef { // The first element is the index of the *binding* in the block. // All remaining elements are the indices of *eligible arguments* in that call. std::pair, std::vector> -FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, +FindInplaceOpportunities(const DataflowBlock& block, const ffi::Array& inputs, const BlockBuilder& ctx) { auto live_ranges = AnalyzeLiveness(block); AliasAnalyzer analyzer; @@ -619,7 +619,7 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, if (auto* call_node = value.as()) { if (auto* op_node = call_node->op.as()) { - if (!OpSupportsInplace(GetRef(op_node))) { + if (!OpSupportsInplace(ffi::GetRef(op_node))) { continue; } @@ -669,14 +669,14 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, } // produce a list of candidates for this index - Array size_candidate_list; + ffi::Array size_candidate_list; for (auto candidate : candidates) { size_candidate_list.push_back(Integer(candidate)); } size_match_list.push_back(InplaceOpportunity(Integer(i), size_candidate_list)); // also gather up the exact match candidates if there are any - Array exact_candidate_list; + ffi::Array exact_candidate_list; for (auto candidate : candidates) { if (!exact_match_candidates.count(candidate)) { continue; @@ -695,10 +695,11 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, } // Replace buffers in a PrimFunc according to the mapping. -tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map) { +tir::Stmt RemapBuffers(const tir::Stmt& stmt, + const ffi::Map& buffer_map) { class BufferMapper : public tir::StmtExprMutator { public: - explicit BufferMapper(const Map& buffer_map) + explicit BufferMapper(const ffi::Map& buffer_map) : buffer_map_(buffer_map) {} tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); } @@ -766,7 +767,7 @@ tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map& buffer_map_; + const ffi::Map& buffer_map_; }; BufferMapper mapper(buffer_map); @@ -786,7 +787,7 @@ class ModuleInplaceTransformer : public ExprMutator { if (auto* func_node = kv.second.as()) { auto gv = kv.first; auto func_params = func_node->params; - auto function = Downcast(VisitExpr(GetRef(func_node))); + auto function = Downcast(VisitExpr(ffi::GetRef(func_node))); builder_->UpdateFunction(gv, function); } } @@ -810,14 +811,14 @@ class ModuleInplaceTransformer : public ExprMutator { // the only case we will override: we will visit all binding blocks // and replace any valid calls in them BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { - auto block = GetRef(op); + auto block = ffi::GetRef(op); auto old_idxs = inplace_idxs; // For now, only handle exact match cases. // Note: Not passing any input values for now, as we can't make any assumptions // about them. auto matches_found = FindInplaceOpportunities(block, {}, builder_); - Map> new_idxs; + ffi::Map> new_idxs; for (auto match : matches_found.second) { new_idxs.Set(block->bindings[match->binding_idx.IntValue()], match->arg_idxs); } @@ -838,7 +839,7 @@ class ModuleInplaceTransformer : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding) override { - auto binding_ref = GetRef(binding); + auto binding_ref = ffi::GetRef(binding); if (!inplace_idxs.count(binding_ref)) { ExprMutator::VisitBinding_(binding); return; @@ -848,7 +849,7 @@ class ModuleInplaceTransformer : public ExprMutator { } void VisitBinding_(const MatchCastNode* binding) override { - auto binding_ref = GetRef(binding); + auto binding_ref = ffi::GetRef(binding); if (!inplace_idxs.count(binding_ref)) { ExprMutator::VisitBinding_(binding); return; @@ -861,7 +862,7 @@ class ModuleInplaceTransformer : public ExprMutator { // Given the call and indices of arguments that could be done in-place, // replace the call with a call to an in-place PrimFunc. // (Made public for testing.) - Call CreateInplaceCall(const Call& call, const Array& inplace_indices) { + Call CreateInplaceCall(const Call& call, const ffi::Array& inplace_indices) { static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -890,8 +891,8 @@ class ModuleInplaceTransformer : public ExprMutator { // 2. For each output var, replace its instances with the corresponding inplace index var // 3. Do the same for the *buffer vars* corresponding to the output vars // 4. Remove the output vars from the param list and buffer map - Map buffer_subst_map; - Map var_subst_map; + ffi::Map buffer_subst_map; + ffi::Map var_subst_map; for (size_t i = 0; i < num_outs; i++) { // we will substitute output i with the corresponding param indicated by inplace indices auto output_var = old_primfunc->params[num_params - num_outs + i]; @@ -907,12 +908,13 @@ class ModuleInplaceTransformer : public ExprMutator { // apply substitutions new_body = RemapBuffers(new_body, buffer_subst_map); - new_body = tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> Optional { - if (var_subst_map.count(v)) { - return var_subst_map.at(v); - } - return Optional(); - }); + new_body = + tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> ffi::Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); + } + return ffi::Optional(); + }); // remove the now-unused outputs from the buffer map auto new_buffer_map = old_primfunc->buffer_map; @@ -922,8 +924,8 @@ class ModuleInplaceTransformer : public ExprMutator { // now get rid of the last num_outputs arguments // (couldn't do earlier or else it would have thrown off the indexing) - Array new_params(old_primfunc->params.begin(), - old_primfunc->params.begin() + (num_params - num_outs)); + ffi::Array new_params(old_primfunc->params.begin(), + old_primfunc->params.begin() + (num_params - num_outs)); tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, new_buffer_map, old_primfunc->attrs, old_primfunc->span); @@ -935,11 +937,11 @@ class ModuleInplaceTransformer : public ExprMutator { // update the call (change the op, update the argument, change the attrs) legalized_call_cow->op = call_tir_inplace_op; - Array new_args(legalized_call->args.begin(), legalized_call->args.end()); + ffi::Array new_args(legalized_call->args.begin(), legalized_call->args.end()); new_args.Set(0, new_gv); legalized_call_cow->args = new_args; - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->inplace_indices = inplace_indices; legalized_call_cow->attrs = Attrs(attrs); @@ -952,43 +954,43 @@ class ModuleInplaceTransformer : public ExprMutator { private: const IRModule& mod_; // Keep track of legalizers we add so we can clean up at the end. - Array legalizers_added; + ffi::Array legalizers_added; // The current function's params will be treated as non-aliased // (we are assuming good behavior on the user's part). - Array func_params; + ffi::Array func_params; // map of eligible bindings to indices of arguments that can be used as the in-place target - Map> inplace_idxs; + ffi::Map> inplace_idxs; }; namespace transform { -Map> DataflowLivenessAnalysis(const DataflowBlock& block) { +ffi::Map> DataflowLivenessAnalysis(const DataflowBlock& block) { auto liveness_ranges = AnalyzeLiveness(block); - Map> ret; + ffi::Map> ret; for (auto kv : liveness_ranges) { ret.Set(kv.first, {kv.second.first, kv.second.second}); } return ret; } -Array DataflowAliasAnalysis(const DataflowBlock& block, Array inputs) { +ffi::Array DataflowAliasAnalysis(const DataflowBlock& block, ffi::Array inputs) { AliasAnalyzer analyzer; auto res = analyzer.Analyze(block, inputs); auto alias_sets = res.first; auto tuple_map = res.second; - Map> new_alias_sets; - Map>> new_tuple_map; + ffi::Map> new_alias_sets; + ffi::Map>> new_tuple_map; for (auto kv : alias_sets) { - Array aliases; + ffi::Array aliases; for (auto alias : kv.second) { aliases.push_back(alias); } new_alias_sets.Set(kv.first, aliases); } for (auto kv : tuple_map) { - Array> elem_aliases; + ffi::Array> elem_aliases; for (auto alias_set : kv.second) { - Array dim_aliases; + ffi::Array dim_aliases; for (auto alias : alias_set) { dim_aliases.push_back(alias); } @@ -1010,12 +1012,12 @@ tvm::transform::Pass DataflowUseInplaceCalls() { 0, "DataflowInsertInPlaceCalls", {}, false); } -Array> DataflowInplaceAnalysis(const DataflowBlock& block, - const Array& inputs, - const IRModule& mod) { +ffi::Array> DataflowInplaceAnalysis(const DataflowBlock& block, + const ffi::Array& inputs, + const IRModule& mod) { auto index_lists = relax::FindInplaceOpportunities(block, inputs, BlockBuilder::Create(mod)); - return {Array(index_lists.first.begin(), index_lists.first.end()), - Array(index_lists.second.begin(), index_lists.second.end())}; + return {ffi::Array(index_lists.first.begin(), index_lists.first.end()), + ffi::Array(index_lists.second.begin(), index_lists.second.end())}; } // these are exposed only for testing @@ -1027,10 +1029,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("relax.testing.transform.DataflowInplaceAnalysis", DataflowInplaceAnalysis) .def("relax.testing.transform.SingleInplaceCall", [](const IRModule& mod, const Call& call, - const Array& inplace_indices) -> Array { + const ffi::Array& inplace_indices) -> ffi::Array { ModuleInplaceTransformer transformer(mod); auto ret_call = transformer.CreateInplaceCall(call, inplace_indices); - return Array{ret_call, transformer.CurrentMod()}; + return ffi::Array{ret_call, transformer.CurrentMod()}; }); }); diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 59874e737778..378239fad0f6 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -91,7 +91,8 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set return mod; } -IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_function_names) { +IRModule DeadCodeElimination(const IRModule& arg_mod, + ffi::Array entry_function_names) { IRModule mod = arg_mod; // S0: Make a list of all user-specified entry functions and @@ -134,7 +135,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_functi namespace transform { -Pass DeadCodeElimination(Array entry_functions) { +Pass DeadCodeElimination(ffi::Array entry_functions) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m, entry_functions); }; diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index df57434ebb02..5050ab487dd0 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -36,9 +36,9 @@ TensorStructInfo MatchTensorStructInfo(Expr data) { return _sinfo.value(); } -Expr ExpandToMatchInput(Expr data, int ndim, Array axes) { +Expr ExpandToMatchInput(Expr data, int ndim, ffi::Array axes) { axes = GetOrderedPositiveAxes(axes, ndim); - Array expand_axes; + ffi::Array expand_axes; for (int i = 0, j = 0; i < ndim; ++i) { if (j < static_cast(axes.size()) && i == axes[j]->value) { ++j; @@ -89,7 +89,7 @@ Expr MutateBatchNormForTraining(Call call) { TensorStructInfo sinfo = MatchTensorStructInfo(data); - Array reduce_axes; + ffi::Array reduce_axes; for (int i = 0; i < sinfo->ndim; ++i) { if (i != attrs->axis) { reduce_axes.push_back(i); @@ -148,12 +148,12 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); Var call = builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {}, - {GetRef(sinfo)})); + {ffi::GetRef(sinfo)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e., - // Array), we define symbolic variables and returns them as a ShapeExpr. - Array shape_var; + // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. + ffi::Array shape_var; for (int i = 0; i < sinfo->ndim; i++) { shape_var.push_back(tir::Var("x", DataType::Int(64))); } @@ -233,7 +233,7 @@ Pass DecomposeOps() { /*required=*/{}); } -Pass DecomposeOpsForInference(Optional func_name) { +Pass DecomposeOpsForInference(ffi::Optional func_name) { if (func_name) { return ApplyPassToFunction(DecomposeOps(), func_name.value()); } else { @@ -241,7 +241,7 @@ Pass DecomposeOpsForInference(Optional func_name) { } } -Pass DecomposeOpsForTraining(Optional func_name) { +Pass DecomposeOpsForTraining(ffi::Optional func_name) { auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()}, "DecomposeOpsForTraining"); if (func_name) { diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 68e37970030a..c88a5bfccb74 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -48,7 +48,7 @@ namespace { */ struct ReplacementKey { tvm::relax::Expr bound_value; - tvm::Optional match_cast = std::nullopt; + tvm::ffi::Optional match_cast = std::nullopt; explicit ReplacementKey(const tvm::relax::Binding& binding) : bound_value(GetBoundValue(binding)) { @@ -155,7 +155,7 @@ class CommonSubexprEliminator : public ExprMutator { // copy of the mutator, to avoid replacing a child-scope // expression with a parent-scope binding, or vice versa. if (expr_replacements_.size() || var_remap_.size()) { - return VisitWithCleanScope(GetRef(op)); + return VisitWithCleanScope(ffi::GetRef(op)); } else { return ExprMutator::VisitExpr_(op); } @@ -168,7 +168,7 @@ class CommonSubexprEliminator : public ExprMutator { if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) && op->false_branch.same_as(false_branch) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(cond, true_branch, false_branch, op->span); } @@ -193,7 +193,7 @@ class CommonSubexprEliminator : public ExprMutator { static const auto& allocator_attr_map = Op::GetAttrMap("TAllocator"); if (const auto* call = expr.as()) { if (const auto* op = call->op.as()) { - bool is_allocator = allocator_attr_map.get(GetRef(op), Bool(false))->value; + bool is_allocator = allocator_attr_map.get(ffi::GetRef(op), Bool(false))->value; if (is_allocator) { return true; } diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 70662396fe52..a871b007b4c4 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); @@ -58,7 +58,7 @@ std::tuple)>> Crea auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto lhs = matches[pat_lhs]; auto rhs_a = matches[pat_rhs_a]; auto rhs_b = matches[pat_rhs_b]; diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 5b711b767562..fbe16e9c1b35 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -32,8 +32,8 @@ namespace { template using PMap = std::unordered_map; -Optional ExpandParams(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); +ffi::Optional ExpandParams(Function func) { + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; bool has_tuple_param = std::any_of( @@ -42,12 +42,12 @@ Optional ExpandParams(Function func) { if (!has_tuple_param) return std::nullopt; - Array params; - Array bindings; + ffi::Array params; + ffi::Array bindings; std::function expand_param = [&](const Var& param) { if (auto sinfo = param->struct_info_.as()) { - Array internal_tuple; + ffi::Array internal_tuple; for (size_t i = 0; i < sinfo->fields.size(); i++) { auto name = static_cast(std::stringstream() << param->name_hint() << "_" << i) @@ -89,7 +89,7 @@ class TupleExpander : public ExprMutator { if (auto gvar = node->op.as()) { if (auto it = replacements_.find(gvar.value()); it != replacements_.end()) { - Array new_args; + ffi::Array new_args; std::function expand_arg = [&](const Expr& arg) { if (auto sinfo = arg->struct_info_.as()) { diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 819de35e20f0..091247272a64 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -42,13 +42,13 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& ICHECK(runner.defined()) << "ValueError: The local runner is not defined!"; } // create an IRModule - IRModule mod = IRModule(Map( - {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, String("main"))}})); + IRModule mod = IRModule(ffi::Map( + {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, ffi::String("main"))}})); // fetch the number of physical cores static const auto f_cpu_count = tvm::ffi::Function::GetGlobalRequired("meta_schedule.cpu_count"); int num_threads = f_cpu_count(false).cast(); // store the results - Array results; + ffi::Array results; std::vector costs; // create a TuneContext meta_schedule::TuneContext task = meta_schedule::TuneContext( @@ -72,16 +72,16 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& /*cost_model=*/std::nullopt); int fail_count = 0, max_fail_count = 100; while (valid_count > 0 && fail_count < max_fail_count) { - Optional> candidates = + ffi::Optional> candidates = task->search_strategy.value()->GenerateMeasureCandidates(); if (!candidates.defined()) break; - Array builder_inputs; + ffi::Array builder_inputs; for (const meta_schedule::MeasureCandidate& candidate : candidates.value()) { builder_inputs.push_back(meta_schedule::BuilderInput( /*mod=*/candidate->sch->mod(), /*target=*/target)); } - Array builder_results = builder->Build(builder_inputs); + ffi::Array builder_results = builder->Build(builder_inputs); ICHECK_EQ(builder_results.size(), candidates.value().size()); int idx = 0; bool no_valid = true; // whether there is no valid schedule in this iteration @@ -95,7 +95,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& } fail_count += no_valid; // increase fail_count if there is no valid schedule if (benchmark) { - Array runner_inputs; + ffi::Array runner_inputs; int idx = 0; for (const meta_schedule::BuilderResult& builder_result : builder_results) { if (!builder_result->error_msg.has_value()) { @@ -106,7 +106,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& } idx++; } - Array runner_futures = runner->Run(runner_inputs); + ffi::Array runner_futures = runner->Run(runner_inputs); for (const meta_schedule::RunnerFuture& runner_future : runner_futures) { meta_schedule::RunnerResult runner_result = runner_future->Result(); if (runner_result->error_msg.has_value()) { @@ -153,12 +153,13 @@ Pass FewShotTuning(int valid_count, bool benchmark) { tvm::Target target = tvm::Target::Current(); ICHECK(target.defined()) << "Target is not set in current context"; // generate the few shot tuned prim funcs. - Map result; + ffi::Map result; for (const auto& [gv, func] : m->functions) { if (func->IsInstance() && !func->HasNonzeroAttr(tir::attr::kIsScheduled)) { - result.Set(gv, FewShotTunePrimFunc(GetRef(func.as()), - target, valid_count, benchmark)); + result.Set(gv, + FewShotTunePrimFunc(ffi::GetRef(func.as()), + target, valid_count, benchmark)); } else { result.Set(gv, func); } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 93b77387d550..c2f2f48cafdc 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -50,7 +50,7 @@ class ConstantFolder : public ExprMutator { * \note Only TensorStructInfo is supported at this moment. Return std::nullopt * if the input struct info is not TensorStructInfo. */ - static Optional MatchConstShape(const StructInfo& struct_info) { + static ffi::Optional MatchConstShape(const StructInfo& struct_info) { // Only support single output for call_tir at this moment. const auto* tensor_sinfo = struct_info.as(); if (tensor_sinfo == nullptr) { @@ -73,8 +73,9 @@ class ConstantFolder : public ExprMutator { * \brief Pattern match op to constant array arguments. * \return The constant array arguments, or nullopt if match fails. */ - static Optional> MatchConstArrayArgs(const Array& args) { - Array res; + static ffi::Optional> MatchConstArrayArgs( + const ffi::Array& args) { + ffi::Array res; for (auto arg : args) { auto* ptr = arg.as(); if (!ptr) return std::nullopt; @@ -87,12 +88,12 @@ class ConstantFolder : public ExprMutator { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ - Optional MatchPrimFunc(const Expr& op) { + ffi::Optional MatchPrimFunc(const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) - Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); + ffi::Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); + return ffi::GetRef(pfunc); } return std::nullopt; } @@ -101,7 +102,7 @@ class ConstantFolder : public ExprMutator { * \brief Get a cached build version of func * \return The cached func, nullopt if func cannot be built. */ - Optional GetCachedBuild(tir::PrimFunc func) { + ffi::Optional GetCachedBuild(tir::PrimFunc func) { // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once // would be helpful for future cases where PrimFunc recursively call into each other Target eval_cpu_target{"llvm"}; @@ -110,7 +111,7 @@ class ConstantFolder : public ExprMutator { if (it != func_build_cache_.end()) { return it->second; } - Optional build_func = std::nullopt; + ffi::Optional build_func = std::nullopt; try { // Not all the primfunc can be directly built via llvm, for example, if a function is @@ -118,7 +119,7 @@ class ConstantFolder : public ExprMutator { // now // TODO(Hongyi): further check and narrow the scope of foldable function const auto pf = tvm::ffi::Function::GetGlobalRequired("tir.build"); - func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); + func = WithAttr(func, tvm::attr::kGlobalSymbol, ffi::String("tir_function")); ffi::Module rt_module = pf(func, eval_cpu_target).cast(); build_func = rt_module->GetFunction("tir_function"); } catch (const tvm::Error& err) { @@ -144,10 +145,11 @@ class ConstantFolder : public ExprMutator { // Try constant evaluate the function call // if failed return std::nullopt - Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, - ffi::Shape shape, DataType ret_type) { + ffi::Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, + ffi::Array arr_args, ffi::Shape shape, + DataType ret_type) { // obtain function from the cache. - Optional func = GetCachedBuild(tir_func); + ffi::Optional func = GetCachedBuild(tir_func); if (!func) return std::nullopt; // here the vector size has an additional + 1 because we need to put ret_tensor at the end @@ -174,15 +176,15 @@ class ConstantFolder : public ExprMutator { } // Returns the folded expr if the call is successfully folded to constant, otherwise null. - Optional VisitCallTIR(Call call) { + ffi::Optional VisitCallTIR(Call call) { // call_tir needs to have at least three arguments ICHECK_GE(call->args.size(), 2); - Optional func = MatchPrimFunc(call->args[0]); + ffi::Optional func = MatchPrimFunc(call->args[0]); ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; - Optional> arr_args = + ffi::Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; - Optional shape = MatchConstShape(call->sinfo_args[0]); + ffi::Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; // Pattern 0: call constant function, const argument with const shape. if (func && arr_args && shape && output_not_tuple) { @@ -216,7 +218,7 @@ class ConstantFolder : public ExprMutator { if (op_node == nullptr) { return post_call; } - auto op = GetRef(op_node); + auto op = ffi::GetRef(op_node); if (op.same_as(call_tir_op)) { return VisitCallTIR(post_call).value_or(post_call); @@ -230,10 +232,10 @@ class ConstantFolder : public ExprMutator { // // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16])) // - Array new_args; + ffi::Array new_args; for (auto arg : post_call->args) { if (arg->IsInstance()) { - Optional val = LookupBinding(Downcast(arg)); + ffi::Optional val = LookupBinding(Downcast(arg)); if (val.defined() && val.value()->IsInstance()) { new_args.push_back(val.value()); continue; @@ -254,7 +256,7 @@ class ConstantFolder : public ExprMutator { // If the legalized expression is call_tir, try to fold it. const CallNode* call = legalized_expr.as(); if (call && call->op.same_as(call_tir_op)) { - return VisitCallTIR(GetRef(call)).value_or(post_call); + return VisitCallTIR(ffi::GetRef(call)).value_or(post_call); } } else if (op->name == "relax.tensor_to_shape") { // Special handling for composite op "relax.tensor_to_shape" @@ -275,7 +277,7 @@ class ConstantFolder : public ExprMutator { ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); int64_t num_elems = ndarray->shape[0]; - Array shape_values; + ffi::Array shape_values; for (int64_t i = 0; i < num_elems; i++) { shape_values.push_back(IntImm(DataType::Int(64), data[i])); } @@ -286,12 +288,12 @@ class ConstantFolder : public ExprMutator { // TODO(sunggg): revisit this when we extend ConstantFolding to fold ffi::Function. Expr arg = post_call->args[0]; ShapeExpr shape = Downcast(arg); - Array values = shape->values; - Array arr; + ffi::Array values = shape->values; + ffi::Array arr; bool is_known = true; for (size_t i = 0; i < values.size(); i++) { PrimExpr val = values[i]; - arr.push_back(GetRef(val.as())); + arr.push_back(ffi::GetRef(val.as())); is_known &= (val.dtype() == DataType::Int(64)); } if (is_known) { @@ -306,7 +308,7 @@ class ConstantFolder : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - Optional opt = LookupBinding(GetRef(op)); + ffi::Optional opt = LookupBinding(ffi::GetRef(op)); // `as` check checks if opt is not null and is instance of constant if (opt.as()) { return opt.value(); @@ -315,7 +317,7 @@ class ConstantFolder : public ExprMutator { } // cache for function build, via structural equality - std::unordered_map, StructuralHash, StructuralEqual> + std::unordered_map, StructuralHash, StructuralEqual> func_build_cache_; }; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 4deb720342f2..acd54d043e56 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -120,10 +120,10 @@ class GraphCreator : public ExprVisitor { // true. const auto* func = it.second.as(); if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) || - func->GetAttr(attr::kCodegen).has_value()) { + func->GetAttr(attr::kCodegen).has_value()) { continue; } - creator(GetRef(func)); + creator(ffi::GetRef(func)); } // The algorithm of the graph creator ensures that each created node will be added to the @@ -195,7 +195,7 @@ class GraphCreator : public ExprVisitor { static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); OpPatternKind pattern = OpPatternKind::kOpaque; - Array args = call->args; + ffi::Array args = call->args; // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the // function attribute and visit the arguments one by one. @@ -209,7 +209,7 @@ class GraphCreator : public ExprVisitor { // Override args for call_tir args = Downcast(call->args[1])->fields; - Optional opt_pattern = func->GetAttr("op_pattern"); + ffi::Optional opt_pattern = func->GetAttr("op_pattern"); if (opt_pattern.defined()) { pattern = static_cast(Downcast(opt_pattern)->value); } else { @@ -222,7 +222,7 @@ class GraphCreator : public ExprVisitor { for (const Expr& arg : args) { ICHECK(IsLeafOrTuple(arg)) << "FuseOps expects all relax::Call nodes to have non-nested arguments, " - << "but " << GetRef(call) << " has argument " << arg + << "but " << ffi::GetRef(call) << " has argument " << arg << ", which is neither a leaf node nor a relax::Tuple"; VisitLeaf(arg, binding_var_node, pattern); } @@ -297,7 +297,7 @@ class GraphCreator : public ExprVisitor { */ IndexedForwardGraph::Node* CreateNode(const Object* key) { ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) - << "The object " << GetRef(key) << " appears at multiple definition sites."; + << "The object " << ffi::GetRef(key) << " appears at multiple definition sites."; auto* node = arena_->make(); graph_.node_map[key] = node; return node; @@ -312,12 +312,12 @@ class GraphCreator : public ExprVisitor { void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { auto it = graph_.node_map.find(key); ICHECK(it != graph_.node_map.end() && it->second == node) - << "Cannot add node " << GetRef(key) << " to the post-DFS order, " + << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " << "because the node for this object has not yet been created."; // We only set the reference of the node when adding it to the post-dfs order. Thus, if the // reference of a node is already set, it must have been appended to the post-dfs order. - ICHECK(node->ref == nullptr) << "Cannot add node " << GetRef(key) + ICHECK(node->ref == nullptr) << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " << "because it has already been added."; @@ -354,7 +354,7 @@ class GraphCreator : public ExprVisitor { */ void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) - << "The input node " << GetRef(node->ref) + << "The input node " << ffi::GetRef(node->ref) << " cannot have have its OpPatternKind set more than once."; initialized_nodes_.insert(node); node->pattern = pattern; @@ -481,7 +481,7 @@ class FunctionCreator : public ExprMutator { * It will become the value of the kComposite attribute of the created function. * \note The created function won't be returned immediately. It's stored in the `function_` field. */ - void CreateFunction(Map group_attrs) { + void CreateFunction(ffi::Map group_attrs) { // Step 1. Start constructing a new dataflow block. builder_->BeginDataflowBlock(); @@ -493,16 +493,16 @@ class FunctionCreator : public ExprMutator { ICHECK(!item_indices.empty()); int param_idx = tuple_param_idx_[tuple_arg]; Var param = params_[param_idx]; - String param_name = params_[param_idx]->name_hint(); + ffi::String param_name = params_[param_idx]->name_hint(); TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); - Array item_args; - Array item_params; + ffi::Array item_args; + ffi::Array item_params; item_args.reserve(item_indices.size()); item_params.reserve(item_indices.size()); for (int item_idx : item_indices) { Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); - item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); + item_args.push_back(TupleGetItem(ffi::GetRef(tuple_arg), item_idx)); item_params.push_back(item_param); tuple_get_item_remap[tuple_arg][item_idx] = item_param; } @@ -513,7 +513,7 @@ class FunctionCreator : public ExprMutator { } // Step 3. Visit each binding and collect outputs one by one. - Array outputs(output_vars_.size(), Expr()); + ffi::Array outputs(output_vars_.size(), Expr()); for (const Binding& binding : bindings_) { // Special handing for TupleGetItem. if (const auto* var_binding = binding.as()) { @@ -561,7 +561,7 @@ class FunctionCreator : public ExprMutator { /*ret_struct_info=*/std::nullopt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); - Array free_vars = + ffi::Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); @@ -577,15 +577,15 @@ class FunctionCreator : public ExprMutator { } /*! \brief The original bindings of the function */ - Array bindings_; + ffi::Array bindings_; /*! \brief The parameters of the function */ - Array params_; + ffi::Array params_; /*! \brief The arguments to call the function on the caller side */ - Array arguments_; + ffi::Array arguments_; /*! \brief The name for the fused function */ - String name_hint_ = "fused"; + ffi::String name_hint_ = "fused"; /*! \brief The constructed Relax function */ - Optional function_ = std::nullopt; + ffi::Optional function_ = std::nullopt; private: std::optional GetOutputIndex(Var v) { @@ -612,8 +612,9 @@ class FunctionCreator : public ExprMutator { const auto* var = expr.as(); if ((var == nullptr || defined_vars_.count(var) == 0) && (lift_constant_ || !expr->IsInstance())) { - String name = var != nullptr ? var->name_hint() - : String("param_" + std::to_string(n_param_for_const_++)); + ffi::String name = var != nullptr + ? var->name_hint() + : ffi::String("param_" + std::to_string(n_param_for_const_++)); StructInfo param_sinfo = GetStructInfo(expr); if (!IsInlinableConstants(expr)) { Var param(std::move(name), GetStructInfo(expr)); @@ -719,8 +720,8 @@ class OperatorFusor : public ExprMutator { * \brief The main transformation on the IRModule * \return The new IRModule after transformation */ - IRModule Transform(const Array& entry_function_names = {}) { - Array entry_functions; + IRModule Transform(const ffi::Array& entry_function_names = {}) { + ffi::Array entry_functions; if (entry_function_names.empty()) { entry_functions = mod_->GetGlobalVars(); } else { @@ -733,7 +734,7 @@ class OperatorFusor : public ExprMutator { // Only visit Relax functions with neither attr::kPrimitive nor // attr::kCodegen. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive) && - !func->GetAttr(attr::kCodegen).has_value()) { + !func->GetAttr(attr::kCodegen).has_value()) { auto updated_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, updated_func); } @@ -882,7 +883,7 @@ class OperatorFusor : public ExprMutator { * \param bindings The bindings to be collected * \note The function update is done by `AppendBinding(...)` */ - void CollectFuncBindings(const Array& bindings) { + void CollectFuncBindings(const ffi::Array& bindings) { for (const Binding& binding : bindings) { // If the binding is the only binding in its group, there is no need to create a new function. Group* group = GetGroupFromBinding(binding); @@ -898,7 +899,7 @@ class OperatorFusor : public ExprMutator { } } - void CollectFuncBoundary(const Array& bindings) { + void CollectFuncBoundary(const ffi::Array& bindings) { for (const Binding& binding : bindings) { // Step 1. Get current binding's group Group* cur_group = GetGroupFromBinding(binding); @@ -969,8 +970,8 @@ class OperatorFusor : public ExprMutator { * \param args The arguments to be updated * \return The updated arguments */ - Array UpdateArgs(const Array& args) { - Array new_args; + ffi::Array UpdateArgs(const ffi::Array& args) { + ffi::Array new_args; new_args.reserve(args.size()); for (const Expr& arg : args) { new_args.push_back(VisitExpr(arg)); @@ -980,7 +981,7 @@ class OperatorFusor : public ExprMutator { private: // Topologically sort bindings according to the group dependency relations. - Array TopoSortByGroupDep(const Array& bindings) { + ffi::Array TopoSortByGroupDep(const ffi::Array& bindings) { std::unordered_map> bindings_per_group; // The order to visit groups should respect the original order of bindings as much as possible. std::vector group_order; @@ -1003,7 +1004,7 @@ class OperatorFusor : public ExprMutator { } }; - Array sorted; + ffi::Array sorted; for (auto g : group_order) { dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) { @@ -1054,7 +1055,7 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants, const Array& entry_function_names) { + bool lift_constants, const ffi::Array& entry_function_names) { return OperatorFusor(mod, partition, lift_constants).Transform(entry_function_names); } @@ -1069,19 +1070,20 @@ class PatternBasedPartitioner : ExprVisitor { using PatternCheckContext = transform::PatternCheckContext; using ExprVisitor::VisitExpr_; using FCheckMatch = ffi::TypedFunction; - using FAttrsGetter = ffi::TypedFunction(const Map&)>; + using FAttrsGetter = + ffi::TypedFunction(const ffi::Map&)>; - static GroupMap Run(String pattern_name, DFPattern pattern, - Map annotation_patterns, FCheckMatch check, Expr expr, - support::Arena* arena, FAttrsGetter attrs_getter) { + static GroupMap Run(ffi::String pattern_name, DFPattern pattern, + ffi::Map annotation_patterns, FCheckMatch check, + Expr expr, support::Arena* arena, FAttrsGetter attrs_getter) { PatternBasedPartitioner part(pattern_name, pattern, annotation_patterns, check, arena, attrs_getter); part.VisitExpr(expr); return part.group_map_; } - PatternBasedPartitioner(String pattern_name, DFPattern pattern, - Map annotation_patterns, FCheckMatch check, + PatternBasedPartitioner(ffi::String pattern_name, DFPattern pattern, + ffi::Map annotation_patterns, FCheckMatch check, support::Arena* arena, FAttrsGetter attrs_getter) : pat_name_(pattern_name), pat_(pattern), @@ -1091,7 +1093,7 @@ class PatternBasedPartitioner : ExprVisitor { attrs_getter_(attrs_getter) {} void VisitBindingBlock_(const DataflowBlockNode* block) final { - current_block_use_def_ = DataflowBlockUseDef(GetRef(block)); + current_block_use_def_ = DataflowBlockUseDef(ffi::GetRef(block)); ExprVisitor::VisitBindingBlock_(block); current_block_use_def_ = {}; } @@ -1112,14 +1114,14 @@ class PatternBasedPartitioner : ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { VisitVarDef(binding->var); - if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + if (auto matches_opt = ExtractMatchedExpr(pat_, ffi::GetRef(call), bindings_)) { const auto& context = CreatePatternCheckContext(call, matches_opt.value()); if (check_ != nullptr && !check_(context)) { return; } for (const auto& [pat, match] : matches_opt.value()) { - if ((pat->IsInstance() && match != GetRef(call)) || + if ((pat->IsInstance() && match != ffi::GetRef(call)) || pat->IsInstance()) { auto g = GetGroup(match); if (g && g->FindRoot()->num_nodes > 1) { @@ -1164,7 +1166,7 @@ class PatternBasedPartitioner : ExprVisitor { // the previous group. For example, when there are two back-to-back conv2d ops, the output // of the first conv2d is matched to the input of the second conv2d via a wildcard pattern. // But we must avoid merging the first conv2d into the group of the second conv2d. - if ((pat->IsInstance() && match != GetRef(call)) || + if ((pat->IsInstance() && match != ffi::GetRef(call)) || pat->IsInstance()) { // Put the bound variable on the LHS into the same parent group. AddToGroup(value_to_bound_var_[match], parent_group); @@ -1196,28 +1198,28 @@ class PatternBasedPartitioner : ExprVisitor { } PatternCheckContext CreatePatternCheckContext(const CallNode* call, - const Map& matched_result) { - Map annotated_expr; + const ffi::Map& matched_result) { + ffi::Map annotated_expr; for (const auto& it : annotation_pat_) { if (matched_result.count(it.second)) { annotated_expr.Set(it.first, matched_result[it.second]); } } - Map matched_bindings; + ffi::Map matched_bindings; for (const auto& [pat, match] : matched_result) { if (pat->IsInstance() || pat->IsInstance()) { matched_bindings.Set(value_to_bound_var_[match], match); } } - return PatternCheckContext(GetRef(call), annotated_expr, matched_bindings, + return PatternCheckContext(ffi::GetRef(call), annotated_expr, matched_bindings, current_block_use_def_, value_to_bound_var_); } // check if a previous matched subgraph is subsumed by the current matched result - bool GraphSubsumedInMatchedValues(const Array& vars_in_graph, - const Map& matched_result) { + bool GraphSubsumedInMatchedValues(const ffi::Array& vars_in_graph, + const ffi::Map& matched_result) { std::set matched_vars; for (const auto& [pat, match] : matched_result) { if ((pat->IsInstance() || pat->IsInstance())) @@ -1230,17 +1232,17 @@ class PatternBasedPartitioner : ExprVisitor { return true; } - String pat_name_; + ffi::String pat_name_; DFPattern pat_; - Map annotation_pat_; + ffi::Map annotation_pat_; FCheckMatch check_; support::Arena* arena_; FAttrsGetter attrs_getter_; - Map bindings_; - Map value_to_bound_var_; - Map> current_block_use_def_; + ffi::Map bindings_; + ffi::Map value_to_bound_var_; + ffi::Map> current_block_use_def_; GroupMap group_map_; - std::map> vars_in_group_; + std::map> vars_in_group_; }; /*! @@ -1263,8 +1265,8 @@ class CompositeFunctionAnnotator : public ExprMutator { } const auto& base_func = (*it).second; if (const auto* func = base_func.as()) { - if (func->GetAttr(attr::kComposite).has_value() || - func->GetAttr(attr::kCodegen).has_value()) { + if (func->GetAttr(attr::kComposite).has_value() || + func->GetAttr(attr::kCodegen).has_value()) { continue; } @@ -1284,15 +1286,15 @@ class CompositeFunctionAnnotator : public ExprMutator { if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) { return Call(it->second, call_node->args); } - auto func = builder_->GetContextIRModule()->Lookup(GetRef(gvar)); - if (auto composite_name = func->GetAttr(attr::kComposite)) { + auto func = builder_->GetContextIRModule()->Lookup(ffi::GetRef(gvar)); + if (auto composite_name = func->GetAttr(attr::kComposite)) { auto new_func = Downcast(VisitExpr(func)); auto codegen_name = GetCodegenName(composite_name.value()); auto gsymbol = gvar->name_hint + "_" + codegen_name; new_func = WithAttrs(new_func, {{attr::kCodegen, codegen_name}, {tvm::attr::kGlobalSymbol, gsymbol}}); new_func = WithoutAttr(std::move(new_func), tvm::relax::attr::kPrimitive); - builder_->GetContextIRModule()->Remove(GetRef(gvar)); + builder_->GetContextIRModule()->Remove(ffi::GetRef(gvar)); auto new_gvar = builder_->AddFunction(new_func, gsymbol); gvar_map_[gvar] = new_gvar; return Call(new_gvar, call_node->args); @@ -1304,7 +1306,7 @@ class CompositeFunctionAnnotator : public ExprMutator { Expr VisitExpr_(const FunctionNode* func_node) final { Function f_inner = Downcast(ExprMutator::VisitExpr_(func_node)); - if (!func_node->GetAttr(attr::kComposite)) { + if (!func_node->GetAttr(attr::kComposite)) { // This lambda function doesn't have `attr::kComposite`, so it // was not produced by FuseOps. return f_inner; @@ -1312,8 +1314,8 @@ class CompositeFunctionAnnotator : public ExprMutator { f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive); - Array param_vars; - Array params; + ffi::Array param_vars; + ffi::Array params; for (auto v : func_node->params) { Var new_v(v->name_hint(), GetStructInfo(v)); @@ -1341,13 +1343,13 @@ class CompositeFunctionAnnotator : public ExprMutator { std::unordered_map gvar_map_; }; -IRModule FuseOpsByPattern(const tvm::Array& patterns, IRModule mod, +IRModule FuseOpsByPattern(const tvm::ffi::Array& patterns, IRModule mod, bool bind_constants, bool annotate_codegen, - Array entry_function_names) { + ffi::Array entry_function_names) { support::Arena arena; for (const auto& pattern : patterns) { - Array entry_functions; + ffi::Array entry_functions; if (entry_function_names.size()) { for (const auto& name : entry_function_names) { auto gv = mod->GetGlobalVar(name); @@ -1363,8 +1365,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, } const FunctionNode* function = base_func.as(); if (function->GetAttr(attr::kPrimitive).value_or(false) || - function->GetAttr(attr::kComposite).has_value() || - function->GetAttr(attr::kCodegen).has_value()) { + function->GetAttr(attr::kComposite).has_value() || + function->GetAttr(attr::kCodegen).has_value()) { continue; } entry_functions.push_back(Downcast(base_func)); @@ -1379,7 +1381,7 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, CHECK(!group_map.count(key)) << "ValueError: " << "IRModule is invalid. " - << "The object " << GetRef(key) << " appears in multiple partitions, " + << "The object " << ffi::GetRef(key) << " appears in multiple partitions, " << "which can occur when the IRModule was not single-site assignment"; group_map.insert({key, value}); } @@ -1395,10 +1397,11 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, namespace transform { -FusionPattern::FusionPattern(String name, DFPattern pattern, - Map annotation_patterns, - Optional check, Optional attrs_getter) { - ObjectPtr n = make_object(); +FusionPattern::FusionPattern(ffi::String name, DFPattern pattern, + ffi::Map annotation_patterns, + ffi::Optional check, + ffi::Optional attrs_getter) { + ObjectPtr n = ffi::make_object(); n->name = std::move(name); n->pattern = std::move(pattern); n->annotation_patterns = std::move(annotation_patterns); @@ -1411,17 +1414,18 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.transform.FusionPattern", - [](String name, DFPattern pattern, Map annotation_patterns, - Optional check, Optional attrs_getter) { + [](ffi::String name, DFPattern pattern, ffi::Map annotation_patterns, + ffi::Optional check, ffi::Optional attrs_getter) { return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); }); }); -PatternCheckContext::PatternCheckContext(Expr matched_expr, Map annotated_expr, - Map matched_bindings, - Map> var_usages, - Map value_to_bound_var) { - ObjectPtr n = make_object(); +PatternCheckContext::PatternCheckContext(Expr matched_expr, + ffi::Map annotated_expr, + ffi::Map matched_bindings, + ffi::Map> var_usages, + ffi::Map value_to_bound_var) { + ObjectPtr n = ffi::make_object(); n->matched_expr = std::move(matched_expr); n->annotated_expr = std::move(annotated_expr); n->matched_bindings = std::move(matched_bindings); @@ -1448,8 +1452,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.transform.FuseOps", FuseOps); }); -Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, - bool annotate_codegen, const Array& entry_function_names) { +Pass FuseOpsByPattern(const tvm::ffi::Array& patterns, bool bind_constants, + bool annotate_codegen, const ffi::Array& entry_function_names) { auto pass_func = // [=](IRModule m, PassContext pc) { return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen, diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index db3916bc2210..61b3a6024810 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -39,10 +39,10 @@ namespace tir { */ class SymbolicMatcher : ExprFunctor { public: - explicit SymbolicMatcher(arith::Analyzer* analyzer, Map* var_remap) + explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) : analyzer_(analyzer), var_remap_(var_remap) {} - void Match(const Array& params, const Array& args) { + void Match(const ffi::Array& params, const ffi::Array& args) { CHECK_EQ(params.size(), args.size()); for (size_t i = 0; i < params.size(); ++i) { Match(params[i], args[i]); @@ -66,15 +66,15 @@ class SymbolicMatcher : ExprFunctor(); \ - if (rhs) { \ - VisitExpr(op->a, rhs->a); \ - VisitExpr(op->b, rhs->b); \ - } else { \ - must_prove_ = must_prove_ && (GetRef(op) == other); \ - } \ +#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ + void VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + if (rhs) { \ + VisitExpr(op->a, rhs->a); \ + VisitExpr(op->b, rhs->b); \ + } else { \ + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); \ + } \ } TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); @@ -98,7 +98,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << GetRef(op) + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an integer argument with value " << op->value << ", " << "but was provided with the argument " << other; } @@ -107,7 +107,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << GetRef(op) + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an float argument with value " << op->value << ", " << "but was provided with the argument " << other; } @@ -116,7 +116,7 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs) { - LOG(FATAL) << "Parameter expression " << GetRef(op) << " expected an cast to " + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " << op->dtype << " as the argument, " << "but was provided with the argument " << other; } @@ -124,13 +124,14 @@ class SymbolicMatcher : ExprFunctor(op); + auto lhs = ffi::GetRef(op); if (lhs.same_as(rhs)) { // Reference identity, no further checks needed. } else if (op->dtype.code() != rhs->dtype.code()) { - LOG(FATAL) << "Parameter expression " << GetRef(op) << " with dtype " << op->dtype - << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); + LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " with dtype " + << op->dtype << " cannot match to argument " << rhs << " with dtype " + << rhs.dtype(); } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -144,12 +145,12 @@ class SymbolicMatcher : ExprFunctortrue_value, rhs->true_value); VisitExpr(op->false_value, rhs->false_value); } else { - must_prove_ = must_prove_ && (GetRef(op) == other); + must_prove_ = must_prove_ && (ffi::GetRef(op) == other); } } arith::Analyzer* analyzer_; - Map* var_remap_; + ffi::Map* var_remap_; PrimExpr must_prove_ = Bool(true); }; @@ -158,8 +159,8 @@ class SymbolicMatcher : ExprFunctor& buffer_map, - const Map& var_map) { + explicit FuseTIRBufferSubstitutor(const ffi::Map& buffer_map, + const ffi::Map& var_map) { buffer_remap_ = buffer_map; var_remap_ = var_map; for (const auto& [src, tgt] : buffer_map) { @@ -171,16 +172,16 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { Buffer SubstituteAllocatedBuffer(Buffer buffer) { ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end()); - Array shape = + ffi::Array shape = MutateArray(buffer->shape, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); - Array strides = MutateArray( + ffi::Array strides = MutateArray( buffer->strides, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset); if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) && elem_offset.same_as(buffer->elem_offset)) { return buffer; } else { - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->shape = std::move(shape); n->strides = std::move(strides); n->elem_offset = std::move(elem_offset); @@ -192,10 +193,10 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { private: PrimExpr VisitExpr_(const VarNode* _op) final { - if (auto it = var_remap_.find(GetRef(_op)); it != var_remap_.end()) { + if (auto it = var_remap_.find(ffi::GetRef(_op)); it != var_remap_.end()) { return (*it).second; } else { - return GetRef(_op); + return ffi::GetRef(_op); } } @@ -206,7 +207,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return load; } else { - auto n = make_object(*load.get()); + auto n = ffi::make_object(*load.get()); n->buffer = buffer; return BufferLoad(n); } @@ -219,7 +220,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { return store; } else { - auto n = make_object(*store.get()); + auto n = ffi::make_object(*store.get()); n->buffer = buffer; return BufferStore(n); } @@ -239,7 +240,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { region.same_as(match_buffer->source->region)) { return match_buffer; } else { - auto n = make_object(*match_buffer.get()); + auto n = ffi::make_object(*match_buffer.get()); n->buffer = tgt_buffer; n->source = BufferRegion(src_buffer, region); return MatchBufferRegion(n); @@ -257,15 +258,15 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { }; // Step 1. Mutate `match_buffers`. - Array match_buffers = + ffi::Array match_buffers = MutateArray(block->match_buffers, f_mutate_match_buffers); // Step 2. Mutate the read/write region. - Array reads = MutateArray(block->reads, f_mutate_read_write_region); - Array writes = MutateArray(block->writes, f_mutate_read_write_region); + ffi::Array reads = MutateArray(block->reads, f_mutate_read_write_region); + ffi::Array writes = MutateArray(block->writes, f_mutate_read_write_region); // Step 3. Mutate the Allocate Buffers. - Array alloc_buffers = MutateArray(block->alloc_buffers, [this](const Buffer& buffer) { - return SubstituteAllocatedBuffer(buffer); - }); + ffi::Array alloc_buffers = + MutateArray(block->alloc_buffers, + [this](const Buffer& buffer) { return SubstituteAllocatedBuffer(buffer); }); reads = UnionAccessRegion(reads); writes = UnionAccessRegion(writes); @@ -288,16 +289,16 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { private: /*! \brief Mapping from src buffer to tgt buffer. */ - Map buffer_remap_; + ffi::Map buffer_remap_; /*! \brief Mapping from src tir var to tgt var. */ - Map var_remap_; + ffi::Map var_remap_; - Array UnionAccessRegion(const Array& regions) const { + ffi::Array UnionAccessRegion(const ffi::Array& regions) const { // For now we only allow Buffer access the same elements. // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. // Note: the order of return region should remain the same as the first occurrence of the region - Array ret; + ffi::Array ret; std::unordered_map buffer_region_set; for (const BufferRegion& region : regions) { @@ -343,7 +344,7 @@ class BlockNameDeduplicator : public tir::StmtMutator { Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); - String name = GetUniqueName(block->name_hint); + ffi::String name = GetUniqueName(block->name_hint); if (name == block->name_hint) { return block; @@ -355,8 +356,8 @@ class BlockNameDeduplicator : public tir::StmtMutator { } } - String GetUniqueName(const String& prefix) { - String unique_prefix = prefix; + ffi::String GetUniqueName(const ffi::String& prefix) { + ffi::String unique_prefix = prefix; auto it = name_count_.find(prefix); while (name_count_.count(unique_prefix)) { unique_prefix = prefix + "_" + std::to_string(++it->second); @@ -368,16 +369,16 @@ class BlockNameDeduplicator : public tir::StmtMutator { // TODO(relax-team): It should detects the number suffix and do renaming properly // e.g. GetUniqueName("name1") should return "name2" instead of "name10". /*! \brief The count map to make block name unique. */ - std::unordered_map name_count_; + std::unordered_map name_count_; }; } // namespace tir namespace relax { -static Array GetInplaceOutputIndices(const Array& inplace_indices, - int num_inputs) { - Array ret; +static ffi::Array GetInplaceOutputIndices(const ffi::Array& inplace_indices, + int num_inputs) { + ffi::Array ret; int last_idx = num_inputs; for (auto idx : inplace_indices) { int i = idx.IntValue(); @@ -396,7 +397,7 @@ static Array GetInplaceOutputIndices(const Array& inplace_indi class RelaxToTIRVarMapCollector : public ExprVisitor { public: explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {} - static Map Collect(const IRModule& mod, const Function& func) { + static ffi::Map Collect(const IRModule& mod, const Function& func) { RelaxToTIRVarMapCollector visitor(mod); visitor(func->body); return visitor.relax_to_tir_var_map_; @@ -414,7 +415,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " - << GetRef(call); + << ffi::GetRef(call); CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_); } @@ -426,7 +427,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { const auto& relax_args = Downcast(call->args[1])->fields; - Array relax_results; + ffi::Array relax_results; if (lhs_var->IsInstance()) { relax_results = Downcast(lhs_var)->fields; } else { @@ -437,7 +438,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { size_t num_inputs = relax_args.size(); size_t num_outputs = relax_results.size(); - Array output_idxs; + ffi::Array output_idxs; if (in_place) { const auto* attrs = call->attrs.as(); CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; @@ -479,7 +480,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { private: /*! \brief The IRModule */ const IRModule& mod_; - Map relax_to_tir_var_map_; + ffi::Map relax_to_tir_var_map_; Var current_var_; }; @@ -491,8 +492,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param gv The global var of relax subfunction to be fused into one PrimFunc * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call) */ - static std::pair> GetFusedTIR(const IRModule& mod, - const GlobalVar& gv) { + static std::pair> GetFusedTIR(const IRModule& mod, + const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); CHECK(f->IsInstance()) @@ -500,7 +501,7 @@ class FusedTIRConstructor : public ExprVisitor { CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); - Array inplace_indices; + ffi::Array inplace_indices; for (size_t idx : visitor.inplace_indices_) { inplace_indices.push_back(Integer(idx)); } @@ -508,18 +509,19 @@ class FusedTIRConstructor : public ExprVisitor { } private: - explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + explicit FusedTIRConstructor(const IRModule& mod, const ffi::String& func_name) : mod_(mod), func_name_(func_name) {} void VisitExpr_(const FunctionNode* func) final { - auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef(func)); - std::vector> prim_func_params; + auto relax_to_tir_var_map = + RelaxToTIRVarMapCollector::Collect(mod_, ffi::GetRef(func)); + std::vector> prim_func_params; for (const Var& relax_param : func->params) { size_t size_before = prim_func_params.size(); CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param)); - auto param_buffers = [&]() -> Array { - Array out; + auto param_buffers = [&]() -> ffi::Array { + ffi::Array out; for (size_t i = size_before; i < prim_func_params.size(); i++) { if (auto buf = prim_func_params[i].as()) { out.push_back(buf.value()); @@ -565,7 +567,7 @@ class FusedTIRConstructor : public ExprVisitor { ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; - const Array& buffers = (*it).second; + const ffi::Array& buffers = (*it).second; // map of input buffers to indices (helpful for detecting in-place inputs) std::unordered_map buffer_to_idx; @@ -635,7 +637,7 @@ class FusedTIRConstructor : public ExprVisitor { ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " - << GetRef(call); + << ffi::GetRef(call); // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); @@ -659,7 +661,7 @@ class FusedTIRConstructor : public ExprVisitor { // Step 5. Map input arguments to buffer MapInputBuffer(prim_func, call->args[1]); - const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); + const ffi::Array>& output_buffer_shapes = GetCallTIROutputShapes(call); AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes); @@ -696,14 +698,14 @@ class FusedTIRConstructor : public ExprVisitor { } end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_sinfo->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( - GetRef(tuple_get_item), + ffi::GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); } } void VisitExpr_(const TupleNode* tuple) final { ExprVisitor::VisitExpr_(tuple); - Array buffers; + ffi::Array buffers; for (const Expr& expr : tuple->fields) { auto it = func_info_.expr2buffers.find(expr); if (it != func_info_.expr2buffers.end()) { @@ -711,7 +713,7 @@ class FusedTIRConstructor : public ExprVisitor { } } if (!buffers.empty()) { - func_info_.expr2buffers.Set(GetRef(tuple), buffers); + func_info_.expr2buffers.Set(ffi::GetRef(tuple), buffers); } } @@ -723,7 +725,7 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Get the number of outputs for a call_tir node. * \return The number of outputs. */ - static Array> GetCallTIROutputShapes(const CallNode* call) { + static ffi::Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); @@ -734,7 +736,7 @@ class FusedTIRConstructor : public ExprVisitor { return shape_expr->values; }; if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { - Array> shapes; + ffi::Array> shapes; for (const StructInfo& field : tuple_sinfo->fields) { const auto* tensor_sinfo = field.as(); CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " @@ -754,11 +756,11 @@ class FusedTIRConstructor : public ExprVisitor { } /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ - void MapArgsToBuffer(const Array args, const Array& buffers) { + void MapArgsToBuffer(const ffi::Array args, const ffi::Array& buffers) { size_t buffer_idx = 0; for (const Expr& arg : args) { if (const auto* v = arg.as()) { - auto it = func_info_.expr2buffers.find(GetRef(v)); + auto it = func_info_.expr2buffers.find(ffi::GetRef(v)); // Substitute the buffer with the already allocated one if it is an intermediate var if (it != func_info_.expr2buffers.end()) { for (const tir::Buffer& target_buffer : (*it).second) { @@ -781,8 +783,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param output_size The number of output params. All output params are at the end of param list. */ void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { - Array arg_list; - Array buffer_list; + ffi::Array arg_list; + ffi::Array buffer_list; if (const auto* arg_tuple = args.as()) { arg_list = arg_tuple->fields; } else { @@ -799,14 +801,14 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, - const Array& output_indices) { + static ffi::Array GetPrimFuncOutputParams(const tir::PrimFunc& func, + const ffi::Array& output_indices) { size_t n = func->params.size(); int symbolic_var_index = -1; size_t output_size = output_indices.size(); ICHECK_GE(n, output_size); - Array ret; + ffi::Array ret; for (auto idx : output_indices) { int i = idx.IntValue(); const tir::Var& param = func->params[static_cast(i)]; @@ -835,15 +837,15 @@ class FusedTIRConstructor : public ExprVisitor { * \param output_shapes The shape of output params. */ void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func, - const Array>& output_shapes) { + const ffi::Array>& output_shapes) { bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); size_t n = func->params.size(); int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); - Array output_buffers; - Array output_idxs; + ffi::Array output_buffers; + ffi::Array output_idxs; if (is_inplace) { const auto* attrs = call->attrs.as(); CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; @@ -854,7 +856,7 @@ class FusedTIRConstructor : public ExprVisitor { } } - Array output_params = GetPrimFuncOutputParams(func, output_idxs); + ffi::Array output_params = GetPrimFuncOutputParams(func, output_idxs); auto input_buffers = func_info_.expr2buffers.Get(call->args[1]); for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = output_params[i]; @@ -868,8 +870,8 @@ class FusedTIRConstructor : public ExprVisitor { } auto unify_name_hints = [this, &buffer]() { - String base_name = buffer->name; - String unique_name = base_name + "_intermediate"; + ffi::String base_name = buffer->name; + ffi::String unique_name = base_name + "_intermediate"; size_t unique_id = 0; std::unordered_set names; @@ -883,7 +885,7 @@ class FusedTIRConstructor : public ExprVisitor { return unique_name; }; // Update buffer with new symbolic shape according to the sinfo - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->shape = output_shapes[i]; n->name = unify_name_hints(); tir::Buffer new_buffer(n); @@ -895,7 +897,7 @@ class FusedTIRConstructor : public ExprVisitor { func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers - func_info_.expr2buffers.Set(GetRef(call), output_buffers); + func_info_.expr2buffers.Set(ffi::GetRef(call), output_buffers); } /*! @@ -905,8 +907,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param out The vector into which to collect the params/buffers */ static void CollectPrimFuncParams(const Var& relax_param, - std::vector>* out, - const Optional& tir_buffer_param) { + std::vector>* out, + const ffi::Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); CHECK(!struct_info.as()) @@ -955,12 +957,12 @@ class FusedTIRConstructor : public ExprVisitor { * \return The fused TIR */ tir::PrimFunc ConstructFunc() { - Map attr_map; + ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers - Array alloc_buffers; + ffi::Array alloc_buffers; for (const tir::Buffer& buf : func_info_.alloc_buffers) { if (func_info_.output_buffers.count(buf.get()) == 0) { alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf)); @@ -998,25 +1000,25 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief auxiliary information for FuseTIR */ struct FuseFuncInfo { /*! \brief The arguments for calling prim_func */ - Array arguments; + ffi::Array arguments; /*! * \brief The map from each dataflow var (intermediate var) to the corresponding buffers * allocated in the fused func */ - Map> expr2buffers; + ffi::Map> expr2buffers; /*! \brief The buffers to allocate in the fused func*/ - Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ - Array bodies; + ffi::Array bodies; /*! \brief The params of the fused function*/ - Array params; + ffi::Array params; /*! * \brief The map from buffer in original functions to corresponding buffer in the fused * function */ - Map buffer_subst_map; + ffi::Map buffer_subst_map; /*! \brief The `buffer_map` in the fused function*/ - Map buffer_map; + ffi::Map buffer_map; /*! \brief The output buffers in the function buffer_map*/ std::unordered_set output_buffers; /*! \brief The name of the fused function */ @@ -1028,7 +1030,7 @@ class FusedTIRConstructor : public ExprVisitor { * `symbolic_var_matcher`, and must be before it in the struct * order. */ - Map symbolic_var_remap; + ffi::Map symbolic_var_remap; /*! \brief The map from symbolic var to its value in the fused function * @@ -1046,7 +1048,7 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief The IRModule */ const IRModule& mod_; /*! \brief The name hint for the input func. */ - String func_name_; + ffi::String func_name_; /*! \brief The helper info to fuse TIR prim_func */ FuseFuncInfo func_info_; /*! \brief The tir function after fusion*/ @@ -1075,7 +1077,7 @@ class TIRFuseMutator : public ExprMutator { public: static IRModule Transform(IRModule mod) { // Collect all primitive relax functions - Map primitive_relax; + ffi::Map primitive_relax; for (const auto& gvar : mod->GetGlobalVars()) { const auto& base_func = mod->Lookup(gvar); // Only fuse primitive relax functions @@ -1134,7 +1136,7 @@ class TIRFuseMutator : public ExprMutator { struct Replacement { GlobalVar fused_tir_gvar; Function original_function; - Array inplace_indices; + ffi::Array inplace_indices; }; explicit TIRFuseMutator(std::unordered_map replacements) @@ -1145,14 +1147,14 @@ class TIRFuseMutator : public ExprMutator { // Get shape from call tir static Expr GetCallTIRShape(StructInfo sinfo) { if (auto* tuple = sinfo.as()) { - Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + ffi::Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); return Tuple(fields); } else { auto* tensor = sinfo.as(); ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; auto* shape_expr = tensor->shape.as(); ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; - return GetRef(shape_expr); + return ffi::GetRef(shape_expr); } } @@ -1185,8 +1187,8 @@ class TIRFuseMutator : public ExprMutator { // Step a. Collect all relax/symbolic arguments. Tuple arguments // are not supported by PrimFunc, so this step verifies that // ExpandTupleArguments has already removed them. - Array arg_list; - Array tir_vars; + ffi::Array arg_list; + ffi::Array tir_vars; for (size_t i = 0; i < call->args.size(); ++i) { auto arg = call->args[i]; auto sinfo = GetStructInfo(arg); @@ -1221,7 +1223,7 @@ class TIRFuseMutator : public ExprMutator { } // Step b. Create call_tir or call_tir_inplace - Array call_args = {fused_tir_gv, Tuple(arg_list)}; + ffi::Array call_args = {fused_tir_gv, Tuple(arg_list)}; if (!tir_vars.empty()) { call_args.push_back(ShapeExpr(tir_vars)); } @@ -1229,7 +1231,7 @@ class TIRFuseMutator : public ExprMutator { Attrs call_attrs = call->attrs; if (replacement.inplace_indices.size()) { call_op = call_tir_inplace_op_; - auto inplace_attrs = make_object(); + auto inplace_attrs = ffi::make_object(); inplace_attrs->inplace_indices = replacement.inplace_indices; call_attrs = Attrs(inplace_attrs); } diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index ff14dc9eef1e..e4af204d323f 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -160,7 +160,7 @@ class CheckpointCollector : private ExprMutator { ICHECK(var) << "The first argument of relax.grad.start_checkpoint and " "relax.grad.end_checkpoint should be a Var"; // var might already be remapped. Find the original var - auto orig_var = Downcast(ExprMutator::VisitExpr(GetRef(var))); + auto orig_var = Downcast(ExprMutator::VisitExpr(ffi::GetRef(var))); // Add remapping from binding->var to new_var if (!binding->var.as() && var->IsInstance()) { // For output binding, emit a dummy binding @@ -203,7 +203,7 @@ class CheckpointGenerator : private ExprMutator { * \param checkpoints The checkpointed vars. checkpoints being empty means all Vars are * checkpointed */ - CheckpointGenerator(const BlockBuilder& builder, const Array& orig_params, + CheckpointGenerator(const BlockBuilder& builder, const ffi::Array& orig_params, const DataflowBlock& forward_block, const VarIdSet& checkpoints) : builder_(builder) { // func params will always be checkpointed @@ -238,10 +238,10 @@ class CheckpointGenerator : private ExprMutator { using ExprMutator::VisitExpr_; // Visit the use-site of a defined Var - Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } + Expr VisitExpr_(const VarNode* op) final { return VisitVar(ffi::GetRef(op)); } // Visit the use-site of a defined DataflowVar - Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVar(GetRef(op)); } + Expr VisitExpr_(const DataflowVarNode* op) final { return VisitVar(ffi::GetRef(op)); } Expr VisitVar(const Var& var) { auto it = checkpoint_map_.find(var); @@ -258,7 +258,7 @@ class CheckpointGenerator : private ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { Expr new_op = this->VisitExpr(call_node->op); - tvm::Array call_args; + tvm::ffi::Array call_args; for (Expr arg : call_node->args) { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); @@ -268,9 +268,9 @@ class CheckpointGenerator : private ExprMutator { BlockBuilder builder_; // The mapping from the forward vars to the checkpoint vars. - Map checkpoint_map_; + ffi::Map checkpoint_map_; // The mapping from the forward vars to their bindings, used to generate checkpoint bindings - Map binding_map_; + ffi::Map binding_map_; }; /*! @@ -294,8 +294,8 @@ class BackwardBindingGenerator : private ExprVisitor { * \return The return expr of new adjoint function. */ static Expr Generate(const BlockBuilder& builder, const DataflowBlock& forward_block, - const Array& require_grads, const Var& target_var, - const Array& orig_params, const Expr& orig_return_value, + const ffi::Array& require_grads, const Var& target_var, + const ffi::Array& orig_params, const Expr& orig_return_value, const CheckpointCollector& cp_collector) { CheckpointGenerator checkpoint_generator(builder, orig_params, forward_block, cp_collector.checkpoints); @@ -358,7 +358,7 @@ class BackwardBindingGenerator : private ExprVisitor { // Support for checkpointing auto [checkpoint_var, checkpoint_call] = - checkpoint_generator_.UpdateBinding(binding->var, GetRef(call)); + checkpoint_generator_.UpdateBinding(binding->var, ffi::GetRef(call)); if (call_op == Op::Get("relax.call_tir")) { LOG(FATAL) << "Differentiation of call_tir op without registering corresponding gradient " @@ -384,7 +384,7 @@ class BackwardBindingGenerator : private ExprVisitor { } } } else { - const Array& partials = gradient_op_map[call_op]( + const ffi::Array& partials = gradient_op_map[call_op]( checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_); ICHECK(partials.size() == call->args.size()) << "partials number != inputs number"; for (size_t i = 0; i < partials.size(); ++i) { @@ -406,7 +406,7 @@ class BackwardBindingGenerator : private ExprVisitor { // b_adjoint += a_adjoint_var[0][0], c_adjoint += a_adjoint_var[0][1], // d_adjoint += a_adjoint_var[1] void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { - UpdateAdjoint(GetRef(tuple), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(tuple), adjoint_var_map_[binding->var]); } // For TupleGetItem nodes, we do a partial update @@ -422,7 +422,7 @@ class BackwardBindingGenerator : private ExprVisitor { const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { - auto nested_zeros = Downcast(NestedZeros(GetRef(tuple_sinfo))); + auto nested_zeros = Downcast(NestedZeros(ffi::GetRef(tuple_sinfo))); auto tuple_fields = nested_zeros->fields; tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, Tuple(tuple_fields), false); @@ -435,11 +435,11 @@ class BackwardBindingGenerator : private ExprVisitor { // For assign nodes, we add the adjoint of output to the adjoint of input void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* var) final { - UpdateAdjoint(GetRef(var), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(var), adjoint_var_map_[binding->var]); } void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - UpdateAdjoint(GetRef(var), adjoint_var_map_[binding->var]); + UpdateAdjoint(ffi::GetRef(var), adjoint_var_map_[binding->var]); } // For constant nodes, we do not have to handle it because it does not contribute to the adjoint @@ -479,9 +479,9 @@ class BackwardBindingGenerator : private ExprVisitor { // Returns the new return value, which would be like: // Tuple(original_return_value, // Tuple(adjoint_of_require_grads_1, adjoint_of_require_grads_2, ...)) - Expr Epilogue(const Array& require_grads, const Expr& orig_return_value) { + Expr Epilogue(const ffi::Array& require_grads, const Expr& orig_return_value) { // create adjoint variables for inputs, and then bind adjoints - Array out_adjoints; + ffi::Array out_adjoints; for (Var var : require_grads) { // var might be wrapped in start_checkpoint or end_checkpoint, so we should find the original @@ -520,7 +520,7 @@ class BackwardBindingGenerator : private ExprVisitor { } static Expr AdjointMsgToExpr(AdjointMsg msg) { - return NestedMsgToExpr(msg, [](Optional leaf_expr) { + return NestedMsgToExpr(msg, [](ffi::Optional leaf_expr) { if (!leaf_expr.defined()) { LOG(FATAL) << "Null should not exist in AdjointMsg."; } @@ -559,7 +559,7 @@ class BackwardBindingGenerator : private ExprVisitor { ICHECK(GetStructInfoAs(r_leaf)) << "The leaf of adjoint should have StructInfo and be a Tensor."; Expr res = add(l_leaf, r_leaf); - UpdateStructInfo(res, GetRef(sinfo)); + UpdateStructInfo(res, ffi::GetRef(sinfo)); return res; }); return AdjointMsgToExpr(res); @@ -575,7 +575,7 @@ class BackwardBindingGenerator : private ExprVisitor { auto* sinfo = GetStructInfoAs(tuple); ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); - Array res; + ffi::Array res; for (size_t i = 0; i < sinfo->fields.size(); ++i) { Expr field; if (const auto* expr_tuple = tuple.as()) { @@ -594,7 +594,7 @@ class BackwardBindingGenerator : private ExprVisitor { // The block builder of the corresponding GradientMutator, to emit bindings BlockBuilder builder_; // Forward Var to its adjoint Var - Map adjoint_var_map_; + ffi::Map adjoint_var_map_; // information collected by CheckpointCollector CheckpointCollector cp_collector_; // The generator for checkpoint bindings @@ -603,13 +603,13 @@ class BackwardBindingGenerator : private ExprVisitor { class GradientMutator : private ExprMutator { public: - static IRModule Transform(IRModule mod, String func_name, Optional> require_grads, - int target_index) { + static IRModule Transform(IRModule mod, ffi::String func_name, + ffi::Optional> require_grads, int target_index) { // Step 1. Copy function auto* old_func = mod->Lookup(func_name).as(); CHECK(old_func) << func_name << "is not a Relax Function"; auto copier = FunctionCopier(); - auto new_func = copier.Copy(GetRef(old_func)); + auto new_func = copier.Copy(ffi::GetRef(old_func)); // Step 2. Handle the checkpoints and eliminate start_checkpoint and end_checkpoint ops auto cp_collector = CheckpointCollector(); @@ -630,7 +630,7 @@ class GradientMutator : private ExprMutator { } private: - GradientMutator(const IRModule& module, const Array& require_grads, int target_index, + GradientMutator(const IRModule& module, const ffi::Array& require_grads, int target_index, const CheckpointCollector& cp_collector) : ExprMutator(module), require_grads_(require_grads), @@ -638,7 +638,7 @@ class GradientMutator : private ExprMutator { target_index_(target_index) {} // Add the adjoint function of func to the IRModule using BlockBuilder - IRModule AddAdjointFunction(const Function& func, const String& func_name, + IRModule AddAdjointFunction(const Function& func, const ffi::String& func_name, bool remove_all_unused = true) { // Step 4.1 forward -> forward + backward auto new_func = Downcast(VisitExpr(func)); @@ -695,7 +695,7 @@ class GradientMutator : private ExprMutator { } // generate backward bindings and the return value - return_expr_ = BackwardBindingGenerator::Generate(builder_, GetRef(block), + return_expr_ = BackwardBindingGenerator::Generate(builder_, ffi::GetRef(block), require_grads_, target_var_, orig_params_, orig_return_expr_, cp_collector_); @@ -715,7 +715,7 @@ class GradientMutator : private ExprMutator { CHECK_EQ(target_index, 0) << "When the function has only one return value, target_index can " "only be 0. But the target_index specified is " << target_index; - target_var_ = GetRef(var); + target_var_ = ffi::GetRef(var); } else if (auto* tuple = e.as()) { CHECK(target_index >= 0 && target_index < static_cast(tuple->fields.size())) << "target_index should be in the range of the number of return values of the " @@ -725,7 +725,7 @@ class GradientMutator : private ExprMutator { auto* var = tuple->fields[target_index].as(); CHECK(var) << "Target must be a Var, but the specified target is " << tuple->fields[target_index]; - target_var_ = GetRef(var); + target_var_ = ffi::GetRef(var); } else { LOG(FATAL) << "The return value of the function must be Var or Tuple. However, the return " "value of the given function is " @@ -742,10 +742,11 @@ class GradientMutator : private ExprMutator { // 1. there should be no duplicate var // 2. every var should be a parameter or a intermediate var in the function // 3. the type of the input var should be Tensor of floating point dtype, or Tuple of that - static Array CheckAndMapRequireGrads(const Array& require_grads, - const Map& var_map, const String& func_name) { + static ffi::Array CheckAndMapRequireGrads(const ffi::Array& require_grads, + const ffi::Map& var_map, + const ffi::String& func_name) { VarIdSet var_set; - Array mapped_vars; + ffi::Array mapped_vars; for (const auto& var : require_grads) { auto it = var_map.find(var); CHECK(it != var_map.end()) << "There is no Var named " << var->name_hint() @@ -764,21 +765,22 @@ class GradientMutator : private ExprMutator { } // differentiation sources - Array require_grads_; + ffi::Array require_grads_; // information collected by CheckpointCollector CheckpointCollector cp_collector_; // the differentiation target int target_index_; Var target_var_; // the return value of the original function and the differentiated function - Array orig_params_; + ffi::Array orig_params_; Expr orig_return_expr_; Expr return_expr_; }; namespace transform { -Pass Gradient(String func_name, Optional> require_grads, int target_index) { +Pass Gradient(ffi::String func_name, ffi::Optional> require_grads, + int target_index) { auto pass_func = [=](IRModule mod, PassContext pc) { return relax::GradientMutator::Transform(mod, func_name, require_grads, target_index); }; diff --git a/src/relax/transform/gradient_simplifier.cc b/src/relax/transform/gradient_simplifier.cc index 966e8b7ad692..5388e3706542 100644 --- a/src/relax/transform/gradient_simplifier.cc +++ b/src/relax/transform/gradient_simplifier.cc @@ -112,7 +112,7 @@ class GradientSimplifier : private ExprMutator { if (ndim == 1) { return expr; } - auto axes = Array(); + auto axes = ffi::Array(); for (int i = 0; i < ndim - 2; ++i) { axes.push_back(i); } @@ -140,7 +140,7 @@ class GradientSimplifier : private ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { - auto result = ExprMutator::VisitExpr(GetRef(call_node)); + auto result = ExprMutator::VisitExpr(ffi::GetRef(call_node)); auto new_call_node = result.as(); auto reemit_and_return = [&]() { ReEmitBinding(binding, result); diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 43bb40b4df4a..ac838d584821 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -31,33 +31,33 @@ NType NTypeFrom(const StructInfo& sinfo, DataType dtype) { else return NType(DLDataTypeToString(dtype)); }; - return MapToNestedMsg(sinfo, fmapleaf); + return MapToNestedMsg(sinfo, fmapleaf); } NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetStructInfo(expr), dtype); } NType NTypeMerge(const NType& a, const NType& b) { - auto fcombine = [&](const String& a_str, const String& b_str) -> String { + auto fcombine = [&](const ffi::String& a_str, const ffi::String& b_str) -> ffi::String { if (a_str == "") { return b_str; } else if (b_str == "") { return a_str; } - DataType a = DataType(StringToDLDataType(a_str)); - DataType b = DataType(StringToDLDataType(b_str)); + DataType a = DataType(ffi::StringToDLDataType(a_str)); + DataType b = DataType(ffi::StringToDLDataType(b_str)); ICHECK_EQ(a.code(), b.code()); ICHECK_EQ(a.lanes(), b.lanes()); return a.bits() > b.bits() ? a_str : b_str; }; - return CombineNestedMsg(a, b, fcombine); + return CombineNestedMsg(a, b, fcombine); } -Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { return {Integer(MixedPrecisionPolicyKind::kFollow), call}; } -Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { return {Integer(MixedPrecisionPolicyKind::kNever), call}; } diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index a3a86dd2e0c3..e8ac586036a8 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -49,11 +49,11 @@ using TMixedPrecisionPolicy = int; // NType is the message we want to track for vars with nested tensorstructinfo // which represents the realization decision of the var. // The string is the name of the dtype decision. -using NType = NestedMsg; +using NType = NestedMsg; struct NTypeEqual { bool operator()(const NType& a, const NType& b) const { - auto dtype_equal = [](const String& a, const String& b) { return a == b; }; + auto dtype_equal = [](const ffi::String& a, const ffi::String& b) { return a == b; }; return Equal(a, b, dtype_equal); } }; @@ -74,9 +74,9 @@ using VarDTypeMap = std::unordered_map; using FInferMixedPrecision = ffi::TypedFunction; -Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); -Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index b2f647c5c229..ea0bd2474913 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -67,7 +67,7 @@ Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) return Layout(axes); } -String TransposeStrLike(const String& input, const Layout& src, const Layout& dst) { +ffi::String TransposeStrLike(const ffi::String& input, const Layout& src, const Layout& dst) { ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim()) << "Layouts must have the same size"; std::string axes; @@ -120,7 +120,7 @@ LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) { auto fmapleaf = [&](const Expr& expr) -> NLayout { if (const auto* var = expr.as()) { - auto it = var_layout_map.find(GetRef(var)); + auto it = var_layout_map.find(ffi::GetRef(var)); if (it != var_layout_map.end()) { return (*it).second; } else { @@ -134,7 +134,8 @@ NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg) { return MapToNestedMsg(arg, fmapleaf); } -bool NoDesiredLayout(const Call& call, const Map>& desired_layouts) { +bool NoDesiredLayout(const Call& call, + const ffi::Map>& desired_layouts) { const OpNode* op_node = call->op.as(); if (op_node == nullptr) return false; const auto& it = desired_layouts.find(op_node->name); diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 69148ce0601f..91590b76ef1f 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -77,7 +77,7 @@ class LayoutDecisionNode : public Object { class LayoutDecision : public ObjectRef { public: LayoutDecision(Layout layout, bool is_unknown_dim = false) { // NOLINT(*) - auto n = make_object(); + auto n = ffi::make_object(); n->layout = std::move(layout); n->is_unknown_dim = is_unknown_dim; data_ = n; @@ -105,10 +105,10 @@ using NLayout = NestedMsg; */ class InferLayoutOutputNode : public Object { public: - Array input_layouts; - Array output_layouts; + ffi::Array input_layouts; + ffi::Array output_layouts; Attrs new_attrs; - Map new_args; + ffi::Map new_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -126,9 +126,9 @@ class InferLayoutOutputNode : public Object { class InferLayoutOutput : public ObjectRef { public: - explicit InferLayoutOutput(Array input_layouts, Array output_layouts, - Attrs new_attrs, Map new_args = {}) { - auto n = make_object(); + explicit InferLayoutOutput(ffi::Array input_layouts, ffi::Array output_layouts, + Attrs new_attrs, ffi::Map new_args = {}) { + auto n = ffi::make_object(); n->input_layouts = std::move(input_layouts); n->output_layouts = std::move(output_layouts); n->new_attrs = std::move(new_attrs); @@ -150,7 +150,7 @@ struct NLayoutEqual { } }; -using VarLayoutMap = Map; +using VarLayoutMap = ffi::Map; /*! * \brief Layout conversion interface. @@ -159,7 +159,7 @@ using VarLayoutMap = Map; * \param var_layout_map The layout of the variables. */ using FRelaxInferLayout = ffi::TypedFunction>& desired_layouts, + const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map)>; /*! @@ -225,7 +225,7 @@ Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst); * \param dst The destination layout. * \return The transposed input str. */ -String TransposeStrLike(const String& input, const Layout& src, const Layout& dst); +ffi::String TransposeStrLike(const ffi::String& input, const Layout& src, const Layout& dst); /*! * \brief Find axis in the dst layout. 0 represents the first axis, 1 represents the second axis, @@ -258,7 +258,8 @@ NLayout GetNLayout(const VarLayoutMap& var_layout_map, const Expr& arg); * \param desired_layouts The desired layouts of the operator. * \return True if the op is not in the desired layout. */ -bool NoDesiredLayout(const Call& call, const Map>& desired_layouts); +bool NoDesiredLayout(const Call& call, + const ffi::Map>& desired_layouts); /*! * \brief Let a tensor with ndim to follow the src layout decision. diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 44363e19464f..e2ab8c1b663c 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -35,7 +35,8 @@ namespace { class FunctionInliner : public ExprMutator { public: - explicit FunctionInliner(const Map, Function>& replacements) + explicit FunctionInliner( + const ffi::Map, Function>& replacements) : replacements_(replacements) {} using ExprMutator::VisitExpr_; @@ -80,7 +81,7 @@ class FunctionInliner : public ExprMutator { } private: - Optional GetFunction(const GlobalVar& gvar) const { + ffi::Optional GetFunction(const GlobalVar& gvar) const { if (auto opt = replacements_.Get(gvar)) { return opt; } else if (auto opt = replacements_.Get(gvar->name_hint)) { @@ -90,14 +91,14 @@ class FunctionInliner : public ExprMutator { } } - Expr InlinedCall(Function func, const Array& args) const { + Expr InlinedCall(Function func, const ffi::Array& args) const { // Ensures that the inlined instance does not have duplicate usage // with other inlined copies, or with the original callee. func = CopyWithNewVars(std::move(func)); - Array param_bindings; + ffi::Array param_bindings; - Map param_map; + ffi::Map param_map; for (size_t i = 0; i < args.size(); i++) { // Option 1: Use tvm::relax::Bind to substitute arguments into // the body. If the arguments contain DataflowVar instances, @@ -138,7 +139,7 @@ class FunctionInliner : public ExprMutator { return SeqExpr({binding_block}, body); } - const Map, Function>& replacements_; + const ffi::Map, Function>& replacements_; std::unordered_set inline_stack_; }; } // namespace @@ -149,8 +150,8 @@ class FunctionInliner : public ExprMutator { * \param params params dict * \return Function */ -Function FunctionInlineFunctions(Function func, - const Map, Function>& replacements) { +Function FunctionInlineFunctions( + Function func, const ffi::Map, Function>& replacements) { for (const auto& [key, func] : replacements) { if (auto ptr = key.as()) { CHECK(!replacements.count(ptr->name_hint)) @@ -174,11 +175,11 @@ namespace transform { Pass InlinePrivateFunctions() { auto pass_func = [=](IRModule mod, PassContext pc) { - Map, Function> replacements; + ffi::Map, Function> replacements; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto func = opt.value(); - bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_private) { replacements.Set(gvar, func); } diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 8c3b76703d8e..7b6e8e502214 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -169,7 +169,8 @@ class CollectLastUsage : public ExprVisitor { << "Operator " << val->op << " should have one argument, " << "but instead found " << val->args.size() << " arguments: " << val->args; auto killed_object = val->args[0].as(); - ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef(val); + ICHECK(killed_object) << "Internal error: non-normalized expression " + << ffi::GetRef(val); killed_objects_.insert(killed_object); } else { // Only recursively visit if it isn't one of the special cases. @@ -213,14 +214,14 @@ class CollectLastUsage : public ExprVisitor { class KillInserter : public ExprMutator { private: Expr VisitExpr_(const FunctionNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); + last_usage_ = CollectLastUsage::Collect(ffi::GetRef(op)); auto mutated = ExprMutator::VisitExpr_(op); last_usage_.clear(); return mutated; } Expr VisitExpr_(const SeqExprNode* op) override { - last_usage_ = CollectLastUsage::Collect(GetRef(op)); + last_usage_ = CollectLastUsage::Collect(ffi::GetRef(op)); auto mutated = ExprMutator::VisitExpr_(op); last_usage_.clear(); return mutated; @@ -231,17 +232,17 @@ class KillInserter : public ExprMutator { if (auto it = last_usage_.find(binding->var.get()); it != last_usage_.end()) { static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); for (const auto& tensor_obj : it->second.tensors) { - builder_->Emit(Call(mem_kill_tensor, {GetRef(tensor_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(mem_kill_tensor, {ffi::GetRef(tensor_obj)}), /*name_hint=*/"_"); } static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); for (const VarNode* storage_obj : it->second.storage) { - builder_->Emit(Call(mem_kill_storage, {GetRef(storage_obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(mem_kill_storage, {ffi::GetRef(storage_obj)}), /*name_hint=*/"_"); } static const Op& vm_kill_object = Op::Get("relax.vm.kill_object"); for (const VarNode* obj : it->second.objects) { - builder_->Emit(Call(vm_kill_object, {GetRef(obj)}), /*name_hint=*/"_"); + builder_->Emit(Call(vm_kill_object, {ffi::GetRef(obj)}), /*name_hint=*/"_"); } } } diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 1fd82b1cc610..fe8d28964dd5 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -40,7 +40,7 @@ namespace { /* \brief Collect names of functions to be lifted out */ class LambdaNameCollector : ExprVisitor { public: - static std::unordered_map Collect(const IRModule& mod) { + static std::unordered_map Collect(const IRModule& mod) { LambdaNameCollector visitor; for (const auto& [gvar, base_func] : mod->functions) { @@ -60,8 +60,8 @@ class LambdaNameCollector : ExprVisitor { private: void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { - String public_name = opt.value(); + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + ffi::String public_name = opt.value(); // If a kGlobalSymbol exists, we must use the name exactly as it // appears, with no modifications. Because these errors would @@ -102,21 +102,22 @@ class LambdaNameCollector : ExprVisitor { } // De-duplication of collected names - std::unordered_map Finalize() const { + std::unordered_map Finalize() const { // The functions which still must be assigned a name - std::unordered_map> remaining_to_name = lambda_location_; + std::unordered_map> remaining_to_name = + lambda_location_; // Collecting the functions that now have a name. - std::unordered_map lifted_names; + std::unordered_map lifted_names; // A lookup for names that are unavailable for use. - std::unordered_set unavailable_names = previous_global_vars_; + std::unordered_set unavailable_names = previous_global_vars_; // A helper function to generate de-duplicated names. The // `proposed_name_generation_func` should be a function with // signature: // - // Optional func(const FunctionNode*, const Array&) + // ffi::Optional func(const FunctionNode*, const ffi::Array&) // // The first argument will be the lambda function being lifted. // The second argument will be the nested location where that @@ -135,9 +136,10 @@ class LambdaNameCollector : ExprVisitor { return; } - std::unordered_map new_names; + std::unordered_map new_names; for (const auto& [func, location] : remaining_to_name) { - if (Optional opt_proposed_name = proposed_name_generation_func(func, location)) { + if (ffi::Optional opt_proposed_name = + proposed_name_generation_func(func, location)) { auto proposed_name = opt_proposed_name.value(); if (unavailable_names.count(proposed_name)) { @@ -163,7 +165,8 @@ class LambdaNameCollector : ExprVisitor { }; // 1. Start with any publicly explosed names from kGlobalSymbol - attempt_name_generation([&](const FunctionNode* func, const auto&) -> Optional { + attempt_name_generation([&](const FunctionNode* func, + const auto&) -> ffi::Optional { if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { return it->second; } else { @@ -173,7 +176,7 @@ class LambdaNameCollector : ExprVisitor { // 2. Try concatenating the name of the relax variable with the // name of the function that contains it. - attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + attempt_name_generation([&](const FunctionNode*, const auto& location) -> ffi::String { std::stringstream stream; stream << location.front() << "_" << location.back(); return stream.str(); @@ -181,26 +184,27 @@ class LambdaNameCollector : ExprVisitor { // 3. Try concatenating the entire path together. Don't include // paths of length 2, as they would already be attempted earlier. - attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional { - if (location.size() == 2) return std::nullopt; - - std::stringstream stream; - bool is_first = true; - for (const auto& loc : location) { - if (is_first) { - is_first = false; - } else { - stream << "_"; - } - stream << loc; - } - return String(stream.str()); - }); + attempt_name_generation( + [&](const FunctionNode*, const auto& location) -> ffi::Optional { + if (location.size() == 2) return std::nullopt; + + std::stringstream stream; + bool is_first = true; + for (const auto& loc : location) { + if (is_first) { + is_first = false; + } else { + stream << "_"; + } + stream << loc; + } + return ffi::String(stream.str()); + }); // 4. Fallback. Count the number of times a relax variable with // that name was used. - std::unordered_map usage_count; - attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + std::unordered_map usage_count; + attempt_name_generation([&](const FunctionNode*, const auto& location) -> ffi::String { std::stringstream stream; stream << location.front() << "_" << location.back(); int usage = usage_count[stream.str()]++; @@ -215,11 +219,11 @@ class LambdaNameCollector : ExprVisitor { return lifted_names; } - Array name_stack_; - std::unordered_set previous_global_vars_; - std::unordered_map> new_public_names_; - std::unordered_map lifted_with_global_symbol_; - std::unordered_map> lambda_location_; + ffi::Array name_stack_; + std::unordered_set previous_global_vars_; + std::unordered_map> new_public_names_; + std::unordered_map lifted_with_global_symbol_; + std::unordered_map> lambda_location_; }; } // namespace @@ -255,9 +259,9 @@ class LambdaLifter : public ExprMutator { return ExprMutator::VisitExpr_(func_node); } - auto func = GetRef(func_node); + auto func = ffi::GetRef(func_node); - String lift_func_name = [&]() { + ffi::String lift_func_name = [&]() { auto it = lifted_names_.find(func_node); ICHECK(it != lifted_names_.end()) << "InternalError: " @@ -266,7 +270,7 @@ class LambdaLifter : public ExprMutator { return it->second; }(); - Array captured_vars; + ffi::Array captured_vars; bool is_recursive = false; bool is_closure = false; for (const auto& var : FreeVars(func)) { @@ -278,15 +282,15 @@ class LambdaLifter : public ExprMutator { } } - Array typed_captured_vars; - Map rebinding_map; + ffi::Array typed_captured_vars; + ffi::Map rebinding_map; for (auto free_var : captured_vars) { Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } - tvm::Array lifted_func_params = + tvm::ffi::Array lifted_func_params = func_node->params.Map([this](Var var) { return VisitVarDef(var); }); for (const auto& var : typed_captured_vars) { lifted_func_params.push_back(var); @@ -323,7 +327,7 @@ class LambdaLifter : public ExprMutator { Function lifted_func; if (lifted_func_params.same_as(func_node->params) && body.same_as(func_node->body) && ret_struct_info.same_as(func_node->ret_struct_info)) { - lifted_func = GetRef(func_node); + lifted_func = ffi::GetRef(func_node); } else { lifted_func = Function(lifted_func_params, body, ret_struct_info, func_node->is_pure, func_node->attrs); @@ -354,7 +358,7 @@ class LambdaLifter : public ExprMutator { } Expr VisitExpr_(const CallNode* call_node) final { - auto call = GetRef(call_node); + auto call = ffi::GetRef(call_node); auto orig_sinfo = Downcast(call->struct_info_); @@ -393,7 +397,7 @@ class LambdaLifter : public ExprMutator { if (auto it = nested_closure_map_.find(var); it != nested_closure_map_.end()) { Call nested_call = it->second; - Array new_args = call->args; + ffi::Array new_args = call->args; for (const auto arg : nested_call->args) { new_args.push_back(arg); } @@ -407,7 +411,7 @@ class LambdaLifter : public ExprMutator { } Expr VisitExpr_(const VarNode* op) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); if (auto it = rebind_map_.find(var); it != rebind_map_.end()) { return it->second; } @@ -436,12 +440,12 @@ class LambdaLifter : public ExprMutator { } } else if (const auto* global_var = val.as()) { - if (closures_.count(GetRef(global_var))) { + if (closures_.count(ffi::GetRef(global_var))) { return true; } IRModule ctx_mod = builder_->GetContextIRModule(); ICHECK(ctx_mod->functions.size() > 0); - BaseFunc func = ctx_mod->Lookup(GetRef(global_var)); + BaseFunc func = ctx_mod->Lookup(ffi::GetRef(global_var)); const auto* func_node = func.as(); if (func_node) { return IsClosure(func_node->body); @@ -477,11 +481,11 @@ class LambdaLifter : public ExprMutator { private: std::unordered_map nested_closure_map_; std::unordered_map rebind_map_; - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; - Optional current_lambda_var_ = std::nullopt; + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; + ffi::Optional current_lambda_var_ = std::nullopt; IRModule mod_; - std::unordered_map lifted_names_; + std::unordered_map lifted_names_; /*! \brief Cache ops that would be used later to reduce lookup overhead. */ const Op& make_closure_op_ = Op::Get("relax.make_closure"); diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 9b59b680eceb..61e36fae69bc 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -69,15 +69,15 @@ class LazyInputMutator : public ExprMutator { FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, ObjectStructInfo())); - Array new_params(func->params.begin(), func->params.begin() + num_input_params); + ffi::Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); auto array_externally_visible_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), array_externally_visible_vars.end()); - StructInfo new_ret_struct_info = - EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional { + StructInfo new_ret_struct_info = EraseToWellDefined( + func->ret_struct_info, [&](const tir::Var& var) -> ffi::Optional { if (externally_visible_vars.count(var)) { return var; } else { @@ -85,7 +85,7 @@ class LazyInputMutator : public ExprMutator { } }); - auto node = GetRef(func); + auto node = ffi::GetRef(func); node.CopyOnWrite()->params = new_params; node.CopyOnWrite()->ret_struct_info = new_ret_struct_info; node = WithAttr(node, attr::kNumInput, num_input_params + 1); @@ -98,7 +98,7 @@ class LazyInputMutator : public ExprMutator { Expr VisitExpr_(const VarNode* op) override { if (plan_) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (auto it = plan_->param_lookup.find(var); it != plan_->param_lookup.end()) { auto untyped = builder_->Emit(relax::Call(plan_->fget_param, @@ -148,9 +148,10 @@ class LazyOutputMutator : public ExprMutator { define_lookup(0, func_body->body); } - Var fset_output("fset_output", - FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - TupleStructInfo(Array{}), /* purity = */ false)); + Var fset_output( + "fset_output", + FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, + TupleStructInfo(ffi::Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; std::optional num_input_params = GetNumInputParams(func); @@ -160,32 +161,32 @@ class LazyOutputMutator : public ExprMutator { fset_output); BindingBlock start_of_func = [&]() { - Array propagated_params; + ffi::Array propagated_params; for (auto param : func->params) { GenerateSetOutputCalls(param, [&](const auto& fset_output_call) { - Var void_output("_void", TupleStructInfo(Array{})); + Var void_output("_void", TupleStructInfo(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); }); } return BindingBlock(propagated_params); }(); BindingBlock end_of_func = [&]() { - Array propagated_params; + ffi::Array propagated_params; for (const auto& [output_index, expr] : inline_outputs) { Call fset_output_call(fset_output, {PrimValue(IntImm(DataType::Int(64), output_index)), expr}); - Var void_output("_void", TupleStructInfo(Array{})); + Var void_output("_void", TupleStructInfo(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); } return BindingBlock(propagated_params); }(); - Array new_blocks = func_body->blocks; + ffi::Array new_blocks = func_body->blocks; new_blocks.insert(new_blocks.begin(), start_of_func); new_blocks.push_back(end_of_func); - Expr new_body = SeqExpr(new_blocks, Tuple(Array{})); + Expr new_body = SeqExpr(new_blocks, Tuple(ffi::Array{})); - auto node = GetRef(func); + auto node = ffi::GetRef(func); { auto write_ptr = node.CopyOnWrite(); write_ptr->params = new_params; @@ -249,7 +250,7 @@ namespace transform { Pass LazyGetInput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyInputs(func); @@ -267,7 +268,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyOutputs(func); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 780de9f57029..c3544314a774 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -60,7 +60,8 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: - explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, + explicit LegalizeMutator(const IRModule& mod, + const ffi::Optional>& cmap, bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { @@ -130,14 +131,14 @@ class LegalizeMutator : public ExprMutator { Call WrapPureCall(const Call& ret) { static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); - Array ret_args = {ret->op}; + ffi::Array ret_args = {ret->op}; for (auto arg : ret->args) { ret_args.push_back(arg); } return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); } - Optional GetTarget(const Array& sinfos) { + ffi::Optional GetTarget(const ffi::Array& sinfos) { for (auto sinfo : sinfos) { if (const auto* tinfo = sinfo.as()) { if (tinfo->vdevice.defined()) { @@ -236,7 +237,7 @@ class LegalizeMutator : public ExprMutator { if (op_node == nullptr) { return visited_call; } - auto op = GetRef(op_node); + auto op = ffi::GetRef(op_node); bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; @@ -312,7 +313,7 @@ class LegalizeMutator : public ExprMutator { legalization_func = legalize_map[op]; } else if (call_packed_map.count(op)) { // Third choice, use an explicit FCallPacked replacement. This does not require the shape - String packed_func_name = call_packed_map[op]; + ffi::String packed_func_name = call_packed_map[op]; legalization_func = [packed_func_name](const BlockBuilder& bb, const Call& call) -> Expr { return Call(ExternFunc(packed_func_name), call->args, Attrs(), {GetStructInfo(call)}); }; @@ -378,7 +379,7 @@ class LegalizeMutator : public ExprMutator { /*! \brief The context IRModule. */ IRModule mod_; /*! \brief The customized legalization function map. */ - Map cmap_; + ffi::Map cmap_; /*! \brief If VDevice annotations produced at least one PrimFunc with a Target attr*/ bool generated_tir_with_target_attr_{false}; /*! @@ -390,7 +391,7 @@ class LegalizeMutator : public ExprMutator { namespace transform { -Pass LegalizeOps(Optional> cmap, bool enable_warning) { +Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 40a1c307cee5..16a50a19a3e3 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -64,20 +64,20 @@ struct BaseCollectInfo { * model weights, and computed tensors that require neither model * weights nor runtime arguments (e.g. `R.zeros([16], "float16")`). */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> requires_compile_time_param; /*! \brief Variables that are required at runtime */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> required_at_runtime; protected: - Array GetCompileTimeOutputsHelper(const Array& params) const { + ffi::Array GetCompileTimeOutputsHelper(const ffi::Array& params) const { // The output of the compile-time function is in the following order: // 1) Any parameter that is required at runtime in the original order, followed by, // 2) Any binding that is computable at compile-time and required at runtime in the original // order. - Array output; + ffi::Array output; for (const auto& param : params) { if (required_at_runtime.count(param)) { output.push_back(param); @@ -93,11 +93,12 @@ struct BaseCollectInfo { return output; } - Function MakeCompileTimeFunctionHelper(const Array params, const Array& bindings, - const Array& output_symbolic_vars, - const Array& outputs) const { - Array output_var_binding; - Array output_exprs; + Function MakeCompileTimeFunctionHelper(const ffi::Array params, + const ffi::Array& bindings, + const ffi::Array& output_symbolic_vars, + const ffi::Array& outputs) const { + ffi::Array output_var_binding; + ffi::Array output_exprs; if (output_symbolic_vars.size()) { output_exprs.push_back( ShapeExpr(output_symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; }))); @@ -131,14 +132,14 @@ struct BaseCollectInfo { struct GlobalCollectInfo : public BaseCollectInfo { // The original functions - Array orig_functions; + ffi::Array orig_functions; // The parameters of the compile-time function. - Array params; + ffi::Array params; // The cross-function mapping between variables. - Map var_remap; + ffi::Map var_remap; // The cross-function between between TIR variables. - Map tir_var_remap; - Array GetPropagatedSymbolicVariables() const { + ffi::Map tir_var_remap; + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_original_params = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); auto vars_from_transformed_params = [&]() -> std::unordered_set { @@ -147,7 +148,7 @@ struct GlobalCollectInfo : public BaseCollectInfo { return {tir_vars.begin(), tir_vars.end()}; }(); - Array output; + ffi::Array output; for (const auto& tir_var : vars_from_original_params) { if (required_at_runtime.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { output.push_back(tir_var); @@ -160,7 +161,7 @@ struct GlobalCollectInfo : public BaseCollectInfo { return MakeCompileTimeFunctionHelper(params, computable_at_compile_time, GetPropagatedSymbolicVariables(), GetCompileTimeOutputs()); } - Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(params); } + ffi::Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(params); } }; struct LocalCollectInfo : public BaseCollectInfo { /* \brief The analyzed function */ @@ -171,15 +172,16 @@ struct LocalCollectInfo : public BaseCollectInfo { GlobalCollectInfo* global_info = nullptr; - Array GetCompileTimeInputs() const { - return Array(orig_func->params.begin() + num_runtime_params, orig_func->params.end()); + ffi::Array GetCompileTimeInputs() const { + return ffi::Array(orig_func->params.begin() + num_runtime_params, orig_func->params.end()); } - Array GetRuntimeInputs() const { - return Array(orig_func->params.begin(), orig_func->params.begin() + num_runtime_params); + ffi::Array GetRuntimeInputs() const { + return ffi::Array(orig_func->params.begin(), + orig_func->params.begin() + num_runtime_params); } - Array GetPropagatedSymbolicVariables() const { + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_any_param = DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo))); @@ -195,7 +197,7 @@ struct LocalCollectInfo : public BaseCollectInfo { return {tir_var_vec.begin(), tir_var_vec.end()}; }(); - Array output; + ffi::Array output; for (const auto& tir_var : vars_from_any_param) { if (required_at_runtime.count(tir_var) && !vars_from_runtime_params.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { @@ -205,7 +207,7 @@ struct LocalCollectInfo : public BaseCollectInfo { return output; } - Array GetCompileTimeOutputs() const { + ffi::Array GetCompileTimeOutputs() const { return GetCompileTimeOutputsHelper(GetCompileTimeInputs()); } @@ -216,29 +218,29 @@ struct LocalCollectInfo : public BaseCollectInfo { } Function MakeRuntimeFunction() const { - Array bindings; + ffi::Array bindings; // Any parameter that isn't available until runtime must be an // input, along with any output from the compile-time function. // Compile-time outputs must have a fresh non-dataflow var to // serve as the parameter. This trivial binding will later be // removed with CanonicalizeBindings. - Array params = GetRuntimeInputs(); + ffi::Array params = GetRuntimeInputs(); auto propagated_tir_vars = [&]() { - Array local_tir_vars = GetPropagatedSymbolicVariables(); + ffi::Array local_tir_vars = GetPropagatedSymbolicVariables(); if (!global_info) { return local_tir_vars; } // When global lifting is enabled, the compile-time outputs are the global outputs, but the // variables in the global outputs to the local variables. - Map reverse_map; + ffi::Map reverse_map; for (const auto& var : local_tir_vars) { if (auto it = global_info->tir_var_remap.find(var); it != global_info->tir_var_remap.end()) { reverse_map.Set(Downcast((*it).second), var); } } - Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); + ffi::Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); global_tir_vars = global_tir_vars.Map([&](const tir::Var& var) { if (auto it = reverse_map.find(var); it != reverse_map.end()) { return Downcast((*it).second); @@ -256,20 +258,20 @@ struct LocalCollectInfo : public BaseCollectInfo { Var shape_expr("vars_from_compile_time_params", shape_sinfo); params.push_back(shape_expr); } - Array compile_time_outputs = [&]() { - Array local_outputs = GetCompileTimeOutputs(); + ffi::Array compile_time_outputs = [&]() { + ffi::Array local_outputs = GetCompileTimeOutputs(); if (!global_info) { return local_outputs; } // When global lifting is enabled, the compile-time outputs are the global outputs, but the // variables in the global outputs to the local variables. - Map reverse_map; + ffi::Map reverse_map; for (const auto& var : local_outputs) { if (auto it = global_info->var_remap.find(var); it != global_info->var_remap.end()) { reverse_map.Set(Downcast((*it).second), var); } } - Array global_outputs = global_info->GetCompileTimeOutputs(); + ffi::Array global_outputs = global_info->GetCompileTimeOutputs(); global_outputs = global_outputs.Map([&](const Var& var) { if (auto it = reverse_map.find(var); it != reverse_map.end()) { return Downcast((*it).second); @@ -378,7 +380,7 @@ class BaseLiftableBindingCollector : public ExprVisitor { return true; } - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; bool is_in_dataflow_block_{false}; }; @@ -389,32 +391,31 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { visitor(func); visitor.info_.orig_func = func; - auto set_union = - [&](std::unordered_set, ObjectPtrHash, ObjectPtrEqual>& - target_set, - const std::unordered_set, ObjectPtrHash, ObjectPtrEqual>& - source_set, - const Map& var_remap, const Map& tir_var_remap) { - // In-place update the set in global info by unioning with the local set, variable - // mappings are applied. - for (const auto& relax_or_tir_var : source_set) { - if (relax_or_tir_var.as()) { - if (auto it = var_remap.find(Downcast(relax_or_tir_var)); - it != var_remap.end()) { - target_set.insert(Downcast((*it).second)); - } else { - target_set.insert(Downcast(relax_or_tir_var)); - } - } else { - if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); - it != tir_var_remap.end()) { - target_set.insert(Downcast((*it).second)); - } else { - target_set.insert(Downcast(relax_or_tir_var)); - } - } + auto set_union = [&](std::unordered_set, ObjectPtrHash, + ObjectPtrEqual>& target_set, + const std::unordered_set, ObjectPtrHash, + ObjectPtrEqual>& source_set, + const ffi::Map& var_remap, + const ffi::Map& tir_var_remap) { + // In-place update the set in global info by unioning with the local set, variable + // mappings are applied. + for (const auto& relax_or_tir_var : source_set) { + if (relax_or_tir_var.as()) { + if (auto it = var_remap.find(Downcast(relax_or_tir_var)); it != var_remap.end()) { + target_set.insert(Downcast((*it).second)); + } else { + target_set.insert(Downcast(relax_or_tir_var)); } - }; + } else { + if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); + it != tir_var_remap.end()) { + target_set.insert(Downcast((*it).second)); + } else { + target_set.insert(Downcast(relax_or_tir_var)); + } + } + } + }; if (global_info) { set_union(global_info->requires_compile_time_param, visitor.info_.requires_compile_time_param, @@ -508,8 +509,8 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { /*! \brief Visitor to find the correspondence between parameters in multiple functions. */ class ParamRemapper : private ExprFunctor { public: - static std::pair, Map> GetParamMapping( - const Array& functions) { + static std::pair, ffi::Map> GetParamMapping( + const ffi::Array& functions) { ParamRemapper mapper; if (functions.size()) { auto num_inputs_0 = functions[0]->GetAttr(attr::kNumInput).value()->value; @@ -536,15 +537,15 @@ class ParamRemapper : private ExprFunctor { private: void VisitExpr_(const VarNode* lhs_var, const Expr& rhs_expr) final { auto rhs_var = Downcast(rhs_expr); - if (auto it = var_remap_.find(GetRef(lhs_var)); it != var_remap_.end()) { + if (auto it = var_remap_.find(ffi::GetRef(lhs_var)); it != var_remap_.end()) { CHECK((*it).second.same_as(rhs_var)); } else { - var_remap_.Set(GetRef(lhs_var), rhs_var); + var_remap_.Set(ffi::GetRef(lhs_var), rhs_var); } CHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, /*map_free_vars=*/true)) << "The struct info of the parameters should be the same for all target functions"; - auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(GetRef(lhs_var))); + auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(ffi::GetRef(lhs_var))); auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr)); ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size()); for (size_t i = 0; i < lhs_tir_vars.size(); i++) { @@ -556,15 +557,15 @@ class ParamRemapper : private ExprFunctor { } } - Map var_remap_; - Map tir_var_remap_; + ffi::Map var_remap_; + ffi::Map tir_var_remap_; }; class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { public: - static GlobalCollectInfo Collect(const Array& functions, - const Map& var_remap, - const Map& tir_var_remap) { + static GlobalCollectInfo Collect(const ffi::Array& functions, + const ffi::Map& var_remap, + const ffi::Map& tir_var_remap) { GlobalLiftableBindingCollector collector(var_remap, tir_var_remap); ICHECK(functions.size()); for (const auto& func : functions) { @@ -574,9 +575,9 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { } collector(func); } - Array params(functions[0]->params.begin() + - functions[0]->GetAttr(attr::kNumInput).value()->value, - functions[0]->params.end()); + ffi::Array params(functions[0]->params.begin() + + functions[0]->GetAttr(attr::kNumInput).value()->value, + functions[0]->params.end()); // todo(@tvm-team): use c++20 designated initializers when windows CI supports it GlobalCollectInfo info = GlobalCollectInfo(); info.orig_functions = functions; @@ -611,8 +612,8 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { } private: - GlobalLiftableBindingCollector(const Map& var_remap, - const Map tir_var_remap) + GlobalLiftableBindingCollector(const ffi::Map& var_remap, + const ffi::Map tir_var_remap) : var_remap_(var_remap), tir_var_remap_(tir_var_remap) {} void VisitBinding(const Binding& binding) override { CHECK(!binding->IsInstance()) << "MatchCast is not supported in global lifting"; @@ -633,9 +634,9 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { // The cross-function mapping between variables. This is initialized with the mapping from the // function parameters, and is updated with the mapping between binding variables asthe collector // visits the bindings. - Map var_remap_; + ffi::Map var_remap_; // The cross-function between between TIR variables. - Map tir_var_remap_; + ffi::Map tir_var_remap_; std::vector unified_bindings_; // The mapping between the unified bindings and the original bindings in different functions. // The unified binding is the binding with all variables replaced by the unified variables as @@ -678,7 +679,7 @@ class ConsumeBundledParams : public ExprMutator { builder_->Emit( Call(call_pure_packed, {builtin_tuple_reset_item, tuple_get_item->tuple, PrimValue(tuple_get_item->index)}, - tvm::Attrs(), {TupleStructInfo(Array{})})); + tvm::Attrs(), {TupleStructInfo(ffi::Array{})})); } else { ExprMutator::VisitBinding_(binding, tuple_get_item); } @@ -700,10 +701,10 @@ class ConsumeBundledParams : public ExprMutator { }; std::vector> GetTargetFunctions( - const IRModule& mod, const Variant>& shared_transform) { + const IRModule& mod, const ffi::Variant>& shared_transform) { std::vector> target_functions; - if (shared_transform.as>().value_or(Array{}).size()) { - auto names = shared_transform.as>().value(); + if (shared_transform.as>().value_or(ffi::Array{}).size()) { + auto names = shared_transform.as>().value(); for (const auto& name : names) { auto gvar = mod->global_var_map_.Get(name); CHECK(gvar) << "When LiftTransformParams is called with a list of function names, " @@ -752,11 +753,11 @@ std::vector> GetTargetFunctions( namespace transform { -Pass PartitionTransformParams(Variant> shared_transform) { +Pass PartitionTransformParams(ffi::Variant> shared_transform) { auto pass_func = [=](IRModule mod, PassContext pc) { std::optional global_collect_info; - CHECK((shared_transform.as() || shared_transform.as>())) + CHECK((shared_transform.as() || shared_transform.as>())) << "shared_transform should be a boolean or an array of function names"; auto target_functions = GetTargetFunctions(mod, shared_transform); @@ -783,7 +784,7 @@ Pass PartitionTransformParams(Variant> shared_transform) { updated_runtime_functions->Add(gvar, new_runtime_func); } - Map lifted_transform_functions; + ffi::Map lifted_transform_functions; if (global_collect_info.has_value()) { auto global_transform = global_collect_info.value().MakeCompileTimeFunc(); lifted_transform_functions.Set("transform_params", global_transform); @@ -818,7 +819,7 @@ Pass PartitionTransformParams(Variant> shared_transform) { return tvm::transform::CreateModulePass(pass_func, 1, "PartitionTransformParams", {}); } -Pass LiftTransformParams(Variant> shared_transform) { +Pass LiftTransformParams(ffi::Variant> shared_transform) { // A post-proc utility as as the third step in LiftTransformParams // // 1. PartitionTransformParams: Partition each function into a diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 36911cd094d8..00c7092c0220 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -38,14 +38,14 @@ class Mutator : public ExprMutator { if (op->op.same_as(alloc_tensor_op)) { CHECK_EQ(op->args.size(), 4) << "Op " << op->op << " should have three arguments, " << "[shape, dtype, runtime_device_index, storage_scope]. " - << "However, received " << GetRef(op); + << "However, received " << ffi::GetRef(op); auto shape_arg = op->args[0]; auto dtype = Downcast(op->args[1]); PrimValue runtime_device_index = Downcast(op->args[2]); StringImm storage_scope = Downcast(op->args[3]); - auto shape = [&]() -> Array { + auto shape = [&]() -> ffi::Array { if (auto ptr = shape_arg.as()) { return ptr->values; } diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 025e91c3c3ab..da9518394468 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -166,14 +166,14 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } private: - Optional GetCodegenName(const Expr& callee) { + ffi::Optional GetCodegenName(const Expr& callee) { auto const* gvar = callee.as(); if (!gvar) { return std::nullopt; } auto composite_name_opt = - mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); + mod_->Lookup(ffi::GetRef(gvar))->GetAttr(attr::kComposite); if (!composite_name_opt) { return std::nullopt; } @@ -181,16 +181,16 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return relax::GetCodegenName(composite_name_opt.value()); } - Optional GetCodegenName(Group* group) { + ffi::Optional GetCodegenName(Group* group) { if (auto opt_str = group->attrs.Get(attr::kCodegen)) { - return Downcast(opt_str.value()); + return Downcast(opt_str.value()); } return std::nullopt; } Group* CreateNewGroup(const CallNode* call) { Group* group = arena_->make(); - if (Optional codegen_name = GetCodegenName(call->op)) { + if (ffi::Optional codegen_name = GetCodegenName(call->op)) { group->attrs.Set(attr::kCodegen, codegen_name.value()); } return group; @@ -220,7 +220,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } } - std::unordered_set GetParentGroupDependencies(const Array& args) { + std::unordered_set GetParentGroupDependencies(const ffi::Array& args) { // Collect groups that parent groups depend on std::unordered_set dependencies; @@ -233,7 +233,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return dependencies; } - void UpdateGroupDependencies(Group* group, const Array& args) { + void UpdateGroupDependencies(Group* group, const ffi::Array& args) { Group* group_root = group->FindRoot(); std::function visit_expr = [&](Expr expr) { @@ -269,7 +269,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } std::vector GetGroupsToMerge(const CallNode* call) { - Optional codegen_name = GetCodegenName(call->op); + ffi::Optional codegen_name = GetCodegenName(call->op); if (!codegen_name.has_value()) { return {}; } @@ -279,7 +279,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { for (const auto& arg : call->args) { auto arg_group = memo_[arg]; - Optional arg_codegen_name = GetCodegenName(arg_group); + ffi::Optional arg_codegen_name = GetCodegenName(arg_group); if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) { // If there is a parent group with the same target, which none of the parent dependency // groups depends on, merging "this" call node into the parent group will not form a cyclic @@ -308,7 +308,7 @@ class CompositeInliner : public ExprMutator { using ExprMutator::VisitExpr_; Function Run(Function func) { - inlined_functions_ = Map(); + inlined_functions_ = ffi::Map(); auto new_body = VisitExpr(ToNonDataflow(func->body)); auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs, func->span); @@ -319,7 +319,7 @@ class CompositeInliner : public ExprMutator { if (call->op->IsInstance()) { auto gvar = Downcast(call->op); auto func = Downcast(mod_->Lookup(gvar)); - if (func->GetAttr(attr::kComposite)) { + if (func->GetAttr(attr::kComposite)) { if (!inlined_functions_.count(func)) { auto new_func = CopyWithNewVars(func); new_func = WithoutAttr(new_func, tvm::relax::attr::kPrimitive); @@ -334,7 +334,7 @@ class CompositeInliner : public ExprMutator { private: IRModule mod_; - Map inlined_functions_; + ffi::Map inlined_functions_; }; /*! @@ -361,7 +361,7 @@ class CompositeFunctionAnnotator : public ExprMutator { if (call->op->IsInstance()) { GlobalVar cur_var = Downcast(call->op); auto func = Downcast(mod_->Lookup(cur_var)); - if (auto codegen_name = func->GetAttr(attr::kCodegen)) { + if (auto codegen_name = func->GetAttr(attr::kCodegen)) { GlobalVar new_var; if (var_map_.count(cur_var) > 0) { // if we visited before, we don't need to create the new function, @@ -374,7 +374,7 @@ class CompositeFunctionAnnotator : public ExprMutator { builder_->GetContextIRModule()->Remove(old_var); // rename the function. - String new_func_name = cur_var->name_hint + "_" + codegen_name.value(); + ffi::String new_func_name = cur_var->name_hint + "_" + codegen_name.value(); Function new_func = inliner.Run(Downcast(func)); new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, new_func_name); new_func = WithoutAttr(std::move(new_func), tvm::relax::attr::kPrimitive); @@ -388,7 +388,7 @@ class CompositeFunctionAnnotator : public ExprMutator { return Call(new_var, call->args); } } - return GetRef(call); + return ffi::GetRef(call); } private: diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 5bb8d2d3e305..2d24f0785a15 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -35,9 +35,10 @@ namespace transform { class MetaScheduleTuner { public: - explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, - Integer max_trials_per_task, Optional> op_names, - Map params = {}) + explicit MetaScheduleTuner(Target target, ffi::String work_dir, Integer max_trials_global, + Integer max_trials_per_task, + ffi::Optional> op_names, + ffi::Map params = {}) : target_(target), work_dir_(work_dir), max_trials_global_(max_trials_global), @@ -64,15 +65,15 @@ class MetaScheduleTuner { private: Target target_; - String work_dir_; + ffi::String work_dir_; Integer max_trials_global_; Integer max_trials_per_task_; - Optional> op_names_; - Map params_; + ffi::Optional> op_names_; + ffi::Map params_; tvm::ffi::Function normalize_mod_func_; }; -Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = false) { +Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_warning = false) { using tvm::meta_schedule::Database; Target target = Target::Current(false); const std::optional normalize_mod_func_ = @@ -85,23 +86,23 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = database = Database::Current().value(); } else { ICHECK(work_dir.has_value()); - String path_workload = work_dir.value() + "/database_workload.json"; - String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + ffi::String path_workload = work_dir.value() + "/database_workload.json"; + ffi::String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload << ", Tuning records at: " << path_tuning_record; database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); } - Map result; + ffi::Map result; auto mod_eq_structural = meta_schedule::ModuleEquality::Create("ignore-tensor"); for (const auto& iter : mod->functions) { GlobalVar gv = iter.first; BaseFunc base_func = iter.second; if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = GetRef(prim_func_node); + tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); IRModule tir_mod = (*normalize_mod_func_)(prim_func).cast(); - if (Optional opt_record = + if (ffi::Optional opt_record = database->QueryTuningRecord(tir_mod, target, gv->name_hint)) { meta_schedule::TuningRecord record = opt_record.value(); tir::Schedule sch{nullptr}; @@ -146,10 +147,10 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); } -Pass MetaScheduleTuneIRMod(Map params, String work_dir, +Pass MetaScheduleTuneIRMod(ffi::Map params, ffi::String work_dir, Integer max_trials_global, - Optional max_trials_per_task = std::nullopt, - Optional> op_names = std::nullopt) { + ffi::Optional max_trials_per_task = std::nullopt, + ffi::Optional> op_names = std::nullopt) { Target target = Target::Current(false); auto pass_func = [=](IRModule m, PassContext ctx) { auto max_trials_task = max_trials_per_task.value_or(max_trials_global); @@ -162,7 +163,7 @@ Pass MetaScheduleTuneIRMod(Map params, String work_dir, /*traceable*/ true); } -Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { +Pass MetaScheduleTuneTIR(ffi::String work_dir, Integer max_trials_global) { Target target = Target::Current(false); ffi::TypedFunction pass_func = [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 8bd740009ef8..0002de872aa8 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -46,7 +46,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr body = this->VisitWithNewScope(op->body, op->params); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } @@ -58,13 +58,13 @@ class NormalizeMutator : public ExprMutatorBase { Expr false_b = this->VisitWithNewScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op); + return ffi::GetRef(op); } else { return If(guard, true_b, false_b, op->span); } } - Expr VisitWithNewScope(const Expr& expr, Optional> params = std::nullopt) { + Expr VisitWithNewScope(const Expr& expr, ffi::Optional> params = std::nullopt) { builder_->BeginBindingBlock(); if (params.defined()) { builder_->BeginScope(params); @@ -82,7 +82,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr VisitExpr_(const SeqExprNode* op) final { bool all_blocks_unchanged = true; - Array blocks; + ffi::Array blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { @@ -100,7 +100,7 @@ class NormalizeMutator : public ExprMutatorBase { } if (all_blocks_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return SeqExpr(blocks, body); } @@ -151,7 +151,7 @@ class NormalizeMutator : public ExprMutatorBase { } if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { builder_->EmitNormalized(VarBinding(binding->var, new_value)); } @@ -161,7 +161,7 @@ class NormalizeMutator : public ExprMutatorBase { Expr new_value = this->VisitExpr(binding->value); if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(ffi::GetRef(binding)); } else { builder_->EmitNormalized( MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); @@ -219,7 +219,7 @@ class GlobalVarNormalizer : private ExprMutator { /*! \brief Check if any function needs to be renamed. */ bool NeedRename() { for (const auto& [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (global_symbol && global_symbol.value() != gvar->name_hint) { return true; } @@ -230,7 +230,7 @@ class GlobalVarNormalizer : private ExprMutator { /*! \brief Add public functions to the builder, and update the name supplier. */ void AddPublicFunctions() { for (const auto& [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (!global_symbol) { continue; } @@ -250,7 +250,7 @@ class GlobalVarNormalizer : private ExprMutator { */ void AddPrivateFunctions() { for (auto [gvar, func] : module_->functions) { - auto global_symbol = func->GetAttr("global_symbol"); + auto global_symbol = func->GetAttr("global_symbol"); if (global_symbol) { continue; } @@ -262,13 +262,13 @@ class GlobalVarNormalizer : private ExprMutator { } Expr VisitExpr_(const GlobalVarNode* op) final { - ICHECK(gvar_map_.count(GetRef(op))); - return gvar_map_[GetRef(op)]; + ICHECK(gvar_map_.count(ffi::GetRef(op))); + return gvar_map_[ffi::GetRef(op)]; } IRModule module_; NameSupply name_supply_; - Map gvar_map_; + ffi::Map gvar_map_; }; namespace transform { diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 1034c2640f2a..087579fc309f 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -77,12 +77,12 @@ class VDeviceLookup { } private: - Optional> opt_vdevices_ = std::nullopt; + ffi::Optional> opt_vdevices_ = std::nullopt; }; class DeviceHintCollector : ExprVisitor { public: - static std::tuple, Map> Collect(IRModule mod) { + static std::tuple, ffi::Map> Collect(IRModule mod) { DeviceHintCollector visitor{VDeviceLookup(mod)}; for (const auto& [gvar, base_func] : mod->functions) { @@ -178,7 +178,7 @@ class DeviceHintCollector : ExprVisitor { } } - Optional LookupBinding(const Expr& expr) const { + ffi::Optional LookupBinding(const Expr& expr) const { if (auto var = expr.as()) { if (auto bound = binding_lookup_.Get(var.value())) { return bound.value(); @@ -194,14 +194,14 @@ class DeviceHintCollector : ExprVisitor { // A lookup of variable bindings, used to unwrap the variable // bindings in functions that return a tuple. - Map binding_lookup_; + ffi::Map binding_lookup_; // A map from Var to the VDevice they are known to occur on. This // only contains variables whose location is explicitly known // (e.g. output of `R.hint_on_device`, variables with explicit // `VDevice` in their struct info), and does not include variables // whose location is (e.g. input of `R.hint_on_device`). - Map known_vdevice_; + ffi::Map known_vdevice_; // A map from Var to the VDevice they are expected to occur on. If // a variable appears in both `known_vdevice_` and @@ -213,7 +213,7 @@ class DeviceHintCollector : ExprVisitor { // Therefore, we only determine that `A` is located on "cuda:0" if // no other annotation has already provided a known location for // `A`. - Map hint_on_device_inputs_; + ffi::Map hint_on_device_inputs_; // The `R.hint_on_device` operator. const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); @@ -223,7 +223,7 @@ class DeviceHintCollector : ExprVisitor { // same VDevice. class VDeviceSetCollector : ExprVisitor { public: - static Map> Collect(IRModule mod) { + static ffi::Map> Collect(IRModule mod) { VDeviceSetCollector visitor; for (const auto& [gvar, base_func] : mod->functions) { if (auto func = base_func.as()) { @@ -249,13 +249,13 @@ class VDeviceSetCollector : ExprVisitor { void VisitExpr_(const VarNode* op) override { if (current_binding_) { - auto var = GetRef(op); + auto var = ffi::GetRef(op); var_to_co_located_vars_[current_binding_.value()].push_back(var); var_to_co_located_vars_[var].push_back(current_binding_.value()); } } - Optional current_binding_ = std::nullopt; + ffi::Optional current_binding_ = std::nullopt; // Lookup from relax variable to the set of relax variables which // must be located on the same device. For example, a trivial @@ -267,18 +267,18 @@ class VDeviceSetCollector : ExprVisitor { // `relax::Call` operation must be located on the same device, with // the exception of `R.hint_on_device` and `R.to_vdevice`, which may // introduce a transfer across devices. - std::unordered_map> var_to_co_located_vars_; + std::unordered_map> var_to_co_located_vars_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; -Map InferVDevice(IRModule mod) { +ffi::Map InferVDevice(IRModule mod) { auto [explicit_annotations, hint_on_device_args] = DeviceHintCollector::Collect(mod); auto co_located_var_lookup = VDeviceSetCollector::Collect(mod); - Map known_vdevice; + ffi::Map known_vdevice; std::vector to_visit; // A helper function to propagate all `known_vdevice` entries based @@ -324,7 +324,7 @@ Map InferVDevice(IRModule mod) { // Update the module to include the inferred VDevice annotations. class VDeviceStructInfoUpdater : ExprMutator { public: - static IRModule Apply(IRModule mod, Map vdevice_map) { + static IRModule Apply(IRModule mod, ffi::Map vdevice_map) { VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map); IRModule updates; @@ -346,7 +346,7 @@ class VDeviceStructInfoUpdater : ExprMutator { } private: - VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, Map vdevice_map) + VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, ffi::Map vdevice_map) : vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {} Var VisitVarDef(const Var& old_var) override { @@ -390,14 +390,14 @@ class VDeviceStructInfoUpdater : ExprMutator { if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) { return arg; } else { - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); attrs->dst_vdevice = output_vdevice; return Call(to_vdevice_op_, {arg}, Attrs(attrs), {}); } } VDeviceLookup vdevice_lookup_; - Map vdevice_map_; + ffi::Map vdevice_map_; const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); }; diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index d8bb6465da05..b6e038eac1bd 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -49,13 +49,13 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { if (call->op == call_pure_packed_op_) { - auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), call->attrs, call->sinfo_args); return VisitExpr(ret); } if (call->op == call_inplace_packed_op_) { // call_inplace_packed has its own attrs so we don't pass those down - auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + auto ret = Call(call->args[0], ffi::Array(call->args.begin() + 1, call->args.end()), tvm::Attrs(), call->sinfo_args); return VisitExpr(ret); } @@ -68,7 +68,7 @@ class PurityRemover : public ExprMutator { Expr VisitExpr_(const FunctionNode* func) override { // handling inner functions: we will remove purity annotations from them too - return RemovePurity(GetRef(func)); + return RemovePurity(ffi::GetRef(func)); } private: diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 26145cde1d48..140e6ae8333e 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -44,7 +44,7 @@ class PartialTupleUsageCollector : ExprVisitor { PMap num_outputs; for (const auto& [gvar, base_func] : mod->functions) { - bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (!is_exposed) { if (auto relax_func = base_func.as()) { @@ -98,21 +98,21 @@ class PartialTupleUsageCollector : ExprVisitor { CHECK_GE(op->index, 0) << "IndexError: " << "Indices for TupleGetItem must be non-negative, " - << "but expression " << GetRef(op) << " uses a tuple index of " - << op->index; + << "but expression " << ffi::GetRef(op) + << " uses a tuple index of " << op->index; size_t index = op->index; CHECK_LT(index, used_indices.size()) << "IndexError: " << "Indices for TupleGetItem must be less than the size of the tuple, " - << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << "but expression " << ffi::GetRef(op) << " uses a tuple index of " << op->index << " for a tuple of size " << used_indices.size(); used_indices[index] = true; } } void VisitExpr_(const VarNode* op) override { - if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef(op))) { + if (auto* usage_mask_ptr = GetCalleeUsageMask(ffi::GetRef(op))) { auto& usage_mask = *usage_mask_ptr; for (size_t i = 0; i < usage_mask.size(); i++) { usage_mask[i] = true; @@ -138,7 +138,7 @@ class PartialTupleUsageCollector : ExprVisitor { } Expr UnwrapBindings(Expr expr) const { - auto get_bound_value = [&](const Expr& expr) -> Optional { + auto get_bound_value = [&](const Expr& expr) -> ffi::Optional { if (auto var = expr.as()) { if (auto known_binding = known_bindings_.Get(var.value())) { return known_binding.value(); @@ -153,7 +153,7 @@ class PartialTupleUsageCollector : ExprVisitor { return expr; } - Map known_bindings_; + ffi::Map known_bindings_; PMap> output_usage_mask_; }; @@ -164,7 +164,7 @@ Function UpdateCallee(Function func, const std::vector& usage_mask) { ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " << "should have been pruned already by PartialTupleUsageCollector"; - Array outputs; + ffi::Array outputs; // This helper variable will be removed by the post-proc of // CanonicalizeBindings and DeadCodeElimination. @@ -267,7 +267,7 @@ Pass RemoveUnusedOutputs() { num_outputs_used += used; } - Array new_results; + ffi::Array new_results; int new_result_index = 0; for (size_t i = 0; i < usage_mask.size(); i++) { if (usage_mask[i]) { diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 2e88ebe417b3..4d203648ffea 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -51,11 +51,11 @@ struct CalleeAnalysis { * * \return The arguments to be used for the modified function */ - std::function(Array)> arg_updater; + std::function(ffi::Array)> arg_updater; }; std::optional AnalyzeCallee(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; auto free_relax_vars = [&]() -> PSet { @@ -66,7 +66,7 @@ std::optional AnalyzeCallee(Function func) { std::vector parameter_mask; parameter_mask.reserve(func->params.size()); - Array params; + ffi::Array params; for (const auto& param : func->params) { bool is_used = free_relax_vars.count(param); parameter_mask.push_back(is_used); @@ -93,7 +93,7 @@ std::optional AnalyzeCallee(Function func) { }(); // Use an array to define the order of the symbolic variables - Array free_tir_vars; + ffi::Array free_tir_vars; for (const auto& tir_var : FreeSymbolicVars(func->body)) { if (!defined_tir_params.count(tir_var)) { free_tir_vars.push_back(tir_var); @@ -110,12 +110,12 @@ std::optional AnalyzeCallee(Function func) { Downcast(func->struct_info_)->purity); auto arg_updater = [parameter_mask, old_relax_params = func->params, - free_tir_vars](Array old_args) -> Array { + free_tir_vars](ffi::Array old_args) -> ffi::Array { ICHECK_EQ(old_args.size(), parameter_mask.size()) << "Call provides " << old_args.size() << ", but the callee accepts " << parameter_mask.size() << " parameters"; - Array new_args; + ffi::Array new_args; for (size_t i = 0; i < old_args.size(); i++) { if (parameter_mask.at(i)) { new_args.push_back(old_args[i]); @@ -123,7 +123,7 @@ std::optional AnalyzeCallee(Function func) { } if (free_tir_vars.size()) { - Map old_binding; + ffi::Map old_binding; for (size_t i = 0; i < old_relax_params.size(); i++) { old_binding.Set(old_relax_params[i], old_args[i]); } diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index b97981a7f4e5..5c73acb451bb 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { // TODO(Lunderberg): Allow pattern-matching to handle a flexible // number of arguments, each of which matches the same type of // pattern. @@ -73,7 +73,7 @@ std::tuple)>> Crea auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern { ICHECK_LT(num_concat, pat_permute_dims.size()); auto concat_tuple = TuplePattern( - Array(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); + ffi::Array(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); return IsOp("relax.concat")(concat_tuple); }; @@ -82,7 +82,7 @@ std::tuple)>> Crea pat_concat = pat_concat | make_pattern_with_num_concat(i); } - auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional> { + auto get_permute_dims_optional_axes = [](const Expr& expr) -> ffi::Optional> { auto call = expr.as(); ICHECK(call); auto attrs = call->attrs.as(); @@ -92,12 +92,12 @@ std::tuple)>> Crea }; auto get_permute_dims_axes = - [get_permute_dims_optional_axes](const Expr& expr) -> Array { + [get_permute_dims_optional_axes](const Expr& expr) -> ffi::Array { if (auto opt_axes = get_permute_dims_optional_axes(expr)) { return opt_axes.value(); } else { auto call = Downcast(expr); - Array permutation; + ffi::Array permutation; auto arg_sinfo = call->args[0]->struct_info_.as(); CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " << "but argument " << call->args[0] << " has struct info " @@ -111,7 +111,7 @@ std::tuple)>> Crea } }; - auto permute_dims_axes_are_compatible = [&](const Array& permute_dims) -> bool { + auto permute_dims_axes_are_compatible = [&](const ffi::Array& permute_dims) -> bool { auto first_axes = get_permute_dims_axes(permute_dims[0]); for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) { auto i_axes = get_permute_dims_axes(permute_dims[i_arg]); @@ -127,9 +127,9 @@ std::tuple)>> Crea return true; }; - auto rewriter = [=](Expr expr, Map matches) -> Expr { - Array args; - Array all_permute_dims; + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { + ffi::Array args; + ffi::Array all_permute_dims; for (size_t i = 0; i < max_concat; i++) { if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) { all_permute_dims.push_back(permute_dim_expr.value()); @@ -145,7 +145,8 @@ std::tuple)>> Crea if (!permute_dims_axes_are_compatible(all_permute_dims)) { return expr; } - Optional> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]); + ffi::Optional> permute_axes = + get_permute_dims_optional_axes(all_permute_dims[0]); Call concat_call = Downcast(matches[pat_concat]); auto concat_attrs = concat_call->attrs.as(); diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index eebec15f52ce..51744a43247d 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -41,7 +41,7 @@ namespace tvm { namespace relax { namespace { -std::tuple)>> CreatePatterns() { +std::tuple)>> CreatePatterns() { auto pat_lhs = WildcardPattern(); auto pat_weights = WildcardPattern(); @@ -50,7 +50,7 @@ std::tuple)>> Crea auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { auto lhs = matches[pat_lhs]; auto weights = matches[pat_weights]; auto indices = matches[pat_indices]; @@ -114,7 +114,7 @@ std::tuple)>> Crea // indices.shape = [batch1] // reordered_weight.shape = [infeatures, table_size, outfeatures] - auto reordered_weight = permute_dims(weights, Array{Integer(1), Integer(0), Integer(2)}); + auto reordered_weight = permute_dims(weights, ffi::Array{Integer(1), Integer(0), Integer(2)}); // fused_weight.shape = [infeatures, table_size * outfeatures] auto fused_weight = reshape(reordered_weight, ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]})); diff --git a/src/relax/transform/replace_global_vars.cc b/src/relax/transform/replace_global_vars.cc index ea5d5e18d8ff..48548de887cd 100644 --- a/src/relax/transform/replace_global_vars.cc +++ b/src/relax/transform/replace_global_vars.cc @@ -37,12 +37,12 @@ namespace { using tvm::transform::GlobalVarReplacer; struct Mutator : ExprMutator { - Map replacements; - explicit Mutator(Map replacements) : replacements(replacements) {} + ffi::Map replacements; + explicit Mutator(ffi::Map replacements) : replacements(replacements) {} using ExprMutator::VisitExpr_; Expr VisitExpr_(const GlobalVarNode* node) override { - auto gvar = GetRef(node); + auto gvar = ffi::GetRef(node); return replacements.Get(gvar).value_or(gvar); } }; @@ -51,14 +51,14 @@ struct Mutator : ExprMutator { TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, - Map replacements) -> BaseFunc { + ffi::Map replacements) -> BaseFunc { Mutator mutator(replacements); auto new_func = Downcast(mutator(Downcast(func))); // If the function is externally exposed, and is being replaced // by a GlobalVar with a new name, then the function's // kGlobalSymbol must be updated to match. - if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = new_func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = opt.value(); for (const auto& [before, after] : replacements) { if (before->name_hint == name) { @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& func, - Map) -> BaseFunc { + ffi::Map) -> BaseFunc { return Downcast(func); }); diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index b1faf5c09271..955b858a0c7c 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -89,7 +89,7 @@ struct LiftedFunctionRewritePlan { // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; // The tir vars in the original function that are propagated to the lifted function - Optional propogated_tir_vars = std::nullopt; + ffi::Optional propogated_tir_vars = std::nullopt; }; /*! \brief Builder of the lifted function for cuda graph capturing or allocations */ @@ -123,22 +123,22 @@ class FuncBuilder : public ExprMutator { /*! \brief Build the new function */ Function Build() { - Array params; - Optional shape_expr = std::nullopt; + ffi::Array params; + ffi::Optional shape_expr = std::nullopt; if (shape_expr_inputs_.size()) { - Array tir_vars; + ffi::Array tir_vars; for (const auto* var : shape_expr_inputs_) { - auto new_var = GetRef(var).copy_with_suffix(""); - tir_var_remap_.Set(GetRef(var), new_var); + auto new_var = ffi::GetRef(var).copy_with_suffix(""); + tir_var_remap_.Set(ffi::GetRef(var), new_var); tir_vars.push_back(new_var); } shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars)); } // Set up the parameters for (const auto* input : inputs_) { - auto new_var = Var( - input->name_hint(), - VisitExprDepStructInfoField(Downcast>(input->struct_info_).value())); + auto new_var = Var(input->name_hint(), + VisitExprDepStructInfoField( + Downcast>(input->struct_info_).value())); var_remap_[input->vid] = new_var; params.push_back(new_var); } @@ -151,14 +151,14 @@ class FuncBuilder : public ExprMutator { VisitBinding_(binding); } // Set up the outputs - Array outputs; + ffi::Array outputs; for (const auto* var : outputs_) { outputs.push_back(VisitExpr_(var)); } auto output = builder_->Emit(Tuple(outputs)); auto block = builder_->EndBlock(); auto body = builder_->Normalize(SeqExpr({block}, output)); - Map attrs; + ffi::Map attrs; attrs.Set(relax::attr::kForcePure, true); auto func = Function(params, body, Downcast(output->struct_info_.value()), /*is_pure=*/true, /*attrs=*/DictAttrs(attrs)); @@ -171,7 +171,7 @@ class FuncBuilder : public ExprMutator { support::OrderedSet outputs_; support::OrderedSet shape_expr_inputs_; std::vector bindings_; - Map tir_var_remap_; + ffi::Map tir_var_remap_; }; // Collect the storage objects that are used as the function output @@ -250,7 +250,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { func->attrs.GetAttr(attr::kNumInput).value_or(Integer(func->params.size())); auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func); for (int i = 0; i < static_cast(func->params.size()); ++i) { - Array symbolic_vars = DefinableTIRVarsInStructInfo( + ffi::Array symbolic_vars = DefinableTIRVarsInStructInfo( Downcast(func->params[i]->struct_info_.value())); if (i < num_inputs.IntValue()) { for (const auto& symbolic_var : symbolic_vars) { @@ -278,9 +278,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { plan->is_alloc = is_alloc; plan->lifted_bindings = std::move(region->bindings_); if (region->shape_expr_inputs_.size()) { - Array tir_vars; + ffi::Array tir_vars; for (const auto* var : region->shape_expr_inputs_) { - tir_vars.push_back(GetRef(var)); + tir_vars.push_back(ffi::GetRef(var)); } plan->propogated_tir_vars = ShapeExpr(tir_vars); } @@ -306,10 +306,11 @@ class CUDAGraphRewritePlanner : public ExprVisitor { * \brief Extract the name hints of the symbolic variables that are allowed to be captured * from the function attributes. */ - std::unordered_set ExtractSymbolicVarHints(const Function& func) { + std::unordered_set ExtractSymbolicVarHints(const Function& func) { auto symbolic_var_names = - func->attrs.GetAttr>("relax.rewrite_cuda_graph.capture_symbolic_vars") - .value_or(Array()); + func->attrs + .GetAttr>("relax.rewrite_cuda_graph.capture_symbolic_vars") + .value_or(ffi::Array()); return {symbolic_var_names.begin(), symbolic_var_names.end()}; } @@ -365,7 +366,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { const auto* call_gv = call->op.as(); bool call_prim_func = - call_gv ? mod_->Lookup(GetRef(call_gv))->IsInstance() : false; + call_gv ? mod_->Lookup(ffi::GetRef(call_gv))->IsInstance() + : false; // Check whether the call can be lifted to the capture function. It requires all the arguments // to be static and the call to be a kernel launch or a pure operation (e.g. memory view). @@ -399,8 +401,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { if (const auto* op = call->op.as()) { return !support::StartsWith(op->name, "relax.memory") && !support::StartsWith(op->name, "relax.builtin") && op->name != "relax.reshape" && - !GetRef(op).same_as(null_value_op) && - !GetRef(op).same_as(call_builtin_with_ctx_op); + !ffi::GetRef(op).same_as(null_value_op) && + !ffi::GetRef(op).same_as(call_builtin_with_ctx_op); } return false; }(); @@ -442,7 +444,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - if (IsStatic(GetRef(var))) { + if (IsStatic(ffi::GetRef(var))) { AddStaticBinding(binding, false); MarkAsFuncInput({var}); } else { @@ -525,7 +527,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } template - bool IsStatic(const Array& exprs, std::vector* vars_collector = nullptr, + bool IsStatic(const ffi::Array& exprs, std::vector* vars_collector = nullptr, std::vector* tir_vars_collector = nullptr) { bool result = true; for (const auto& expr : exprs) { @@ -657,7 +659,7 @@ Function MergeAllocationPlans(const std::vector& all bool operator<(const StorageRecord& other) const { return size < other.size; } }; // Using an (ordered) map to make sure the result is deterministic - std::map>> storage_records; + std::map>> storage_records; static const auto& mem_alloc_storage_op = Op::Get("relax.memory.alloc_storage"); // Collect the storage records for each storage scope. Storage records are stored separately @@ -675,7 +677,7 @@ Function MergeAllocationPlans(const std::vector& all int64_t virtual_device_id = Downcast(Downcast(alloc_storage->args[1])->value)->value; ICHECK_EQ(virtual_device_id, 0); - String storage_scope = Downcast(alloc_storage->args[2])->value; + ffi::String storage_scope = Downcast(alloc_storage->args[2])->value; auto [it, _] = storage_records.try_emplace(storage_scope, alloc_plans.size()); it->second[plan_id].emplace_back(StorageRecord{size, binding, plan}); } @@ -791,7 +793,7 @@ class CUDAGraphRewriter : public ExprMutator { plan->func, current_func_.value()->name_hint + "_cuda_graph_capture"); StructInfo call_sinfo = plan->func->ret_struct_info; // Arguments of the lifted function - Array args; + ffi::Array args; for (const auto& arg : plan->inputs) { args.push_back(VisitExpr_(arg)); } @@ -803,7 +805,7 @@ class CUDAGraphRewriter : public ExprMutator { const auto& shape_expr = plan->func->params.back(); auto symbolic_params = Downcast(shape_expr->struct_info_.value())->values.value(); - Map tir_var_remap; + ffi::Map tir_var_remap; ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); @@ -811,8 +813,8 @@ class CUDAGraphRewriter : public ExprMutator { call_sinfo = Bind(call_sinfo, tir_var_remap); } // Arguments of builtin_run_or_capture - Array tuple_arg_fields{gv_func, Tuple(args), - PrimValue(IntImm(DataType::Int(64), index_capture_++))}; + ffi::Array tuple_arg_fields{gv_func, Tuple(args), + PrimValue(IntImm(DataType::Int(64), index_capture_++))}; if (plan->propogated_tir_vars.defined()) { // The shape expr is explicitly passed twice, one as the last argument of the lifted // function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly @@ -857,7 +859,7 @@ class CUDAGraphRewriter : public ExprMutator { // the original var definition is not visited yet. return EmitRedef(op, it->second); } - return GetRef(op); + return ffi::GetRef(op); } Var EmitRedef(const VarNode* var, const Expr& redef) { @@ -872,8 +874,8 @@ class CUDAGraphRewriter : public ExprMutator { int index_alloc_ = 0; int index_capture_ = 0; support::Arena arena_; - Optional gv_global_alloc_ = std::nullopt; - Optional current_func_ = std::nullopt; + ffi::Optional gv_global_alloc_ = std::nullopt; + ffi::Optional current_func_ = std::nullopt; }; IRModule RewriteCUDAGraph(IRModule mod) { diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index a9e5e8b3c5ff..1ce656a7fb66 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -69,7 +69,7 @@ class DataflowReshapeRewriter : public ExprMutator { // We only rewrite the bindings that are not dataflow output (which means they are not // externally referenced) if (!binding->var->IsInstance()) { - this->builder_->EmitNormalized(GetRef(binding)); + this->builder_->EmitNormalized(ffi::GetRef(binding)); } else { ExprMutator::VisitBinding_(binding); } @@ -78,7 +78,7 @@ class DataflowReshapeRewriter : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); if (call->op != call_tir_op) { - return GetRef(call); + return ffi::GetRef(call); } // We bring the calls of reshape PrimFunc back to calls of high-level @@ -94,13 +94,13 @@ class DataflowReshapeRewriter : public ExprMutator { // then flattens the tuple input so that the fused TIR reshape function ends up having // multiple input buffers. But only one of them should be accessed and reshaped. if (used_tensor_arg_indices.size() != 1) { - return GetRef(call); + return ffi::GetRef(call); } auto arg = arg_tuple[used_tensor_arg_indices[0]]; if (!IsCallingTIRReshape(call, arg)) { - return GetRef(call); + return ffi::GetRef(call); } TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); @@ -111,7 +111,7 @@ class DataflowReshapeRewriter : public ExprMutator { const GlobalVar& global_var = Downcast(call->args[0]); const auto* func = mod_->functions.Get(global_var).value().as(); ICHECK_NOTNULL(func); - if (!HasReshapePattern(GetRef(func))) { + if (!HasReshapePattern(ffi::GetRef(func))) { return false; } @@ -130,7 +130,7 @@ class DataflowReshapeRewriter : public ExprMutator { if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) { return false; } - auto product = [](Array args) -> PrimExpr { + auto product = [](ffi::Array args) -> PrimExpr { PrimExpr p; if (args.empty()) { // Scalar tensors may be empty indicating a single element. diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index af02225361f3..88389b416ca0 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -37,12 +37,12 @@ namespace relax { class CodeGenRunner : ExprMutator { public: - using OptionMap = Map; + using OptionMap = ffi::Map; explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {} - IRModule Run(Optional> target_options, - Array entry_function_names) { + IRModule Run(ffi::Optional> target_options, + ffi::Array entry_function_names) { IRModule mod = builder_->GetContextIRModule(); support::OrderedSet entry_functions; @@ -59,7 +59,8 @@ class CodeGenRunner : ExprMutator { std::vector attr_entry_functions; for (const auto& [gv, func] : mod->functions) { if (func->GetLinkageType() == LinkageType::kExternal && - !func->GetAttr(attr::kCodegen) && func->IsInstance()) { + !func->GetAttr(attr::kCodegen) && + func->IsInstance()) { attr_entry_functions.push_back(gv); } } @@ -80,7 +81,7 @@ class CodeGenRunner : ExprMutator { auto out_mod = builder_->GetContextIRModule(); if (ext_mods.size()) { - if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { + if (auto opt_old_ext_mods = mod->GetAttr>(tvm::attr::kExternalMods)) { auto old_ext_mods = opt_old_ext_mods.value(); ext_mods.insert(ext_mods.begin(), old_ext_mods.begin(), old_ext_mods.end()); } @@ -89,7 +90,7 @@ class CodeGenRunner : ExprMutator { if (constant_names.size()) { // Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated - Map constants; + ffi::Map constants; for (const auto& [constant, name] : constant_names) { ICHECK(!constants.count(name)) << "More than one constant with the name " << name; constants.Set(name, constant->data); @@ -106,11 +107,11 @@ class CodeGenRunner : ExprMutator { Expr VisitExpr_(const CallNode* call_node) override { auto call = Downcast(ExprMutator::VisitExpr_(call_node)); if (auto const* gvar_node = call_node->op.as()) { - const GlobalVar gvar = GetRef(gvar_node); + const GlobalVar gvar = ffi::GetRef(gvar_node); auto create_call_dps_packed = [call_node, this](Expr extern_func, StructInfo ret_struct_info) { - Array new_args({extern_func}); + ffi::Array new_args({extern_func}); new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); static const Op& call_op = Op::Get("relax.call_dps_packed"); @@ -139,7 +140,7 @@ class CodeGenRunner : ExprMutator { } } } - Array new_args; + ffi::Array new_args; for (const auto& arg : call_node->args) { new_args.push_back(VisitExpr(arg)); } @@ -148,8 +149,8 @@ class CodeGenRunner : ExprMutator { } Expr VisitExpr_(const FunctionNode* func_node) override { - Function func = GetRef(func_node); - auto opt_codegen = func->GetAttr(attr::kCodegen); + Function func = ffi::GetRef(func_node); + auto opt_codegen = func->GetAttr(attr::kCodegen); if (opt_codegen) { auto ext_symbol = GetExtSymbol(func); size_t count = 0; @@ -168,8 +169,9 @@ class CodeGenRunner : ExprMutator { } private: - Array InvokeCodegen(IRModule mod, Map target_options) { - std::unordered_map> target_functions; + ffi::Array InvokeCodegen(IRModule mod, + ffi::Map target_options) { + std::unordered_map> target_functions; for (const auto& entry : mod->functions) { if (entry.second->IsInstance()) { @@ -178,26 +180,26 @@ class CodeGenRunner : ExprMutator { PostOrderVisit(entry.second, [&target_functions](Expr e) { if (e->IsInstance()) { auto f = Downcast(e); - if (auto target_opt = f->GetAttr(attr::kCodegen)) { - String target = target_opt.value(); + if (auto target_opt = f->GetAttr(attr::kCodegen)) { + ffi::String target = target_opt.value(); target_functions[target].push_back(f); } } }); } - Array ext_mods; + ffi::Array ext_mods; for (const auto& [target, functions] : target_functions) { OptionMap options = target_options.Get(target).value_or(OptionMap()); // Start the codegen process. // Get the codegen with its ffi key. - String codegen_name = "relax.ext." + target; + ffi::String codegen_name = "relax.ext." + target; const auto codegen = tvm::ffi::Function::GetGlobal(codegen_name); ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; - Array compiled_functions = - (*codegen)(functions, options, constant_names).cast>(); + ffi::Array compiled_functions = + (*codegen)(functions, options, constant_names).cast>(); ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); } @@ -205,7 +207,7 @@ class CodeGenRunner : ExprMutator { } /*! \brief The names of all constants in the original module. */ - Map constant_names; + ffi::Map constant_names; /*! \brief Extern funcs for each global variable. */ std::unordered_map extern_funcs_; }; @@ -213,8 +215,9 @@ class CodeGenRunner : ExprMutator { } // namespace relax namespace transform { -Pass RunCodegen(Optional>> target_options, - Array entry_functions) { +Pass RunCodegen( + ffi::Optional>> target_options, + ffi::Array entry_functions) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::CodeGenRunner(m).Run(target_options, entry_functions); }; diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 41528c7d8690..c0dce4db6122 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -50,7 +50,7 @@ using relax::TIRPattern; class ForMatcher : public TensorizeComparator { public: using SymbolMap = std::unordered_map; - explicit ForMatcher(const tir::PrimFunc& pattern, const Array& pattern_vars) + explicit ForMatcher(const tir::PrimFunc& pattern, const ffi::Array& pattern_vars) : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { for (const auto& pattern_var : pattern_vars) { this->pattern_vars_.insert(pattern_var); @@ -61,7 +61,7 @@ class ForMatcher : public TensorizeComparator { bool Match(const For& top) { const ForNode* pattern_top = pattern_->body.as()->block->body.as(); ICHECK(pattern_top) << "Invalid pattern function"; - if (!VisitStmt(top, GetRef(pattern_top))) { + if (!VisitStmt(top, ffi::GetRef(pattern_top))) { return false; } // Get evaluated symbols, buffers from the pattern. @@ -82,7 +82,7 @@ class ForMatcher : public TensorizeComparator { private: using ExprComparator::VisitExpr_; - Optional QueryEvaluatedSymbols(const Var& var) { + ffi::Optional QueryEvaluatedSymbols(const Var& var) { for (const SymbolMap& symbol_map : evaluated_symbols) { auto it = symbol_map.find(var); if (it != symbol_map.end()) { @@ -94,16 +94,16 @@ class ForMatcher : public TensorizeComparator { bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { if (const auto* op = rhs.as()) { - if (pattern_vars_.count(GetRef(op))) { + if (pattern_vars_.count(ffi::GetRef(op))) { // special case for pattern vars const auto* lhs_ptr = lhs.as(); if (lhs_ptr == nullptr) { if (lhs->IsInstance() || lhs->IsInstance()) { - Optional value = QueryEvaluatedSymbols(GetRef(op)); + ffi::Optional value = QueryEvaluatedSymbols(ffi::GetRef(op)); if (value.defined()) { if (!analyzer_.CanProveEqual(lhs, value.value())) return false; } else { - evaluated_symbols.back()[GetRef(op)] = lhs; + evaluated_symbols.back()[ffi::GetRef(op)] = lhs; } return true; } else { @@ -116,7 +116,7 @@ class ForMatcher : public TensorizeComparator { if (const auto* rhs_ptr = rhs.as()) { const auto* operand_a = rhs_ptr->a.as(); const auto* operand_b = rhs_ptr->b.as(); - if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + if (operand_a != nullptr && pattern_vars_.count(ffi::GetRef(operand_a))) { // pattern var is on the left evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->b); @@ -124,11 +124,12 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 1); + evaluated_symbols.back()[ffi::GetRef(operand_a)] = + MakeConstScalar(rhs_ptr->b.dtype(), 1); return true; } } - if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + if (operand_b != nullptr && pattern_vars_.count(ffi::GetRef(operand_b))) { // pattern var is on the right evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->a); @@ -136,7 +137,8 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 1); + evaluated_symbols.back()[ffi::GetRef(operand_b)] = + MakeConstScalar(rhs_ptr->a.dtype(), 1); return true; } } @@ -145,7 +147,7 @@ class ForMatcher : public TensorizeComparator { if (const auto* rhs_ptr = rhs.as()) { const auto* operand_a = rhs_ptr->a.as(); const auto* operand_b = rhs_ptr->b.as(); - if (operand_a != nullptr && pattern_vars_.count(GetRef(operand_a))) { + if (operand_a != nullptr && pattern_vars_.count(ffi::GetRef(operand_a))) { // pattern var is on the left evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->b); @@ -153,11 +155,12 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_a)] = MakeConstScalar(rhs_ptr->b.dtype(), 0); + evaluated_symbols.back()[ffi::GetRef(operand_a)] = + MakeConstScalar(rhs_ptr->b.dtype(), 0); return true; } } - if (operand_b != nullptr && pattern_vars_.count(GetRef(operand_b))) { + if (operand_b != nullptr && pattern_vars_.count(ffi::GetRef(operand_b))) { // pattern var is on the right evaluated_symbols.push_back(SymbolMap()); bool match = VisitExpr(lhs, rhs_ptr->a); @@ -165,7 +168,8 @@ class ForMatcher : public TensorizeComparator { evaluated_symbols.pop_back(); if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); - evaluated_symbols.back()[GetRef(operand_b)] = MakeConstScalar(rhs_ptr->a.dtype(), 0); + evaluated_symbols.back()[ffi::GetRef(operand_b)] = + MakeConstScalar(rhs_ptr->a.dtype(), 0); return true; } } @@ -241,8 +245,8 @@ class ForMatcher : public TensorizeComparator { bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final { const auto* rhs = other.as(); - loop_stack_lhs_.push_back(GetRef(op)); - loop_stack_rhs_.push_back(GetRef(rhs)); + loop_stack_lhs_.push_back(ffi::GetRef(op)); + loop_stack_rhs_.push_back(ffi::GetRef(rhs)); // The body of loop must be loop or BlockRealize if (!op->body->IsInstance() && !op->body->IsInstance()) { return false; @@ -351,7 +355,7 @@ class ForMatcher : public TensorizeComparator { } template - bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { + bool CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { @@ -369,7 +373,7 @@ class ForMatcher : public TensorizeComparator { /*! \brief Analyze the function and match it with a list of patterns */ class TIRPatternMatcher { public: - static Array Match(Array patterns, Stmt body) { + static ffi::Array Match(ffi::Array patterns, Stmt body) { TIRPatternMatcher matcher(patterns); matcher.OpMatternMatch(body); if (matcher.fail_) return {}; @@ -377,13 +381,13 @@ class TIRPatternMatcher { } private: - explicit TIRPatternMatcher(Array patterns) : patterns_(patterns) {} + explicit TIRPatternMatcher(ffi::Array patterns) : patterns_(patterns) {} // Find an op that matches this block bool BlockPatternMatch(const For& top) { for (const TIRPattern& pattern : patterns_) { tir::PrimFunc pattern_func = pattern; - Array pattern_symbolic_vars; + ffi::Array pattern_symbolic_vars; int buffer_count = pattern_func->buffer_map.size(); for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { pattern_symbolic_vars.push_back(pattern_func->params[i]); @@ -391,7 +395,7 @@ class TIRPatternMatcher { ForMatcher block_matcher(pattern_func, pattern_symbolic_vars); if (block_matcher.Match(top)) { // We have found a match - Array symbol_values; + ffi::Array symbol_values; for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { symbol_values.push_back(block_matcher.evaluated_symbols.back()[pattern_func->params[i]]); } @@ -406,7 +410,7 @@ class TIRPatternMatcher { // For each block in the body, try to find its corresponding pattern one by one void OpMatternMatch(const Stmt& body) { - Array blocks; + ffi::Array blocks; if (body->IsInstance()) { // {for} blocks = {body}; @@ -418,7 +422,7 @@ class TIRPatternMatcher { } for (const Stmt& stmt : blocks) { const ForNode* loop = stmt.as(); - if (loop == nullptr || !BlockPatternMatch(GetRef(loop))) { + if (loop == nullptr || !BlockPatternMatch(ffi::GetRef(loop))) { break; } } @@ -429,9 +433,9 @@ class TIRPatternMatcher { /*! \brief Indicate whether we fail to match.*/ bool fail_ = false; /*! \brief The patterns we match the target stmt to.*/ - Array patterns_; + ffi::Array patterns_; /*! \brief The results of the matching process.*/ - Array match_results_; + ffi::Array match_results_; }; /*! \brief helper class to partition a function into 2 parts. Return function information which we @@ -444,7 +448,7 @@ class FunctionPartitioner : public StmtExprVisitor { /*! \brief alloc_buffers for the second function */ std::unordered_set allocs2; /*! \brief whether the current block is in the first function */ - Map block_partition; + ffi::Map block_partition; /*! \brief input buffers for the first function */ std::unordered_set input1; /*! \brief input buffers for the second function */ @@ -485,7 +489,7 @@ class FunctionPartitioner : public StmtExprVisitor { input2.insert(write->buffer); } } - block_partition.Set(GetRef(op), Bool(is_matching_)); + block_partition.Set(ffi::GetRef(op), Bool(is_matching_)); } // The number of matched ops in the function size_t num_matched_ops_; @@ -496,7 +500,7 @@ class FunctionPartitioner : public StmtExprVisitor { class BlockRemover : public StmtExprMutator { public: static Stmt RemoveBlockByPartition( - Stmt stmt, const Map& block_partition, + Stmt stmt, const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) { BlockRemover remover(block_partition, allocs, is_library_part); @@ -504,24 +508,24 @@ class BlockRemover : public StmtExprMutator { } private: - BlockRemover(const Map& block_partition, + BlockRemover(const ffi::Map& block_partition, const std::unordered_set& allocs, bool is_library_part) : block_partition(block_partition), allocs_(allocs), is_library_part_(is_library_part) {} Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - ObjectPtr n = make_object(*block.operator->()); + ObjectPtr n = ffi::make_object(*block.operator->()); if (op->name_hint != "root") { - ICHECK(block_partition.count(GetRef(op))); - bool block_is_library = block_partition[GetRef(op)]->value; + ICHECK(block_partition.count(ffi::GetRef(op))); + bool block_is_library = block_partition[ffi::GetRef(op)]->value; if (!(is_library_part_ ^ block_is_library)) { n->body = block->body; } else { erased_ = true; } } - Array alloc_buffers; + ffi::Array alloc_buffers; for (const Buffer& b : block->alloc_buffers) { if (allocs_.count(b)) { alloc_buffers.push_back(b); @@ -532,7 +536,7 @@ class BlockRemover : public StmtExprMutator { } Stmt VisitStmt_(const SeqStmtNode* op) final { - Array seq; + ffi::Array seq; for (const Stmt& s : op->seq) { Stmt new_s = VisitStmt(s); if (erased_) { @@ -545,7 +549,7 @@ class BlockRemover : public StmtExprMutator { } bool erased_ = false; - Map block_partition; + ffi::Map block_partition; std::unordered_set allocs_; bool is_library_part_ = false; }; @@ -560,22 +564,21 @@ class BlockRemover : public StmtExprMutator { * \return A pair of functions, the first one is the library kernel and the second one is the * rest. */ -std::pair> SplitFunctions(PrimFunc func, - std::vector>* arg_partition, - Array patterns, - FCodegen f_codegen) { +std::pair> SplitFunctions( + PrimFunc func, std::vector>* arg_partition, ffi::Array patterns, + FCodegen f_codegen) { // Step 1. Find the library kernel and the rest. Stmt body = func->body.as()->block->body; - Array match_results = + ffi::Array match_results = TIRPatternMatcher::Match(patterns, func->body.as()->block->body); if (match_results.empty()) { return {func, std::nullopt}; } - Array codegen_result = f_codegen(match_results); + ffi::Array codegen_result = f_codegen(match_results); ICHECK(codegen_result.size() == 3); - String library_code = Downcast(codegen_result[0]); + ffi::String library_code = Downcast(codegen_result[0]); int num_matched_ops = Downcast(codegen_result[1])->value; - Array func1_args = Downcast>(codegen_result[2]); + ffi::Array func1_args = Downcast>(codegen_result[2]); if (num_matched_ops == 0) { return {func, std::nullopt}; } @@ -601,7 +604,7 @@ std::pair> SplitFunctions(PrimFunc func, Stmt body2 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, partitioner.allocs2, false); // Step 3. Craft the first function. - Array new_params1; + ffi::Array new_params1; std::vector arg_partition1; ICHECK_LE(func1_args.size(), partitioner.input1.size()); for (const auto& buffer : func1_args) { @@ -616,7 +619,7 @@ std::pair> SplitFunctions(PrimFunc func, } arg_partition->push_back(arg_partition1); new_params1.push_back(Var("output", DataType::Handle())); - Map new_buffer_map1; + ffi::Map new_buffer_map1; for (const auto& kv : func->buffer_map) { if (partitioner.input1.count(kv.second)) { new_buffer_map1.Set(kv.first, kv.second); @@ -626,7 +629,7 @@ std::pair> SplitFunctions(PrimFunc func, PrimFunc func1 = PrimFunc(new_params1, body1, func->ret_type, new_buffer_map1, func->attrs); func1 = WithAttr(func1, kLibraryKernel, library_code); // Step 4. Craft the second function. - Array new_params2; + ffi::Array new_params2; std::vector arg_partition2; new_params2.push_back(Var("input", DataType::Handle())); for (int i = 0; i < static_cast(func->params.size()); i++) { @@ -639,7 +642,7 @@ std::pair> SplitFunctions(PrimFunc func, } } arg_partition->push_back(arg_partition2); - Map new_buffer_map2; + ffi::Map new_buffer_map2; new_buffer_map2.Set(new_params2[0], partitioner.intermediate_buffer); for (const auto& kv : func->buffer_map) { if (partitioner.input2.count(kv.second)) { @@ -659,18 +662,18 @@ void StringReplace(std::string* subject, const std::string& search, const std::s } } -tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { +tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, ffi::String global_symbol) { using namespace tvm::tir; - Optional library_code = pf->attrs.GetAttr(kLibraryKernel); + ffi::Optional library_code = pf->attrs.GetAttr(kLibraryKernel); if (!library_code.has_value()) { - return GetRef(pf); + return ffi::GetRef(pf); } std::string source = library_code.value(); StringReplace(&source, "{global_symbol}", global_symbol); ExternFunc ret(global_symbol); - ret = WithAttrs(std::move(ret), Map{ - {String(kCSource), String(source)}, - {String(kCSourceFmt), String(kCSourceFmtCuda)}, + ret = WithAttrs(std::move(ret), ffi::Map{ + {ffi::String(kCSource), ffi::String(source)}, + {ffi::String(kCSourceFmt), ffi::String(kCSourceFmtCuda)}, }); return ret; } @@ -678,13 +681,14 @@ tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symb /*! \brief Emit 2 calls to the library kernel and the rest of the function. */ class SplitMutator : public ExprMutator { public: - SplitMutator(const tvm::IRModule& mod, Array patterns, FCodegen fcodegen) + SplitMutator(const tvm::IRModule& mod, ffi::Array patterns, FCodegen fcodegen) : ExprMutator(mod), mod_(mod), patterns_(patterns), fcodegen_(fcodegen) {} - static IRModule Transform(const IRModule& mod, Array patterns, FCodegen fcodegen) { + static IRModule Transform(const IRModule& mod, ffi::Array patterns, + FCodegen fcodegen) { SplitMutator mutator(mod, patterns, fcodegen); for (auto& kv : mod->functions) { if (auto* func = kv.second.as()) { - Function new_func = Downcast(mutator(GetRef(func))); + Function new_func = Downcast(mutator(ffi::GetRef(func))); mutator.builder_->UpdateFunction(kv.first, new_func); } } @@ -694,7 +698,7 @@ class SplitMutator : public ExprMutator { private: using ExprMutator::VisitExpr_; - inline Array GetCallTIRArgs(Expr args) { + inline ffi::Array GetCallTIRArgs(Expr args) { if (args.as()) { return args.as()->fields; } else { @@ -710,22 +714,22 @@ class SplitMutator : public ExprMutator { // the first argument is the function to be called const auto* gv_ptr = call->args[0].as(); if (gv_ptr == nullptr) return call; - GlobalVar gv = GetRef(gv_ptr); + GlobalVar gv = ffi::GetRef(gv_ptr); // retrieve the function from the module and split it tir::PrimFunc func = Downcast(mod_->Lookup(gv)); std::vector> arg_partition; // split the function into two functions, one for the library kernel and one for the rest. - std::pair> split_funcs = + std::pair> split_funcs = tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); if (!split_funcs.second.defined()) { // no need to split, the function itself a library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); - if (lib_func->IsInstance()) return GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); // Update the function in the module with the library kernel ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); // emit the call to the library kernel - ObjectPtr new_call = make_object(*call.operator->()); + ObjectPtr new_call = ffi::make_object(*call.operator->()); new_call->op = this->call_dps_packed_; new_call->args = {lib_func, call->args[1]}; return Call(new_call); @@ -734,13 +738,13 @@ class SplitMutator : public ExprMutator { tir::PrimFunc func2 = tir::RenewDefs(split_funcs.second.value()); ICHECK(arg_partition.size() == 2); // emit the first call to the library kernel - Array args1; + ffi::Array args1; for (int p : arg_partition[0]) { args1.push_back(GetCallTIRArgs(call->args[1])[p]); } // replace the function in the module with the library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); - if (lib_func->IsInstance()) return GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); @@ -749,7 +753,7 @@ class SplitMutator : public ExprMutator { {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); // emit the second call to the rest of the function - Array args2; + ffi::Array args2; args2.push_back(call_var1); for (int p : arg_partition[1]) { args2.push_back(GetCallTIRArgs(call->args[1])[p]); @@ -762,12 +766,12 @@ class SplitMutator : public ExprMutator { const Op& call_dps_packed_ = Op::Get("relax.call_dps_packed"); tvm::IRModule mod_; - Array patterns_; + ffi::Array patterns_; FCodegen fcodegen_; }; namespace transform { -Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { +Pass SplitCallTIRByPattern(ffi::Array patterns, FCodegen fcodegen) { auto pass_func = // [=](IRModule m, PassContext pc) { return SplitMutator::Transform(m, patterns, fcodegen); }; return CreateModulePass(/*pass_function=*/pass_func, // diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 3fa9d52147d3..ccb723a0c163 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -35,7 +35,7 @@ namespace tir { class SplitPrimFuncLayoutRewrite : public StmtMutator { public: explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} - std::tuple, PrimFunc> Transform(const PrimFunc& func) { + std::tuple, PrimFunc> Transform(const PrimFunc& func) { ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; const auto& block = func->body.as()->block; visit_root_block(block.get()); @@ -58,8 +58,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; // Step 2: Create the params for the new PrimFunc - Array params; - Map buffer_map; + ffi::Array params; + ffi::Map buffer_map; for (const auto& info : rewrite_infos_) { params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); @@ -76,16 +76,16 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] : SeqStmt(layout_rewrite_preproc_stmts_); body = BlockRealize( - /*iter_values=*/Array(), + /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body)); - Map dict; + ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { if (key == "global_symbol") { - dict.Set(key, Downcast(original_value) + "_weight_prepack"); + dict.Set(key, Downcast(original_value) + "_weight_prepack"); } else if (key != "layout_free_buffers") { dict.Set(key, original_value); } @@ -98,8 +98,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { PrimFunc create_compute_func() const { // Step 1: Create the params for the new PrimFunc - Array params = original_func_->params; - Map buffer_map = original_func_->buffer_map; + ffi::Array params = original_func_->params; + ffi::Map buffer_map = original_func_->buffer_map; for (const auto& info : rewrite_infos_) { const Var& param = params[info.buffer_index]; ICHECK(buffer_map[param] == info.pre_rewrite_buffer); @@ -109,7 +109,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { // Step 2: Create the body for the new PrimFunc Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); Block original_block = original_func_->body.as()->block; - Array alloc_buffers; + ffi::Array alloc_buffers; for (const auto& buffer : original_block->alloc_buffers) { auto it = std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), @@ -120,7 +120,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { } body = BlockRealize( - /*iter_values=*/Array(), + /*iter_values=*/ffi::Array(), /*predicate=*/const_true(), /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, @@ -128,10 +128,10 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers)); - Map dict; + ffi::Map dict; for (const auto& [key, original_value] : original_func_->attrs->dict) { if (key == "global_symbol") { - dict.Set(key, Downcast(original_value) + "_prepacked"); + dict.Set(key, Downcast(original_value) + "_prepacked"); } else if (key != "layout_free_buffers") { dict.Set(key, original_value); } @@ -199,7 +199,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { auto new_annotations = op->annotations; new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); - auto n = make_object(*block.get()); + auto n = ffi::make_object(*block.get()); n->annotations = new_annotations; return Block(n); } @@ -216,9 +216,9 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { private: /*! \brief The stmts that are used for layout rewrite preproc*/ - Array layout_rewrite_preproc_stmts_; + ffi::Array layout_rewrite_preproc_stmts_; /*! \brief The stmts that are other than layout rewrite preproc*/ - Array compute_stmts_; + ffi::Array compute_stmts_; /*! \brief Whether the current subtree is a layout rewrite preproc subtree. -1: visited a non-layout rewrite preproc block @@ -290,9 +290,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { const auto& rewrite_infos = rewrite_infos_it->second; // Step 5: Emit the preproc call - Array call_tir_args = Downcast(call->args[1])->fields; - Array preproc_args; - Array preproc_sinfo_list; + ffi::Array call_tir_args = Downcast(call->args[1])->fields; + ffi::Array preproc_args; + ffi::Array preproc_sinfo_list; for (const auto& info : rewrite_infos) { preproc_args.push_back(call_tir_args[info.buffer_index]); tir::Buffer rewritten_buffer = info.post_rewrite_buffer; diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index f2e185ebd2d4..572ea35931d9 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -129,7 +129,7 @@ class StorageTokenNode : public Object { */ class StorageToken : public ObjectRef { public: - explicit StorageToken(Array shape, DataType dtype, std::string storage_scope) { + explicit StorageToken(ffi::Array shape, DataType dtype, std::string storage_scope) { // Compute the tensor size from the shape. int64_t const_coeff = dtype.bytes() * dtype.lanes(); PrimExpr size = tir::make_const(DataType::Int(64), 1); @@ -142,7 +142,7 @@ class StorageToken : public ObjectRef { } size = tir::make_const(DataType::Int(64), const_coeff) * size; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->bytes = size; n->dtype = dtype; n->storage_scope = std::move(storage_scope); @@ -170,7 +170,7 @@ class TokenAllocator1D { * \return The request result token. Return std::nullopt if there is no * appropriate available token in the pool. */ - Optional RequestReuse(StorageToken prototype) { + ffi::Optional RequestReuse(StorageToken prototype) { // Step 0. Sanity check: the prototype token is supposed not to be allocated with actual storage ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; // If the prototype has no reference at all, feel free to allocate new storage. @@ -326,7 +326,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { } void VisitExpr_(const TupleNode* tuple) final { - Array tokens; + ffi::Array tokens; tokens.reserve(tuple->fields.size()); for (const Expr& field : tuple->fields) { Tokens field_tokens = GetTokens(field); @@ -343,7 +343,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { return; } ICHECK(tokens.IsNested()); - Array field_tokens = tokens.NestedArray(); + ffi::Array field_tokens = tokens.NestedArray(); ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); ICHECK_GE(tuple_item->index, 0); SetTokens(tuple_item, field_tokens[tuple_item->index]); @@ -372,25 +372,27 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { * \param dom_map The domain map of the TIR variables. */ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, - Map* dom_map) { + ffi::Map* dom_map) { // Use the attribute-annotated TIR var upper bounds as the TIR var values for // memory planning. // NOTE: we only apply the annotated upper bounds to the TIR variables that // appear in the **function signature**. - Map var_upper_bound_attr_raw = - func->GetAttr>("tir_var_upper_bound").value_or(Map()); - Array non_negative_var_attr_raw = - func->GetAttr>("tir_non_negative_var").value_or(Array()); - std::unordered_map var_upper_bound_attr; - std::unordered_set non_negative_var_attr; + ffi::Map var_upper_bound_attr_raw = + func->GetAttr>("tir_var_upper_bound") + .value_or(ffi::Map()); + ffi::Array non_negative_var_attr_raw = + func->GetAttr>("tir_non_negative_var") + .value_or(ffi::Array()); + std::unordered_map var_upper_bound_attr; + std::unordered_set non_negative_var_attr; // We manually check the value type to ensure the values are all positive IntImm. for (auto [key, value] : var_upper_bound_attr_raw) { var_upper_bound_attr[key] = value; } - for (const String& var_name : non_negative_var_attr_raw) { + for (const ffi::String& var_name : non_negative_var_attr_raw) { non_negative_var_attr.insert(var_name); } - Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); + ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); for (const tir::Var& tir_var : var_in_signature) { auto it = var_upper_bound_attr.find(tir_var->name_hint); if (it != var_upper_bound_attr.end()) { @@ -414,10 +416,10 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, * \return The upper-bounded shape. When a dimension's upper bound * cannot be determined, we keep the dimension unchanged. */ -Array GetUpperBoundShape(Array shape, arith::Analyzer* ana, - const Map& dom_map) { +ffi::Array GetUpperBoundShape(ffi::Array shape, arith::Analyzer* ana, + const ffi::Map& dom_map) { // Use the upper bounds of TIR vars as their values. - Array upper_bounded_shape; + ffi::Array upper_bounded_shape; upper_bounded_shape.reserve(shape.size()); for (const PrimExpr& dim_len : shape) { int64_t max_bound = ana->const_int_bound(dim_len)->max_value; @@ -436,7 +438,7 @@ Array GetUpperBoundShape(Array shape, arith::Analyzer* ana, } /*! \brief Check if a shape is static (a.k.a., has no TIR variable). */ -bool IsStaticShape(Array shape) { +bool IsStaticShape(ffi::Array shape) { for (const PrimExpr& dim : shape) { const auto* int_len = dim.as(); if (!int_len) { @@ -471,7 +473,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { if (func == nullptr) { continue; } - initializer(GetRef(func)); + initializer(ffi::GetRef(func)); } return initializer.token_map_; } @@ -484,7 +486,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { void VisitExpr_(const FunctionNode* func) final { // Set the upper bound of TIR variables in the analyzer. - SetTIRVarUpperBound(GetRef(func), analyzer_, &dom_map_); + SetTIRVarUpperBound(ffi::GetRef(func), analyzer_, &dom_map_); // Recurse into the function to get its tokens. Tokens body_tokens = GetTokens(func->body); // Discard the tokens used by the function return value, as they are external referenced. @@ -513,7 +515,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // potential external reference. if (IsPrimFuncGlobalVar(call->op) || call->op->IsInstance() || call->op == call_tir_dyn_op) { - Array args = + ffi::Array args = call->op == call_tir_dyn_op ? Downcast(call->args[1])->fields : call->args; ICHECK(!block_stack_.empty()); for (const Expr& arg : call->args) { @@ -559,7 +561,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { if (global_var == nullptr) { return false; } - auto func_it = ctx_mod_->functions.find(GetRef(global_var)); + auto func_it = ctx_mod_->functions.find(ffi::GetRef(global_var)); if (func_it == ctx_mod_->functions.end()) { return false; } @@ -587,7 +589,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic // if the upper bounds of some variables are not provided. - Array upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_, dom_map_); + ffi::Array upper_bounded_shape = + GetUpperBoundShape(shape->values, analyzer_, dom_map_); // Create and set token. StringImm storage_scope = Downcast(call->args[3]); @@ -664,7 +667,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { /*! \brief The arithmetic analyzer. */ arith::Analyzer* analyzer_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - Map dom_map_; + ffi::Map dom_map_; /*! \brief The mapping from each token to the binding block where it is created. */ std::unordered_map token2block_; /*! \brief The mapping from each token to the Exprs that are using this token. */ @@ -780,7 +783,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { /*! \brief Request a storage reuse, or allocate storage if no appropriate storage is reusable. */ StorageToken RequestReuseOrAlloc(StorageToken prototype) { - Optional token = allocator_.RequestReuse(prototype); + ffi::Optional token = allocator_.RequestReuse(prototype); if (!token.defined()) { return allocator_.Alloc(prototype, this->n_storage_++); } else { @@ -840,7 +843,7 @@ class StorageAllocationRewriter : public ExprMutator { plan_dynamic_output_ = static_cast( func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); if (plan_dynamic_output_) { - SetTIRVarUpperBound(GetRef(func_), &ana_, &dom_map_); + SetTIRVarUpperBound(ffi::GetRef(func_), &ana_, &dom_map_); } token2storage_var_.clear(); Function func = Downcast(this->VisitExpr_(func_)); @@ -903,7 +906,7 @@ class StorageAllocationRewriter : public ExprMutator { ICHECK_NOTNULL(sinfo); const auto* shape = sinfo->shape.as(); ICHECK_NOTNULL(shape); - Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); + ffi::Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); if (!IsStaticShape(shape->values)) { ICHECK(!sinfo->IsUnknownDtype()); ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); @@ -920,7 +923,7 @@ class StorageAllocationRewriter : public ExprMutator { Var storage = builder_->Emit(alloc_storage, "storage"); return Call(mem_alloc_tensor, {storage, // /*offset=*/PrimValue::Int64(0), - /*shape=*/GetRef(shape), // + /*shape=*/ffi::GetRef(shape), // /*dtype=*/DataTypeImm(sinfo->dtype)}); } } @@ -931,7 +934,7 @@ class StorageAllocationRewriter : public ExprMutator { /*! \brief The arithmetic analyzer. */ arith::Analyzer ana_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - Map dom_map_; + ffi::Map dom_map_; /*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */ bool plan_dynamic_output_; /*! diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index 90b343faa628..026e68c3ba6f 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -44,7 +44,7 @@ int GetMixedPrecisionInfo(const CallNode* call_node) { if (op_node == nullptr) { return -1; } - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); auto attr_map = Op::GetAttrMap("TMixedPrecisionPolicy"); return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever; } @@ -146,12 +146,12 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(Array args, Array to) { + void RequireArgsToType(ffi::Array args, ffi::Array to) { ICHECK(args.size() == to.size()) << "Invalid target dtypes"; for (size_t i = 0; i < args.size(); ++i) { auto fvisitleaf = [&](const Expr& expr, NType to) { if (const auto* var = expr.as()) { - UpdateVarDTypeMap(GetRef(var), to); + UpdateVarDTypeMap(ffi::GetRef(var), to); } else if (expr->IsInstance()) { // Constant can be casted anyway, so we don't need to do anything here return; @@ -164,7 +164,7 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(Array args, DataType to) { + void RequireArgsToType(ffi::Array args, DataType to) { std::vector arg_arr; std::vector to_arr; for (const Expr& arg : args) { @@ -178,7 +178,7 @@ class DTypeDecisionCollector : public ExprVisitor { } void VisitVars_(const VarNode* op) { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (IsNestedTensor(var)) { // require the var to be fp32 (its original dtype) UpdateVarDTypeMap(var, NTypeFrom(var, fp32_)); @@ -239,7 +239,7 @@ class DTypeDecisionCollector : public ExprVisitor { } if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -258,7 +258,7 @@ class DTypeDecisionCollector : public ExprVisitor { this->VisitExpr(op->cond); if (auto* sinfo = op->struct_info_.as()) { - this->VisitExprDepStructInfoField(GetRef(sinfo)); + this->VisitExprDepStructInfoField(ffi::GetRef(sinfo)); } } @@ -301,7 +301,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } } - Array RemapArgs(const Array& args) { + ffi::Array RemapArgs(const ffi::Array& args) { return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); }); } @@ -317,13 +317,13 @@ class ToMixedPrecisionRewriter : public ExprMutator { // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not // supported to be rewritten if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr; - return astype(expr, DataType(StringToDLDataType(to[0].LeafValue()))); + return astype(expr, DataType(ffi::StringToDLDataType(to[0].LeafValue()))); }; - return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); + return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); } - Array RewriteArgs(const Array& args, DataType to) { - Array new_args; + ffi::Array RewriteArgs(const ffi::Array& args, DataType to) { + ffi::Array new_args; for (const Expr& arg : args) { if (IsNestedTensor(arg)) { new_args.push_back(RewriteExpr(arg, NTypeFrom(arg, to))); @@ -344,7 +344,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return true; } - bool AllFP16Castable(const Array& args) { + bool AllFP16Castable(const ffi::Array& args) { auto is_fp16 = [](StructInfo sinfo) { if (auto tensor_sinfo = sinfo.as(); tensor_sinfo && tensor_sinfo->dtype == DataType::Float(16)) { @@ -413,11 +413,11 @@ class ToMixedPrecisionRewriter : public ExprMutator { auto it = only_fp16_map_->find(var); if (it == only_fp16_map_->end()) return; // Get the to dtype, cast to fp16 if the var is fp16 only, otherwise do nothing - auto fcombine = [](const String& from, const String& required) -> String { + auto fcombine = [](const ffi::String& from, const ffi::String& required) -> ffi::String { return required == "float16" ? required : from; }; NType from = NTypeFrom(cur_var); - NType to = CombineNestedMsg(from, it->second, fcombine); + NType to = CombineNestedMsg(from, it->second, fcombine); Expr rewrite = RewriteExpr(cur_var, to); // If cur_var is not rewritten, we don't need to emit a new var if (!rewrite.same_as(cur_var)) { @@ -439,7 +439,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (!builder_->CurrentBlockIsDataFlow()) { return ExprMutator::VisitExpr_(op); } - return VisitVar_(GetRef(op)); + return VisitVar_(ffi::GetRef(op)); } Var VisitVarDef(const Var& var) { return GetRemapped(var); } @@ -464,14 +464,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { // var = Call(op) const auto* op_node = call_node->op.as(); ICHECK(op_node != nullptr); - Op op = GetRef(op_node); + Op op = ffi::GetRef(op_node); if (wrap_param_op.same_as(op)) { // wrap_param ReEmitBinding(binding, call_node->args[0]); return; } - Call new_call = GetRef(call_node); + Call new_call = ffi::GetRef(call_node); // We first to remap the args to the current vars according to the var_remap_ new_call.CopyOnWrite()->args = RemapArgs(new_call->args); @@ -493,7 +493,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { // cast back to the original datatype. if (!new_call->args.same_as(call_node->args)) { - Array new_typed_args; + ffi::Array new_typed_args; for (size_t i = 0; i < call_node->args.size(); i++) { auto arg = new_call->args[i]; auto old_ntype = NTypeFrom(call_node->args[i]); @@ -532,7 +532,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { ExprMutator::VisitBinding_(binding, tuple_node); return; } - ObjectPtr new_tuple = make_object(*tuple_node); + ObjectPtr new_tuple = ffi::make_object(*tuple_node); new_tuple->fields = RemapArgs(tuple_node->fields); new_tuple->struct_info_ = std::nullopt; Expr new_value = builder_->Normalize(Tuple(new_tuple)); @@ -552,7 +552,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return; } ObjectPtr new_tuple_get_item = - make_object(*tuple_get_item_node); + ffi::make_object(*tuple_get_item_node); new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; new_tuple_get_item->struct_info_ = std::nullopt; Expr new_value = TupleGetItem(new_tuple_get_item); @@ -593,14 +593,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); DataType output_dtype_; - Array params_; + ffi::Array params_; std::unordered_set fp16_input_names_; const Op& wrap_param_op = Op::Get("relax.wrap_param"); }; Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, - Optional> fp16_input_names) { + ffi::Optional> fp16_input_names) { VarDTypeMap only_fp16_map = DTypeDecisionCollector::Collect(f, out_dtype); std::unordered_set fp16_input_names_set; if (fp16_input_names) { @@ -612,7 +612,8 @@ Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, namespace transform { -Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_input_names) { +Pass ToMixedPrecision(const DataType& out_dtype, + ffi::Optional> fp16_input_names) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(ToMixedPrecision(f, out_dtype, fp16_input_names)); }; diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index c9f11b32bee7..7bf2141f75d5 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -149,7 +149,7 @@ class BindingOrderCollector : ExprVisitor { } void VisitExpr_(const VarNode* op) override { - Var upstream_requirement = GetRef(op); + Var upstream_requirement = ffi::GetRef(op); auto downstream_user = current_binding_; dependencies_.downstream_users[upstream_requirement].push_back(downstream_user); @@ -167,7 +167,7 @@ class TopologicalSorter : public ExprMutator { Expr VisitExpr_(const FunctionNode* op) override { auto cached = dependencies_; - dependencies_ = BindingOrderCollector::Collect(GetRef(op)); + dependencies_ = BindingOrderCollector::Collect(ffi::GetRef(op)); if (starting_location_ == StartingLocation::FromOutputs) { std::reverse(dependencies_.binding_order.begin(), dependencies_.binding_order.end()); @@ -184,7 +184,7 @@ class TopologicalSorter : public ExprMutator { } BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { - auto block = GetRef(op); + auto block = ffi::GetRef(op); // A map from not-yet-defined variables to the binding that will // define the variable. Items are removed from this map as they @@ -309,13 +309,13 @@ class TopologicalSorter : public ExprMutator { << "no bindings should remain to emit. " << "However, bindings " << [&]() { - Array arr; + ffi::Array arr; for (const auto& [var, binding] : to_emit) { arr.push_back(var); } return arr; }() << " still remain after emitting " - << Array(new_bindings.begin(), new_bindings.end()) + << ffi::Array(new_bindings.begin(), new_bindings.end()) .Map([](const Binding& binding) { return binding->var; }); if (starting_location_ == StartingLocation::FromOutputs) { @@ -346,7 +346,8 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.transform.TopologicalSort", [](String order_str, String direction_str) -> Pass { + "relax.transform.TopologicalSort", + [](ffi::String order_str, ffi::String direction_str) -> Pass { TraversalOrder order = [&]() { if (order_str == "depth-first") { return TraversalOrder::DepthFirst; diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 85acec6942da..0bf0c6ae6bb6 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -40,14 +40,14 @@ namespace relax { namespace { class ParamStructInfoMutator : public ExprMutator { public: - explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) + explicit ParamStructInfoMutator(ffi::TypedFunction(Var)> sinfo_func) : sinfo_func_(sinfo_func) {} using ExprMutator::VisitExpr_; using ExprMutator::VisitVarDef_; Expr VisitExpr_(const FunctionNode* op) override { - auto func = GetRef(op); + auto func = ffi::GetRef(op); auto params = op->params.Map([this](Var param) { if (auto new_sinfo = sinfo_func_(param)) { @@ -65,12 +65,12 @@ class ParamStructInfoMutator : public ExprMutator { return ExprMutator::VisitExpr_(func.get()); } - ffi::TypedFunction(Var)> sinfo_func_; + ffi::TypedFunction(Var)> sinfo_func_; }; } // namespace namespace transform { -Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { +Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_func) { auto pass_func = [=](IRModule mod, PassContext pc) { ParamStructInfoMutator mutator(sinfo_func); diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index fc7d8941fe51..77d4f21ee6d3 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -35,7 +35,7 @@ class VDeviceMutator : public ExprMutator { public: VDeviceMutator(const IRModule& mod, VDevice new_vdevice, int64_t index) : ExprMutator(mod), mod_(mod), new_vdevice_(new_vdevice) { - Array vdevices = mod->global_infos["vdevice"]; + ffi::Array vdevices = mod->global_infos["vdevice"]; old_vdevice_ = Downcast(vdevices[index]); } @@ -74,7 +74,7 @@ class VDeviceMutator : public ExprMutator { builder_->UpdateFunction(gv, update_func); } } - Array new_vdevices; + ffi::Array new_vdevices; for (auto vdev : mod_->global_infos["vdevice"]) { if (vdev == old_vdevice_) { new_vdevices.push_back(new_vdevice_); diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index 19e93bbc0c0e..580b3892e57b 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -44,15 +44,15 @@ bool IsNestedTensor(const StructInfo& sinfo) { bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } Function ComposeFunctions(Function func_a, Function func_b) { - Array bindings; + ffi::Array bindings; Var func_a_output("func_a_output", func_a->ret_struct_info); bindings.push_back(VarBinding(func_a_output, func_a->body)); - auto func_a_outputs = [&]() -> Array { + auto func_a_outputs = [&]() -> ffi::Array { if (auto func_a_output_tuple = func_a->ret_struct_info.as()) { - Array outputs; + ffi::Array outputs; for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) { outputs.push_back(TupleGetItem(func_a_output, i)); } diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index e4fe449ed65e..ff8596cd79e3 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -84,8 +84,8 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor(vn))); - return memo_[GetRef(vn)]; + ICHECK(memo_.count(ffi::GetRef(vn))); + return memo_[ffi::GetRef(vn)]; } virtual OutputType VisitBinding_(const VarBindingNode* binding) { @@ -115,7 +115,7 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor entry_funcs); +TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, ffi::Array entry_funcs); /*! * \brief Get the external symbol of the Relax function name. @@ -124,7 +124,7 @@ TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, Array entry_fu * \return An external symbol. */ inline std::string GetExtSymbol(const Function& func) { - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.has_value()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -142,7 +142,7 @@ inline std::string GetExtSymbol(const Function& func) { */ IRModule MakeGroupedFunctions( IRModule mod, const std::unordered_map& partition, - bool lift_constants = true, const Array& entry_function_names = {}); + bool lift_constants = true, const ffi::Array& entry_function_names = {}); /*! * \brief Check if the given StructInfo is a scalar tensor. The sinfo should be an instance of @@ -172,7 +172,7 @@ bool IsScalarTensor(const Expr& expr); template bool IsNestedTensorConditioned(const StructInfo& sinfo, FType f_condition) { if (const auto* tensor_sinfo = sinfo.as()) { - return f_condition(GetRef(tensor_sinfo)); + return f_condition(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = sinfo.as()) { return !std::any_of( tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), @@ -209,7 +209,7 @@ class VarReplacer : public ExprMutator { private: Expr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = var_remap_.find(var->vid); return it == var_remap_.end() ? var : it->second; } @@ -241,19 +241,19 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { // 1. Visit and replace all tir::Vars at the definition point // 2. Revisit the function again and update the use side. PrimExpr VisitExpr_(const tir::VarNode* op) final { - auto it = var_map_.find(GetRef(op)); + auto it = var_map_.find(ffi::GetRef(op)); if (it != var_map_.end()) { return (*it).second; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); tir::Var v(n); - var_map_.Set(GetRef(op), v); + var_map_.Set(ffi::GetRef(op), v); return v; } } Expr VisitExpr_(const FunctionNode* op) { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (Var param : op->params) { Var new_param = this->VisitVarDef(param); @@ -267,14 +267,14 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { Expr body = this->VisitWithNewScope(op->body, params); if (all_params_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto new_ret_sinfo = this->VisitExprDepStructInfoField(op->ret_struct_info); return Function(params, body, new_ret_sinfo, op->is_pure, op->attrs); } } - Map var_map_; + ffi::Map var_map_; }; /*! @@ -286,7 +286,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { public: FunctionCopier() = default; Function Copy(Function func) { return Downcast(VisitExpr(func)); } - Map GetVarMap() { return relax_var_map_; } + ffi::Map GetVarMap() { return relax_var_map_; } private: using relax::ExprMutator::VisitExpr; @@ -295,7 +295,7 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span); var_remap_[var->vid] = copied_var; - relax_var_map_.Set(GetRef(var), copied_var); + relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; } @@ -303,11 +303,11 @@ class FunctionCopier : public SymbolicVarRenewMutator { Var new_var = SymbolicVarRenewMutator::VisitVarDef_(var); Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span); var_remap_[var->vid] = copied_var; - relax_var_map_.Set(GetRef(var), copied_var); + relax_var_map_.Set(ffi::GetRef(var), copied_var); return copied_var; } - Map relax_var_map_; + ffi::Map relax_var_map_; }; /*! @@ -360,7 +360,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) { return Constant(arr); } -inline Array GetOrderedPositiveAxes(const Array& axes, int ndim) { +inline ffi::Array GetOrderedPositiveAxes(const ffi::Array& axes, int ndim) { std::vector ret; ret.reserve(axes.size()); for (const auto& axis : axes) { @@ -376,7 +376,7 @@ inline Array GetOrderedPositiveAxes(const Array& axes, int ndi return support::AsArray(ret); } -inline String GetCodegenName(const std::string& composite_name) { +inline ffi::String GetCodegenName(const std::string& composite_name) { auto delim_pos = composite_name.find("."); ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " "start with a compiler name followed by period."; @@ -384,7 +384,7 @@ inline String GetCodegenName(const std::string& composite_name) { } inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { - Array vdevices = mod->global_infos["vdevice"]; + ffi::Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { if (vdevices[i] == vdevice) { return i; @@ -434,7 +434,8 @@ Expr CanonicalizeBindings(Expr expr); * * \ret The updated function. */ -Function BundleModelParams(const Function& func, Optional param_tuple_name = std::nullopt); +Function BundleModelParams(const Function& func, + ffi::Optional param_tuple_name = std::nullopt); /*! \brief Compose two functions * diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 92747a2515d5..d594ce90b499 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -31,15 +31,15 @@ namespace relax { /*! \brief Helper to implement bind params.*/ class ExprBinder : public ExprMutator { public: - explicit ExprBinder(const tvm::Map& args_map, - const tvm::Map& symbolic_var_map) + explicit ExprBinder(const tvm::ffi::Map& args_map, + const tvm::ffi::Map& symbolic_var_map) : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {} private: using ExprMutator::VisitExpr_; Expr VisitExpr_(const FunctionNode* op) final { - tvm::Array params; + tvm::ffi::Array params; bool all_params_unchanged = true; for (const Var& param : op->params) { if (args_map_.count(param)) { @@ -58,7 +58,7 @@ class ExprBinder : public ExprMutator { // FuncStructInfo does not depend on Expr if (all_params_unchanged && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { // purity won't be affected, no need to update annotation return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->is_pure, @@ -67,7 +67,7 @@ class ExprBinder : public ExprMutator { } Expr VisitExpr_(const VarNode* op) final { - auto id = GetRef(op); + auto id = ffi::GetRef(op); auto it = args_map_.find(id); if (it != args_map_.end()) { return (*it).second; @@ -86,8 +86,8 @@ class ExprBinder : public ExprMutator { } private: - const tvm::Map& args_map_; - const tvm::Map& symbolic_var_map_; + const tvm::ffi::Map& args_map_; + const tvm::ffi::Map& symbolic_var_map_; }; /*! @@ -97,18 +97,19 @@ class ExprBinder : public ExprMutator { * \param symbolic_var_map The map from symbolic var to the expr it binds to * \return The result expr after bind params */ -Expr Bind(const Expr& expr, const tvm::Map& binds, - const tvm::Map& symbolic_var_map) { +Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } -StructInfo Bind(const StructInfo& sinfo, const tvm::Map& symbolic_var_map) { +StructInfo Bind(const StructInfo& sinfo, + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo); } -tvm::Map InferSymbolicVarMap( - const tvm::Map& relax_var_remap, arith::Analyzer* analyzer) { - tvm::Map tir_var_remap; +tvm::ffi::Map InferSymbolicVarMap( + const tvm::ffi::Map& relax_var_remap, arith::Analyzer* analyzer) { + tvm::ffi::Map tir_var_remap; auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape, const PrimExpr& expr_shape) { @@ -218,7 +219,7 @@ bool IsLeafOrTuple(const Expr& expr) { bool IsImpureCall(const Call& call) { if (auto op_ptr = call->op.as()) { - auto op = GetRef(op_ptr); + auto op = ffi::GetRef(op_ptr); static auto purity_map = Op::GetAttrMap("FPurity"); ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this op: " << op->name; return !(purity_map[op]->value); diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 6f07e10f62d7..c4604348ba01 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -67,7 +67,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { } } - ffi::Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be @@ -80,7 +80,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { if (name == "get_const_var_tensor") { return ffi::Function([_self, this](ffi::PackedArgs args, ffi::Any* rv) { - Map ret_map; + ffi::Map ret_map; for (const auto& kv : const_var_tensor_) { ret_map.Set(kv.first, kv.second); } @@ -109,8 +109,8 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { * \param symbol The symbol that is being queried. * \return The list of needed Tensor. */ - Array GetRequiredConstants(const std::string& symbol) { - Array ret; + ffi::Array GetRequiredConstants(const std::string& symbol) { + ffi::Array ret; ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No constants known for function '" << symbol << "'"; std::vector vars = const_vars_by_symbol_[symbol]; @@ -139,7 +139,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { for (const Any& it : this->imports_) { // Get the initialization function from the imported modules. std::string init_name = "__init_" + symbol; - Optional init = it.cast()->GetFunction(init_name, false); + ffi::Optional init = it.cast()->GetFunction(init_name, false); if (init.has_value()) { auto md = GetRequiredConstants(symbol); // Initialize the module with constants. @@ -159,7 +159,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { std::vector variables; std::vector const_var_tensor; for (const auto& it : const_var_tensor_) { - String var_name = it.first; + ffi::String var_name = it.first; variables.push_back(var_name); const_var_tensor.push_back(it.second); } @@ -232,7 +232,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { const_vars_by_symbol[symbols[i]] = const_vars[i]; } - auto n = make_object(const_var_tensor, const_vars_by_symbol); + auto n = ffi::make_object(const_var_tensor, const_vars_by_symbol); return ffi::Module(n); } @@ -251,7 +251,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { ffi::Module ConstLoaderModuleCreate( const std::unordered_map& const_var_tensor, const std::unordered_map>& const_vars_by_symbol) { - auto n = make_object(const_var_tensor, const_vars_by_symbol); + auto n = ffi::make_object(const_var_tensor, const_vars_by_symbol); return ffi::Module(n); } diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 92e4bd06e254..5cd6a1746647 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -61,7 +61,7 @@ class ACLRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit ACLRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} /*! @@ -77,7 +77,7 @@ class ACLRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -588,9 +588,9 @@ class ACLRuntime : public JSONRuntimeBase { } #endif }; -ffi::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module ACLRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 0386bde3783b..735a5eff7bd2 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -88,12 +88,12 @@ ThreadingConfig getDefaultThreadingConfig() { class BNNSJSONRuntime : public JSONRuntimeBase { public: BNNSJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} const char* kind() const override { return "bnns_json"; } - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; @@ -557,9 +557,9 @@ class BNNSJSONRuntime : public JSONRuntimeBase { std::vector tensors_eid_; }; -ffi::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module BNNSJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 39e38aa8725d..62ba4846f6d1 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -149,7 +149,7 @@ class CLMLRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit CLMLRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), clml_symbol(symbol_name) {} ~CLMLRuntime() { @@ -201,7 +201,7 @@ class CLMLRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -270,7 +270,7 @@ class CLMLRuntime : public JSONRuntimeBase { "same by exporting CLML_DISABLE_RECORDABLE_QUEUE at runtime."; } cl_command_queue queue = CLML_QUEUE; - Map dump_tensors; + ffi::Map dump_tensors; std::ostringstream os; dmlc::JSONWriter writer(&os); writer.BeginObject(); @@ -354,7 +354,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector shape = nodes_[nid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); prof->StartCall("CopyIn", cws->tentry->device, metrics); CLML_CALL(clEnqueueCopyMLTensorDataQCOM, queue, layer_.in_placeholder[nid]->tensor, @@ -380,7 +380,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector shape = node.GetOpShape()[0]; DLDataType tvm_dtype = node.GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); // Launch call prof->StartCall(clml_symbol + "-" + this->layer_.layer_names[i], cws->tentry->device, @@ -412,7 +412,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector shape = nodes_[eid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[eid].GetOpDataType()[0]; shape_str.append(profiling::ShapeString(shape, tvm_dtype)); - metrics["Argument Shapes"] = String(shape_str); + metrics["Argument Shapes"] = ffi::String(shape_str); prof->StartCall("CopyOut", cws->tentry->device, metrics); CLML_CALL(clEnqueueCopyMLTensorDataQCOM, queue, layer_.outputs[i]->tensor, @@ -1826,9 +1826,9 @@ class CLMLRuntime : public JSONRuntimeBase { std::string clml_symbol; }; -ffi::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module CLMLRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index 3f7db78bfc31..9aa8cf839e4c 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -104,7 +104,7 @@ class CoreMLRuntime : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 5926fb32d62c..e0c1653077a8 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -129,7 +129,7 @@ model_ = std::unique_ptr(new CoreMLModel(url)); } -Optional CoreMLRuntime::GetFunction(const String& name) { +ffi::Optional CoreMLRuntime::GetFunction(const ffi::String& name) { // Return member functions during query. if (name == "invoke" || name == "run") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { model_->Invoke(); }); @@ -153,7 +153,7 @@ NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data options:NSJSONReadingAllowFragments error:nil]; - NSArray* input_names = json[@"inputs"]; + NSffi::Array* input_names = json[@"inputs"]; // Copy input tensors to corresponding data entries. for (auto i = 0; i < args.size() - 1; ++i) { @@ -186,7 +186,7 @@ } ffi::Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_path) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(symbol, model_path); return ffi::Module(exec); } @@ -250,7 +250,7 @@ BOOL res = [dirWrapper writeToURL:url options:0 originalContentsURL:nil error:nil]; ICHECK(res) << "Failed to create model directory " << [model_path UTF8String]; - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(symbol, [model_path UTF8String]); return ffi::Module(exec); } diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 99eda5cc89f8..98b05ba31995 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -46,12 +46,12 @@ using namespace tvm::runtime::json; class CublasJSONRuntime : public JSONRuntimeBase { public: CublasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override {} + void Init(const ffi::Array& consts) override {} - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since CublasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call cuBLAS on the inputs from ffi::PackedArgs. @@ -153,9 +153,9 @@ class CublasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -ffi::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module CublasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h index ae11764ce02c..077ab57966a5 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -69,7 +69,7 @@ class CuDNNSDPARunnerNode : public tvm::runtime::Object { class CuDNNSDPARunner : public tvm::runtime::ObjectRef { public: static CuDNNSDPARunner Create() { - auto n = make_object(); + auto n = ffi::make_object(); return CuDNNSDPARunner(n); } diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 1e17cf2ecfd4..fa046980e39a 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -49,10 +49,10 @@ using namespace tvm::runtime::json; class cuDNNJSONRuntime : public JSONRuntimeBase { public: cuDNNJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { op_execs_.resize(nodes_.size()); // get some config from the graph for (size_t i = 0; i < nodes_.size(); ++i) { @@ -238,9 +238,9 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::vector> op_execs_; }; -ffi::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module cuDNNJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index eccfb913d177..3b9304f11c61 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -51,7 +51,7 @@ using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), next_unique_eid_offset_(data_entry_.size()), run_arg_eid_(input_var_eid_) { @@ -60,7 +60,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "dnnl_json"; } - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; @@ -100,7 +100,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Override GetFunction to reimplement Run method */ - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -923,9 +923,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::vector run_arg_eid_; }; -ffi::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module DNNLJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index a52da2318b71..34d335c0e900 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -64,7 +64,7 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, Device dev) { } ffi::Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(tflite_model_bytes, dev); return ffi::Module(exec); } diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 046c1c14b30b..6e760b7f0625 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -44,12 +44,12 @@ using namespace tvm::runtime::json; class HipblasJSONRuntime : public JSONRuntimeBase { public: HipblasJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - void Init(const Array& consts) override {} + void Init(const ffi::Array& consts) override {} - ffi::Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since HipblasJSONRuntime // can be used by multiple GPUs running on different threads, we avoid using that function // and directly call hipBLAS on the inputs from ffi::PackedArgs. @@ -140,9 +140,9 @@ class HipblasJSONRuntime : public JSONRuntimeBase { void Run() override { LOG(FATAL) << "Unreachable"; } }; -ffi::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module HipblasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index ea32f7f1f24a..a8bb6c26083f 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -50,7 +50,7 @@ namespace json { class JSONRuntimeBase : public ffi::ModuleObj { public: JSONRuntimeBase(const std::string& symbol_name, const std::string& graph_json, - const Array const_names) + const ffi::Array const_names) : symbol_name_(symbol_name), graph_json_(graph_json), const_names_(const_names) { LoadGraph(graph_json_); } @@ -63,7 +63,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { } /*! \brief Initialize a specific json runtime. */ - virtual void Init(const Array& consts) = 0; + virtual void Init(const ffi::Array& consts) = 0; /*! \brief Invoke the execution engine to inteprete a specific json runtime. */ virtual void Run() = 0; @@ -93,7 +93,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( @@ -123,8 +123,8 @@ class JSONRuntimeBase : public ffi::ModuleObj { // Bind argument tensors to data entries. this->SetInputOutputBuffers(args); - if (auto opt_str = rv->try_cast()) { - String purpose = std::move(opt_str.value()); + if (auto opt_str = rv->try_cast()) { + ffi::String purpose = std::move(opt_str.value()); if ("debug_dump" == purpose) { *rv = this->DebugDump(); } @@ -133,7 +133,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { profiling::Profiler* prof = static_cast(rv->cast()); this->RunProfile(prof); } - // String vendor_prof = this->RunProfile(prof); + // ffi::String vendor_prof = this->RunProfile(prof); }); } else if ("__init_" + this->symbol_name_ == name) { // The function to initialize constant tensors. @@ -141,7 +141,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { ICHECK_EQ(args.size(), 1U); std::lock_guard guard(this->initialize_mutex_); if (!this->initialized_) { - this->Init(args[0].cast>()); + this->Init(args[0].cast>()); this->initialized_ = true; } *rv = 0; @@ -180,11 +180,11 @@ class JSONRuntimeBase : public ffi::ModuleObj { ICHECK(stream->Read(&symbol)) << "Loading symbol name failed"; ICHECK(stream->Read(&graph_json)) << "Loading graph json failed"; ICHECK(stream->Read(&consts)) << "Loading the const name list failed"; - Array const_names; + ffi::Array const_names; for (const auto& it : consts) { const_names.push_back(it); } - auto n = make_object(symbol, graph_json, const_names); + auto n = ffi::make_object(symbol, graph_json, const_names); return ffi::Module(n); } @@ -194,7 +194,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return graph_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return graph_json_; } protected: /*! @@ -270,7 +270,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * * \param consts A list of constant Tensor to be used. */ - void SetupConstants(const Array& consts) { + void SetupConstants(const ffi::Array& consts) { for (size_t i = 0; i < consts.size(); ++i) { data_entry_[EntryID(const_idx_[i], 0)] = consts[i].operator->(); } @@ -313,7 +313,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { /*! \brief The graph. */ std::string graph_json_; /*! \brief The required constant names. */ - Array const_names_; + ffi::Array const_names_; /*! \brief The json graph nodes. */ std::vector nodes_; /*! \brief The input nodes, including variables and constants. */ diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index f9769d79099a..336367131fc7 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -212,7 +212,7 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual Optional GetFunction(const String& name) { + virtual ffi::Optional GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( @@ -226,8 +226,9 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { use_dpdk_cb = true; }); } else if (name == "get_const_vars") { - return ffi::Function( - [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::Array{}; + }); } else if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { RunInference(args); @@ -274,8 +275,8 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { ICHECK(stream->Read(&num_inputs)) << "Loading num_inputs failed"; ICHECK(stream->Read(&num_outputs)) << "Loading num_outputs failed"; ICHECK(stream->Read(&batch_size)) << "Loading batch_size failed"; - auto n = make_object(symbol_name, nodes_json, bin_code, num_inputs, - num_outputs, batch_size); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code, + num_inputs, num_outputs, batch_size); return ffi::Module(n); } @@ -285,7 +286,7 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return nodes_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -469,11 +470,12 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { } }; -ffi::Module MarvellHardwareModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, - const String& bin_code, int num_input, +ffi::Module MarvellHardwareModuleRuntimeCreate(const ffi::String& symbol_name, + const ffi::String& nodes_json, + const ffi::String& bin_code, int num_input, int num_output, int batch_size) { - auto n = make_object(symbol_name, nodes_json, bin_code, num_input, - num_output, batch_size); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code, num_input, + num_output, batch_size); return ffi::Module(n); } diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index af384035c96b..8c1ed354d6f5 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -70,14 +70,15 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The packed function. */ - virtual Optional GetFunction(const String& name) { + virtual ffi::Optional GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_symbol") { return ffi::Function( [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->symbol_name_; }); } else if (name == "get_const_vars") { - return ffi::Function( - [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = Array{}; }); + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::Array{}; + }); } else if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { Run(args); @@ -111,7 +112,7 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { ICHECK(stream->Read(&nodes_json)) << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; ICHECK(stream->Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; - auto n = make_object(symbol_name, nodes_json, bin_code); + auto n = ffi::make_object(symbol_name, nodes_json, bin_code); return ffi::Module(n); } @@ -121,7 +122,7 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { * \param format the format to return. * \return A string of JSON. */ - String InspectSource(const String& format) const override { return nodes_json_; } + ffi::String InspectSource(const ffi::String& format) const override { return nodes_json_; } protected: std::string symbol_name_; @@ -149,9 +150,10 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { } }; -ffi::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, const String& nodes_json, - const String& bin_code) { - auto n = make_object(symbol_name, nodes_json, bin_code); +ffi::Module MarvellSimulatorModuleRuntimeCreate(const ffi::String& symbol_name, + const ffi::String& nodes_json, + const ffi::String& bin_code) { + auto n = ffi::make_object(symbol_name, nodes_json, bin_code); return ffi::Module(n); } diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index 8e68cf7e6963..a7d50f412c9d 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -126,7 +126,7 @@ static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, } float f; float* data = new float[tot_dim](); - String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; + ffi::String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; std::ifstream fin(outbin, std::ios::binary); ICHECK(fin.is_open()) << "Cannot open file: " << outbin; int i = 0; diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 3a5f7c02def6..8a837370fa34 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -62,7 +62,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit MSCTensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} ~MSCTensorRTRuntime() { @@ -87,7 +87,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalOptions(); @@ -122,14 +122,14 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - Map input_datas; + ffi::Map input_datas; int device_id = 0; for (const auto& pair : input_bindings_) { const auto& tensor_name = engine_->getBindingName(pair.first); input_datas.Set(tensor_name, device_buffers_[pair.first]); device_id = data_entry_[pair.first]->device.device_id; } - Map> context; + ffi::Map> context; context.Set("datas", input_datas); (*pf)(context, "before_forward", graph_name_, tool_tag_); } @@ -155,7 +155,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - Map output_datas; + ffi::Map output_datas; for (int bid = 0; bid < engine_->getNbBindings(); bid++) { if (input_bindings_.count(bid)) { continue; @@ -163,13 +163,13 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { const auto& tensor_name = engine_->getBindingName(bid); output_datas.Set(tensor_name, device_buffers_[bid]); } - Map> context; + ffi::Map> context; context.Set("datas", output_datas); (*pf)(context, "after_forward", graph_name_, tool_tag_); } } - bool LoadEngine(const String& engine_file) { + bool LoadEngine(const ffi::String& engine_file) { IRuntime* runtime = createInferRuntime(logger_); // build engine std::ifstream input(engine_file_, std::ifstream::binary); @@ -323,15 +323,15 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { << "Please build with USE_TENSORRT_RUNTIME."; } - bool LoadEngine(const String& engine_file) { return false; } + bool LoadEngine(const ffi::String& engine_file) { return false; } void DestroyEngine() {} #endif // TVM_GRAPH_EXECUTOR_TENSORRT private: - String engine_file_; - String tool_tag_; - String graph_name_; + ffi::String engine_file_; + ffi::String tool_tag_; + ffi::String graph_name_; std::unordered_map> tensor_ids_; #ifdef TVM_GRAPH_EXECUTOR_TENSORRT TensorRTLogger logger_; @@ -345,9 +345,9 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { #endif }; -ffi::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module MSCTensorRTRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index a1f3b3f132f5..db0f19897bbc 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -51,7 +51,7 @@ using JSONGraphNode = tvm::runtime::json::JSONGraphNode; class NNAPIRuntime : public JSONRuntimeBase { public: explicit NNAPIRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} const char* kind() const final { return "nnapi"; } @@ -70,7 +70,7 @@ class NNAPIRuntime : public JSONRuntimeBase { std::optional compiled_model_; - void Init(const Array& consts) final { + void Init(const ffi::Array& consts) final { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required constants."; SetupConstants(consts); @@ -225,7 +225,7 @@ class NNAPIRuntime : public JSONRuntimeBase { std::unordered_map node_output_map_; #else // ifdef TVM_GRAPH_EXECUTOR_NNAPI - void Init(const Array& consts) final { + void Init(const ffi::Array& consts) final { LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; } @@ -235,9 +235,9 @@ class NNAPIRuntime : public JSONRuntimeBase { #endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI }; -ffi::Module NNAPIRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module NNAPIRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 4cb0558d611b..9082f43b3966 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -80,7 +80,7 @@ void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) { << ", npes=" << nvshmem_n_pes(); } -void InitNVSHMEMWrapper(String args) { +void InitNVSHMEMWrapper(ffi::String args) { picojson::value v; std::string err = picojson::parse(v, args); if (!err.empty()) { diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 0c816669be9a..4e742a0792e7 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -68,7 +68,7 @@ class NVSHMEMAllocator final : public PooledAllocator { Buffer buffer_; }; - Buffer buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem")); + Buffer buffer = PooledAllocator::Alloc(device, shape, dtype, ffi::String("nvshmem")); return Tensor::FromNDAlloc(NVSHMEMAlloc(buffer), shape, dtype, device); } diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index d847e05e1bee..6bedf2d4ef6c 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -101,7 +101,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { * collected on that device. You can find the names of available metrics by * running `papi_native_avail`. */ - explicit PAPIMetricCollectorNode(Map> metrics) { + explicit PAPIMetricCollectorNode(ffi::Map> metrics) { for (auto& p : metrics) { papi_metric_names[p.first->device] = {}; for (auto& metric : p.second) { @@ -114,7 +114,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { /*! \brief Initialization call. * \param devices The devices this collector will be running on */ - void Init(Array devices) { + void Init(ffi::Array devices) { if (!PAPI_is_initialized()) { if (sizeof(long_long) > sizeof(int64_t)) { LOG(WARNING) << "PAPI's long_long is larger than int64_t. Overflow may occur when " @@ -225,7 +225,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { int event_set = it->second; std::vector values(papi_metric_names[dev].size()); PAPI_CALL(PAPI_read(event_set, values.data())); - return ObjectRef(make_object(values, dev)); + return ObjectRef(ffi::make_object(values, dev)); } else { return ObjectRef(nullptr); } @@ -237,19 +237,19 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { * \param obj `PAPIEventSetNode` created by a call to `Start`. * \returns A mapping from metric name to value. */ - Map Stop(ObjectRef obj) final { + ffi::Map Stop(ObjectRef obj) final { const PAPIEventSetNode* event_set_node = obj.as(); std::vector end_values(papi_metric_names[event_set_node->dev].size()); PAPI_CALL(PAPI_read(event_sets[event_set_node->dev], end_values.data())); - std::unordered_map reported_metrics; + std::unordered_map reported_metrics; for (size_t i = 0; i < end_values.size(); i++) { if (end_values[i] < event_set_node->start_values[i]) { LOG(WARNING) << "Detected overflow when reading performance counter, setting value to -1."; reported_metrics[papi_metric_names[event_set_node->dev][i]] = - ObjectRef(make_object(-1)); + ObjectRef(ffi::make_object(-1)); } else { reported_metrics[papi_metric_names[event_set_node->dev][i]] = - ObjectRef(make_object(end_values[i] - event_set_node->start_values[i])); + ObjectRef(ffi::make_object(end_values[i] - event_set_node->start_values[i])); } } return reported_metrics; @@ -277,22 +277,24 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { /*! \brief Wrapper for `PAPIMetricCollectorNode`. */ class PAPIMetricCollector : public MetricCollector { public: - explicit PAPIMetricCollector(Map> metrics) { - data_ = make_object(metrics); + explicit PAPIMetricCollector(ffi::Map> metrics) { + data_ = ffi::make_object(metrics); } TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PAPIMetricCollector, MetricCollector, PAPIMetricCollectorNode); }; -MetricCollector CreatePAPIMetricCollector(Map> metrics) { +MetricCollector CreatePAPIMetricCollector( + ffi::Map> metrics) { return PAPIMetricCollector(metrics); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "runtime.profiling.PAPIMetricCollector", - [](Map> metrics) { return PAPIMetricCollector(metrics); }); + refl::GlobalDef().def("runtime.profiling.PAPIMetricCollector", + [](ffi::Map> metrics) { + return PAPIMetricCollector(metrics); + }); }); } // namespace profiling diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index d66b1a1c46e1..8620988f8465 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -68,7 +68,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * \param const_names The names of each constant in the sub-graph. */ explicit TensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, - const Array& const_names) + const ffi::Array& const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names), use_implicit_batch_(true), max_workspace_size_(size_t(1) << 30), @@ -109,7 +109,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * * \param consts The constant params from compiled model. */ - void Init(const Array& consts) override { + void Init(const ffi::Array& consts) override { ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalAttributes(); @@ -519,9 +519,9 @@ class TensorRTRuntime : public JSONRuntimeBase { bool use_fp16_; }; -ffi::Module TensorRTRuntimeCreate(const String& symbol_name, const String& graph_json, - const Array& const_names) { - auto n = make_object(symbol_name, graph_json, const_names); +ffi::Module TensorRTRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = ffi::make_object(symbol_name, graph_json, const_names); return ffi::Module(n); } diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index b51b8084cb91..8ddaafbd6cb0 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -152,7 +152,7 @@ Tensor TFLiteRuntime::GetOutput(int index) const { return ret; } -ffi::Optional TFLiteRuntime::GetFunction(const String& name) { +ffi::Optional TFLiteRuntime::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Return member functions during query. if (name == "set_input") { @@ -180,7 +180,7 @@ ffi::Optional TFLiteRuntime::GetFunction(const String& name) { } ffi::Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { - auto exec = make_object(); + auto exec = ffi::make_object(); exec->Init(tflite_model_bytes, dev); return ffi::Module(exec); } diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 590ee4df6f7b..a5703ee70749 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -54,7 +54,7 @@ class TFLiteRuntime : public ffi::ModuleObj { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); /*! * \return The type key of the executor. diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc index 673f83e2e0c1..e5814df8afd5 100644 --- a/src/runtime/contrib/vllm/cache_alloc.cc +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -25,9 +25,9 @@ namespace tvm { namespace runtime { namespace vllm { -Array AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size, - int num_blocks) { - Array cache; +ffi::Array AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size, + int num_blocks) { + ffi::Array cache; int element_size = 2; int vec_size = 16 / element_size; diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index a68fd66d6269..d97c9f8a7aa1 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -184,7 +184,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ return Array{key, value}; }) - .def("tvm.contrib.vllm.copy_blocks", [](Array key_value_caches, + .def("tvm.contrib.vllm.copy_blocks", [](ffi::Array key_value_caches, Tensor block_mapping) { auto num_layers = key_value_caches.size() / 2; auto num_pairs = block_mapping->shape[0] / 2; diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 451348afbf1a..d346d4d83e8b 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -336,10 +336,10 @@ class CUDATimerNode : public TimerNode { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.cuda", - [](Device dev) { return Timer(make_object()); }); + [](Device dev) { return Timer(ffi::make_object()); }); }); -TVM_DLL String GetCudaFreeMemory() { +TVM_DLL ffi::String GetCudaFreeMemory() { size_t free_mem, total_mem; CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); std::stringstream ss; diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index eb3bee4757bf..9086903d0141 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -73,9 +73,9 @@ class CUDAModuleNode : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -99,7 +99,7 @@ class CUDAModuleNode : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == fmt_) return data_; if (cuda_source_.length() != 0) { return cuda_source_; @@ -261,7 +261,7 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -Optional CUDAModuleNode::GetFunction(const String& name) { +ffi::Optional CUDAModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == symbol::tvm_prepare_global_barrier) { @@ -278,12 +278,12 @@ Optional CUDAModuleNode::GetFunction(const String& name) { ffi::Module CUDAModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string cuda_source) { - auto n = make_object(data, fmt, fmap, cuda_source); + auto n = ffi::make_object(data, fmt, fmap, cuda_source); return ffi::Module(n); } // Load module from module. -ffi::Module CUDAModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module CUDAModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 16fd3c7b7761..fd7d651df2f4 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -107,7 +107,7 @@ static size_t GetDataAlignment(const DLDataType dtype) { return align; } -size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { +size_t DeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional mem_scope) { if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { size_t size = 1; for (int i = 0; i < arr.ndim; ++i) { @@ -121,7 +121,7 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { } void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { // by default, we can always redirect to the flat memory allocations DLTensor temp; diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index f4964b12d709..2ea9ef575d05 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -38,7 +38,7 @@ struct BcastSessionObj::Internal { } static DRef MakeDRef(int reg_id, Session session) { - ObjectPtr p = make_object(); + ObjectPtr p = ffi::make_object(); p->reg_id = reg_id; p->session = session; return DRef(std::move(p)); @@ -48,7 +48,7 @@ struct BcastSessionObj::Internal { DRef BcastSessionObj::GetGlobalFunc(const std::string& name) { int reg_id = AllocateReg(); BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kGetGlobalFunc, reg_id, name); - return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); + return BcastSessionObj::Internal::MakeDRef(reg_id, ffi::GetRef(this)); } void BcastSessionObj::CopyFromWorker0(const Tensor& host_array, const DRef& remote_array) { @@ -67,11 +67,11 @@ void BcastSessionObj::Shutdown() { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 0); } -void BcastSessionObj::InitCCL(String ccl, ffi::Shape device_ids) { +void BcastSessionObj::InitCCL(ffi::String ccl, ffi::Shape device_ids) { const auto pf = tvm::ffi::Function::GetGlobal("runtime.disco." + ccl + ".init_ccl"); CHECK(pf.has_value()) << "ValueError: Cannot initialize CCL `" << ccl << "`, because cannot find function: runtime.disco." << ccl << ".init_ccl"; - (*pf)(GetRef(this), device_ids); + (*pf)(ffi::GetRef(this), device_ids); } void BcastSessionObj::SyncWorker(int worker_id) { @@ -97,7 +97,7 @@ DRef BcastSessionObj::CallWithPacked(const ffi::PackedArgs& args) { args_vec[2] = func->reg_id; } this->BroadcastPacked(ffi::PackedArgs(args_vec, args.size())); - return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef(this)); + return BcastSessionObj::Internal::MakeDRef(reg_id, ffi::GetRef(this)); } void BcastSessionObj::DeallocReg(int reg_id) { diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index e4ee3bb8a1cb..a850902c5e46 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -41,7 +41,7 @@ class BcastSessionObj : public SessionObj { void CopyToWorker0(const Tensor& host_array, const DRef& remote_array) override; void SyncWorker(int worker_id) override; void Shutdown() override; - void InitCCL(String ccl, IntTuple device_ids) override; + void InitCCL(ffi::String ccl, IntTuple device_ids) override; ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) override = 0; diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 2cfd91dfde83..b88c9a36ad5f 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -49,17 +49,17 @@ class DSOLibraryCache { std::mutex mutex_; }; -ffi::Module LoadVMModule(std::string path, Optional device) { +ffi::Module LoadVMModule(std::string path, ffi::Optional device) { static DSOLibraryCache cache; ffi::Module dso_mod = cache.Open(path); Device dev = UseDefaultDeviceIfNone(device); - Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); + ffi::Optional vm_load_executable = dso_mod->GetFunction("vm_load_executable"); if (!vm_load_executable.has_value()) { // not built by RelaxVM, return the dso_mod directly return dso_mod; } auto mod = (*vm_load_executable)().cast(); - Optional vm_initialization = mod->GetFunction("vm_initialization"); + ffi::Optional vm_initialization = mod->GetFunction("vm_initialization"); if (!vm_initialization.has_value()) { LOG(FATAL) << "ValueError: File `" << path << "` is not built by RelaxVM, because `vm_initialization` does not exist"; @@ -70,7 +70,7 @@ ffi::Module LoadVMModule(std::string path, Optional device) { return mod; } -Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, Optional device) { +Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device) { return Tensor::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } @@ -95,11 +95,11 @@ TVM_DLL void BroadcastFromWorker0(Tensor send, bool in_group, Tensor recv) { GetCCLFunc("broadcast_from_worker0")(send, in_group, recv); } -TVM_DLL void ScatterFromWorker0(Optional send, bool in_group, Tensor recv) { +TVM_DLL void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { GetCCLFunc("scatter_from_worker0")(send, in_group, recv); } -void GatherToWorker0(Tensor send, bool in_group, Optional recv) { +void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { GetCCLFunc("gather_to_worker0")(send, in_group, recv); } @@ -130,8 +130,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("runtime.disco.load_vm_module", LoadVMModule) .def("runtime.disco.empty", - [](ffi::Shape shape, DataType dtype, Optional device, bool worker0_only, - bool in_group) -> Optional { + [](ffi::Shape shape, DataType dtype, ffi::Optional device, bool worker0_only, + bool in_group) -> ffi::Optional { int worker_id = WorkerId(); int group_size = DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index 37ae2b404101..a02ab2a84c3f 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -101,7 +101,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); // Create the CUDAIPCMemory object. - ObjectPtr ipc_memory = make_object(); + ObjectPtr ipc_memory = ffi::make_object(); nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get(); ipc_memory->remote_data = data_comm_ptrs; ipc_memory->barrier_in = barrier_in_comm_ptrs; diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 8e576fff227d..3fbe59a3c308 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -56,7 +56,7 @@ class DiscoSocketChannel : public DiscoChannel { class SocketSessionObj : public BcastSessionObj { public: explicit SocketSessionObj(int num_nodes, int num_workers_per_node, int num_groups, - const String& host, int port) + const ffi::String& host, int port) : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) { const auto f_create_local_session = tvm::ffi::Function::GetGlobal("runtime.disco.create_socket_session_local_workers"); @@ -209,7 +209,8 @@ class SocketSessionObj : public BcastSessionObj { class RemoteSocketSession { public: - explicit RemoteSocketSession(const String& server_host, int server_port, int num_local_workers) { + explicit RemoteSocketSession(const ffi::String& server_host, int server_port, + int num_local_workers) { socket_.Create(); socket_.SetKeepAlive(true); SockAddr server_addr{server_host.c_str(), server_port}; @@ -287,7 +288,7 @@ class RemoteSocketSession { int num_workers_per_node_{-1}; }; -void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, +void RemoteSocketSessionEntryPoint(const ffi::String& server_host, int server_port, int num_local_workers) { RemoteSocketSession proxy(server_host, server_port, num_local_workers); proxy.MainLoop(); @@ -298,9 +299,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("runtime.disco.RemoteSocketSession", RemoteSocketSessionEntryPoint); }); -Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, - int port) { - auto n = make_object(num_nodes, num_workers_per_node, num_groups, host, port); +Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, + const ffi::String& host, int port) { + auto n = + ffi::make_object(num_nodes, num_workers_per_node, num_groups, host, port); return Session(n); } diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index fec50cd71118..87633c01b8c3 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -78,7 +78,8 @@ ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_ shape.push_back(AsType(shape_json[i])); } std::string dtype = AsType(json_tensor_info[1]); - return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), DataType(StringToDLDataType(dtype))}; + return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), + DataType(ffi::StringToDLDataType(dtype))}; } ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) { @@ -117,19 +118,19 @@ class ShardLoaderObj : public Object { public: /*! \brief Create a shard loader. */ static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Optional mod); + std::string shard_info, ffi::Optional mod); /*! \brief Load the i-th parameter */ Tensor Load(int weight_index) const; Tensor LoadParamOnWorker0(int weight_index) const; /*! \brief Load all the parameters */ - Array LoadAll() const; + ffi::Array LoadAll() const; Tensor ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const Tensor& param) const; /*! \brief Load all the pre-sharded parameters */ - Array LoadAllPresharded() const; + ffi::Array LoadAllPresharded() const; /*! \brief Load the i-th parameter from presharded binaries */ Tensor LoadPresharded(int weight_index) const; @@ -175,13 +176,13 @@ class ShardLoaderObj : public Object { }; ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata, - std::string shard_info, Optional mod) { + std::string shard_info, ffi::Optional mod) { if (shard_info.empty() && mod.has_value()) { if (auto get_shard_info = (*mod)->GetFunction("get_shard_info")) { - shard_info = (*get_shard_info)().cast(); + shard_info = (*get_shard_info)().cast(); } } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->metadata_ = TensorCacheMetadata::LoadFromStr(metadata, path_to_metadata); n->current_file_ = nullptr; n->param_info_.clear(); @@ -194,7 +195,7 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: ShardInfo& shard_info = shards[name]; for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) { const std::string& name = shard_func.name; - if (Optional f = + if (ffi::Optional f = mod.has_value() ? (*mod)->GetFunction(name, true) : std::nullopt) { n->shard_funcs_[name] = *f; } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { @@ -341,9 +342,9 @@ Tensor ShardLoaderObj::Load(int weight_index) const { } } -Array ShardLoaderObj::LoadAll() const { +ffi::Array ShardLoaderObj::LoadAll() const { int n = static_cast(param_info_.size()); - Array shards; + ffi::Array shards; shards.reserve(n); for (int i = 0; i < n; ++i) { std::string param_name = "param_" + std::to_string(i); @@ -380,13 +381,13 @@ Tensor ShardLoaderObj::LoadPresharded(int weight_index) const { return LoadDirect(index); } -Array ShardLoaderObj::LoadAllPresharded() const { +ffi::Array ShardLoaderObj::LoadAllPresharded() const { DiscoWorker* worker = DiscoWorker::ThreadLocal(); size_t worker_id = static_cast(worker->worker_id); size_t num_workers = static_cast(worker->num_workers); size_t num_params = param_info_.size() / num_workers; - Array params; + ffi::Array params; params.reserve(num_params); for (size_t i_param = 0; i_param < num_params; ++i_param) { std::string param_name = static_cast( diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 86950eedad45..c9207d92d2d0 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -141,7 +141,7 @@ void AllGather(Tensor send, bool in_group, Tensor recv) { in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void BroadcastFromWorker0(Optional send, bool in_group, Tensor recv) { +void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int group_size = ctx->worker->num_workers / ctx->worker->num_groups; @@ -164,7 +164,7 @@ void BroadcastFromWorker0(Optional send, bool in_group, Tensor recv) { /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } -void ScatterFromWorker0(Optional send, bool in_group, Tensor recv) { +void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; @@ -211,7 +211,7 @@ void ScatterFromWorker0(Optional send, bool in_group, Tensor recv) { NCCL_CALL(ncclGroupEnd()); } -void GatherToWorker0(Tensor send, bool in_group, Optional recv) { +void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { CHECK(send.defined()) << "ValueError: buffer `send` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; @@ -330,7 +330,7 @@ void SyncWorker() { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("runtime.disco.compiled_ccl", []() -> String { return TVM_DISCO_CCL_NAME; }) + .def("runtime.disco.compiled_ccl", []() -> ffi::String { return TVM_DISCO_CCL_NAME; }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl", InitCCL) .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker", InitCCLPerWorker) .def("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce", diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index d901b3eae42c..04675db7ad98 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -173,15 +173,15 @@ class ProcessSessionObj final : public BcastSessionObj { TVM_DECLARE_FINAL_OBJECT_INFO(ProcessSessionObj, SessionObj); }; -Session Session::ProcessSession(int num_workers, int num_group, String process_pool_creator, - String entrypoint) { +Session Session::ProcessSession(int num_workers, int num_group, ffi::String process_pool_creator, + ffi::String entrypoint) { CHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; const auto pf = tvm::ffi::Function::GetGlobal(process_pool_creator); CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator << " in the registry. Please check if it is registered."; auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast(); - auto n = make_object(num_workers, num_group, process_pool); + auto n = ffi::make_object(num_workers, num_group, process_pool); return Session(n); } diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 3c3193d31147..000e3482f1fe 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -96,7 +96,7 @@ struct DiscoDebugObject : public Object { /*! \brief Wrap an Tensor or reflection-capable TVM object into the debug extension. */ static ObjectRef Wrap(const ffi::Any& data) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->data = data; return ObjectRef(n); } @@ -182,7 +182,7 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { uint32_t type_index; self->template Read(&type_index); if (type_index == TypeIndex::kRuntimeDiscoDRef) { - ObjectPtr dref = make_object(); + ObjectPtr dref = ffi::make_object(); self->template Read(&dref->reg_id); dref->session = Session{nullptr}; result = ObjectRef(std::move(dref)); @@ -191,7 +191,7 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { self->template Read(&size); std::string data(size, '\0'); self->template ReadArray(data.data(), size); - result = String(std::move(data)); + result = ffi::String(std::move(data)); } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) { uint64_t size = 0; self->template Read(&size); @@ -247,7 +247,7 @@ inline ObjectPtr DiscoDebugObject::LoadFromStr(std::string jso ICHECK(!json_str.empty()); char control_bit = json_str.back(); json_str.pop_back(); - ObjectPtr result = make_object(); + ObjectPtr result = ffi::make_object(); if (control_bit == '0') { const auto f = tvm::ffi::Function::GetGlobal("node.LoadJSON"); CHECK(f.has_value()) << "ValueError: Cannot deserialize object in non-debugging mode"; diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 7dba51e4900c..864ff442f694 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -190,7 +190,7 @@ class ThreadedSessionObj final : public BcastSessionObj { Session Session::ThreadedSession(int num_workers, int num_group) { CHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; - ObjectPtr n = make_object(num_workers, num_group); + ObjectPtr n = ffi::make_object(num_workers, num_group); return Session(std::move(n)); } diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h index f0a10b6093d4..fb68335d8c5e 100644 --- a/src/runtime/disco/utils.h +++ b/src/runtime/disco/utils.h @@ -27,7 +27,7 @@ namespace tvm { namespace runtime { -inline Device UseDefaultDeviceIfNone(Optional device) { +inline Device UseDefaultDeviceIfNone(ffi::Optional device) { return device.value_or(DiscoWorker::ThreadLocal()->default_device); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 4a0a8044fd8e..63e02049bd82 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -196,12 +196,12 @@ void CopyFile(const std::string& src_file_name, const std::string& dest_file_nam << " dest='" << dest_file_name << "'"; } -Map LoadParams(const std::string& param_blob) { +ffi::Map LoadParams(const std::string& param_blob) { dmlc::MemoryStringStream strm(const_cast(¶m_blob)); return LoadParams(&strm); } -Map LoadParams(dmlc::Stream* strm) { - Map params; +ffi::Map LoadParams(dmlc::Stream* strm) { + ffi::Map params; uint64_t header, reserved; ICHECK(strm->Read(&header)) << "Invalid parameters file format"; ICHECK(header == kTVMTensorListMagic) << "Invalid parameters file format"; @@ -222,7 +222,7 @@ Map LoadParams(dmlc::Stream* strm) { return params; } -void SaveParams(dmlc::Stream* strm, const Map& params) { +void SaveParams(dmlc::Stream* strm, const ffi::Map& params) { std::vector names; std::vector arrays; for (auto& p : params) { @@ -243,7 +243,7 @@ void SaveParams(dmlc::Stream* strm, const Map& params) { } } -std::string SaveParams(const Map& params) { +std::string SaveParams(const ffi::Map& params) { std::string bytes; dmlc::MemoryStringStream strm(&bytes); dmlc::Stream* fo = &strm; @@ -255,17 +255,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.SaveParams", - [](const Map& params) { + [](const ffi::Map& params) { std::string s = ::tvm::runtime::SaveParams(params); return ffi::Bytes(std::move(s)); }) .def("runtime.SaveParamsToFile", - [](const Map& params, const String& path) { + [](const ffi::Map& params, const ffi::String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); SaveParams(&strm, params); }) .def("runtime.LoadParams", [](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); }) - .def("runtime.LoadParamsFromFile", [](const String& path) { + .def("runtime.LoadParamsFromFile", [](const ffi::String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); return LoadParams(&strm); }); diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 43f4a8455f41..6f5487f7fab0 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -110,25 +110,25 @@ constexpr uint64_t kTVMTensorListMagic = 0xF7E58D4F05049CB7; * \param param_blob Serialized string of parameters. * \return Map of parameter name to parameter value. */ -Map LoadParams(const std::string& param_blob); +ffi::Map LoadParams(const std::string& param_blob); /*! * \brief Load parameters from a stream. * \param strm Stream to load parameters from. * \return Map of parameter name to parameter value. */ -Map LoadParams(dmlc::Stream* strm); +ffi::Map LoadParams(dmlc::Stream* strm); /*! * \brief Serialize parameters to a byte array. * \param params Parameters to save. - * \return String containing binary parameter data. + * \return ffi::String containing binary parameter data. */ -std::string SaveParams(const Map& params); +std::string SaveParams(const ffi::Map& params); /*! * \brief Serialize parameters to a stream. * \param strm Stream to write to. * \param params Parameters to save. */ -void SaveParams(dmlc::Stream* strm, const Map& params); +void SaveParams(dmlc::Stream* strm, const ffi::Map& params); /*! * \brief A dmlc stream which wraps standard file operations. diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc index 48afa5770afd..c6dd9421fe63 100644 --- a/src/runtime/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -109,7 +109,7 @@ std::unique_ptr Allocator(size_t return std::make_unique(nbytes, alignment); } -HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, Optional scope) +HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, ffi::Optional scope) : ndim_(1), nbytes_per_allocation_(nbytes) { SetStorageScope(scope); @@ -125,7 +125,7 @@ HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, Optional s } HexagonBuffer::HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, - Optional scope) + ffi::Optional scope) : ndim_(2), nbytes_per_allocation_(nbytes) { SetStorageScope(scope); @@ -166,7 +166,7 @@ void* HexagonBuffer::GetPointer() { HexagonBuffer::StorageScope HexagonBuffer::GetStorageScope() const { return storage_scope_; } -void HexagonBuffer::SetStorageScope(Optional scope) { +void HexagonBuffer::SetStorageScope(ffi::Optional scope) { const std::string s = scope.value_or("global"); if (s == "global") { diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h index b1bec270d4fe..2dd7c127e3ed 100644 --- a/src/runtime/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -49,7 +49,7 @@ class HexagonBuffer { * space in which to allocate. Defaults to global system * memory (DDR). */ - HexagonBuffer(size_t nbytes, size_t alignment, Optional scope); + HexagonBuffer(size_t nbytes, size_t alignment, ffi::Optional scope); /* \brief Allocate 2d (discontiguous) memory within Hexagon accessible * memory scopes. @@ -65,7 +65,7 @@ class HexagonBuffer { * space in which to allocate. Defaults to global system * memory (DDR). */ - HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, Optional scope); + HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, ffi::Optional scope); //! \brief Destruction deallocates the underlying allocations. ~HexagonBuffer(); @@ -140,7 +140,7 @@ class HexagonBuffer { size_t TotalBytes() const { return nbytes_per_allocation_ * allocations_.size(); } //! \brief Assign a storage scope to the buffer. - void SetStorageScope(Optional scope); + void SetStorageScope(ffi::Optional scope); /*! \brief Array of raw pointer allocations required by the buffer. * * For 1d (contiguous) storage a single allocation will result. diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 491ded5730e6..64a79c0e5e99 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -57,7 +57,7 @@ class HexagonTimerNode : public TimerNode { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.hexagon", - [](Device dev) { return Timer(make_object()); }); + [](Device dev) { return Timer(ffi::make_object()); }); }); } // namespace hexagon @@ -94,7 +94,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def_packed( "ffi.Module.load_from_file.hexagon", [](ffi::PackedArgs args, ffi::Any* rv) { auto floader = tvm::ffi::Function::GetGlobalRequired("ffi.Module.load_from_file.so"); - *rv = floader(args[0].cast(), "so"); + *rv = floader(args[0].cast(), "so"); }); }); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index ec58946b64b1..cd6d55b3b66b 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -52,7 +52,7 @@ void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { // DataSpace: static allocations for Hexagon void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { CHECK(shape || ndim == 0) << "shape array is null for a non-scalar tensor, ndim = " << ndim; CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; @@ -122,7 +122,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignme CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with " << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " << "Please call HexagonDeviceAPI::AcquireResources"; - return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, String("global")); + return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, ffi::String("global")); } void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -272,7 +272,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ type_hint.lanes = 1; HexagonDeviceAPI* hexapi = HexagonDeviceAPI::Global(); - *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); + *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, ffi::String(scope)); }) .def_packed("device_api.hexagon.free_nd", [](ffi::PackedArgs args, ffi::Any* rv) { diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/runtime/hexagon/hexagon_device_api.h index e77e681dd434..76439ef531ae 100644 --- a/src/runtime/hexagon/hexagon_device_api.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -136,7 +136,7 @@ class HexagonDeviceAPI final : public DeviceAPI { * \return The allocated HexagonBuffer pointer. */ void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final; + ffi::Optional mem_scope) final; /*! * \brief Copy data from one storage to another. diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 9db6a6680b06..5515c33e5f7d 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -42,11 +42,11 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, std::string bc_str) : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} -Optional HexagonModuleNode::GetFunction(const String& name) { +ffi::Optional HexagonModuleNode::GetFunction(const ffi::String& name) { LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; } -String HexagonModuleNode::InspectSource(const String& format) const { +ffi::String HexagonModuleNode::InspectSource(const ffi::String& format) const { if (format == "s" || format == "asm") { return asm_; } @@ -56,7 +56,7 @@ String HexagonModuleNode::InspectSource(const String& format) const { return ""; } -void HexagonModuleNode::WriteToFile(const String& file_name, const String& format) const { +void HexagonModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -93,7 +93,7 @@ ffi::Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); + auto n = ffi::make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); return ffi::Module(n); } diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index ae7174236622..1f99c278b28b 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -39,10 +39,10 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "obj". * \param fmap The function information map of each function. - * \param asm_str String with the generated assembly source. - * \param obj_str String with the object file data. - * \param ir_str String with the disassembled LLVM IR source. - * \param bc_str String with the bitcode LLVM IR. + * \param asm_str ffi::String with the generated assembly source. + * \param obj_str ffi::String with the object file data. + * \param ir_str ffi::String with the disassembled LLVM IR source. + * \param bc_str ffi::String with the bitcode LLVM IR. */ ffi::Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, @@ -60,15 +60,15 @@ class HexagonModuleNode : public ffi::ModuleObj { HexagonModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str); - Optional GetFunction(const String& name) final; - String InspectSource(const String& format) const final; + ffi::Optional GetFunction(const ffi::String& name) final; + ffi::String InspectSource(const ffi::String& format) const final; const char* kind() const final { return "hexagon"; } /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable | ffi::Module::kRunnable; } - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; protected: diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/runtime/hexagon/hexagon_thread_manager.cc index 4f8ddd156b9f..a6ae62e39fa5 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.cc +++ b/src/runtime/hexagon/hexagon_thread_manager.cc @@ -140,11 +140,11 @@ void HexagonThreadManager::SpawnThreads(unsigned thread_stack_size_bytes, unsigned thread_pipe_size_words) { // allocate all stack space for threads stack_buffer_ = hexbuffs_.AllocateHexagonBuffer(thread_stack_size_bytes * nthreads_, - MEM_ALIGNMENT, String("global")); + MEM_ALIGNMENT, ffi::String("global")); // allocate space for pipe buffers (command queues) unsigned thread_pipe_size_bytes = thread_pipe_size_words * sizeof(qurt_pipe_data_t); pipe_buffer_ = hexbuffs_.AllocateHexagonBuffer(thread_pipe_size_bytes * nthreads_, MEM_ALIGNMENT, - String("global")); + ffi::String("global")); threads_.resize(nthreads_); pipes_.resize(nthreads_); diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index 4f810011e8aa..239d9e131ea6 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -36,7 +36,7 @@ namespace runtime { namespace memory { Storage::Storage(Buffer buffer, Allocator* allocator) { - auto n = make_object(); + auto n = ffi::make_object(); n->buffer = std::move(buffer); n->allocator = allocator; data_ = std::move(n); @@ -61,7 +61,7 @@ inline size_t GetDataAlignment(const DLDataType& dtype) { } Tensor StorageObj::AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, - String scope) { + ffi::String scope) { if (scope == "global" || scope.empty()) { return AllocTensor(offset, shape, dtype); } @@ -71,7 +71,7 @@ Tensor StorageObj::AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataTyp public: explicit StorageScopedAlloc(Storage storage) : storage_(storage) {} - void AllocData(DLTensor* tensor, const ffi::Shape& shape, const String& scope, + void AllocData(DLTensor* tensor, const ffi::Shape& shape, const ffi::String& scope, int64_t byte_offset) { tensor->data = storage_->allocator->CreateView(storage_->buffer, shape, tensor->dtype, scope); tensor->byte_offset = byte_offset; @@ -87,7 +87,7 @@ Tensor StorageObj::AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataTyp << "storage allocation failure, attempted to allocate " << needed_size << " at offset " << offset << " in region that is " << this->buffer.size << "bytes"; - return Tensor::FromNDAlloc(StorageScopedAlloc(GetRef(this)), shape, dtype, + return Tensor::FromNDAlloc(StorageScopedAlloc(ffi::GetRef(this)), shape, dtype, this->buffer.device, shape, scope, offset); } @@ -120,8 +120,8 @@ Tensor StorageObj::AllocTensor(int64_t offset, ffi::Shape shape, DLDataType dtyp Storage storage_; }; - return Tensor::FromNDAlloc(StorageAlloc(GetRef(this)), shape, dtype, this->buffer.device, - offset); + return Tensor::FromNDAlloc(StorageAlloc(ffi::GetRef(this)), shape, dtype, + this->buffer.device, offset); } MemoryManager* MemoryManager::Global() { @@ -214,7 +214,7 @@ void MemoryManager::Clear() { } Tensor Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, - Optional mem_scope) { + ffi::Optional mem_scope) { VerifyDataType(dtype); class BufferAlloc { diff --git a/src/runtime/memory/naive_allocator.h b/src/runtime/memory/naive_allocator.h index aed990d22c3b..6a968c86ef3b 100644 --- a/src/runtime/memory/naive_allocator.h +++ b/src/runtime/memory/naive_allocator.h @@ -67,7 +67,7 @@ class NaiveAllocator final : public Allocator { buf.size = nbytes; buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - String(mem_scope)); + ffi::String(mem_scope)); used_memory_.fetch_add(nbytes, std::memory_order_relaxed); DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; buf.alloc_type = kNaive; diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index bc88529ae19e..85b83289f4d3 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -38,7 +38,7 @@ namespace tvm { namespace runtime { -inline String get_name_mangled(const String& module_name, const String& name) { +inline ffi::String get_name_mangled(const ffi::String& module_name, const ffi::String& name) { std::stringstream ss; ss << module_name << "_" << name; return ss.str(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 2a8544f6f17c..c8a155ce387d 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -164,7 +164,7 @@ int GetWarpSize(id dev) { id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); #else - NSArray >* devs = MTLCopyAllDevices(); + NSffi::Array >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); @@ -397,7 +397,7 @@ virtual void Stop() { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.metal", - [](Device dev) { return Timer(make_object(dev)); }); + [](Device dev) { return Timer(ffi::make_object(dev)); }); }); } // namespace metal diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 71c46504c4d4..0439ba47789a 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -58,9 +58,9 @@ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { LOG(FATAL) << "Do not support save to file, use save to binary and export instead"; } @@ -75,7 +75,7 @@ void WriteToFile(const String& file_name, const String& format) const final { stream->Write(fmt_); return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { // return text source if available. return source_; } @@ -263,7 +263,7 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) LaunchParamConfig launch_param_config_; }; -Optional MetalModuleNode::GetFunction(const String& name) { +ffi::Optional MetalModuleNode::GetFunction(const ffi::String& name) { ffi::Function ret; AUTORELEASEPOOL { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); @@ -286,24 +286,24 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) std::unordered_map fmap, std::string fmt, std::string source) { ObjectPtr n; - AUTORELEASEPOOL { n = make_object(smap, fmap, fmt, source); }; + AUTORELEASEPOOL { n = ffi::make_object(smap, fmap, fmt, source); }; return ffi::Module(n); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "runtime.module.create_metal_module", - [](Map smap, std::string fmap_json, std::string fmt, std::string source) { - std::istringstream stream(fmap_json); - std::unordered_map fmap; - dmlc::JSONReader reader(&stream); - reader.Read(&fmap); + refl::GlobalDef().def("runtime.module.create_metal_module", + [](ffi::Map smap, std::string fmap_json, + std::string fmt, std::string source) { + std::istringstream stream(fmap_json); + std::unordered_map fmap; + dmlc::JSONReader reader(&stream); + reader.Read(&fmap); - return MetalModuleCreate( - std::unordered_map(smap.begin(), smap.end()), fmap, fmt, - source); - }); + return MetalModuleCreate(std::unordered_map( + smap.begin(), smap.end()), + fmap, fmt, source); + }); }); ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& bytes) { diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 16c617ce3fcb..97238ec56b79 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -35,7 +35,7 @@ namespace tvm { namespace runtime { -bool RuntimeEnabled(const String& target_str) { +bool RuntimeEnabled(const ffi::String& target_str) { std::string target = target_str; std::string f_name; if (target == "cpu") { diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 021dad3ca35a..62da1007f0ba 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -341,7 +341,7 @@ class OpenCLWorkspace : public DeviceAPI { } void* AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, DLDataType dtype, - Optional mem_scope = std::nullopt); + ffi::Optional mem_scope = std::nullopt); void FreeDataSpaceView(Device dev, void* ptr); cl_device_id GetCLDeviceID(int device_id); @@ -350,9 +350,9 @@ class OpenCLWorkspace : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; void* AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final; void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope = std::nullopt) final; + ffi::Optional mem_scope = std::nullopt) final; void* AllocDataSpace(Device dev, size_t width, size_t height, DLDataType type_hint, - Optional mem_scope = std::nullopt); + ffi::Optional mem_scope = std::nullopt); void* GetNativePtr(const tvm::runtime::Tensor& narr); void SetNativePtr(const tvm::runtime::Tensor& narr, void* host_ptr, size_t buf_size); void SetPerfHint(Device dev, cl_uint perf_hint); @@ -360,12 +360,13 @@ class OpenCLWorkspace : public DeviceAPI { void StreamSync(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; - size_t GetDataSize(const DLTensor& arr, Optional mem_scope = std::nullopt) final; + size_t GetDataSize(const DLTensor& arr, + ffi::Optional mem_scope = std::nullopt) final; // cl_mem alloc utils void* AllocCLBuffer(Device dev, size_t size, size_t alignment, DLDataType type_hint); void* AllocCLImage(Device dev, void* back_buffer, size_t width, size_t height, size_t row_pitch, - DLDataType type_hint, Optional mem_scope); + DLDataType type_hint, ffi::Optional mem_scope); /*! * \brief Get the thread local ThreadEntry @@ -436,9 +437,10 @@ struct BufferDescriptor { kImage2DNHWC, }; BufferDescriptor() = default; - explicit BufferDescriptor(Optional scope) : layout(MemoryLayoutFromScope(scope)) {} - static MemoryLayout MemoryLayoutFromScope(Optional mem_scope); - static String ScopeFromMemoryLayout(MemoryLayout mem_scope); + explicit BufferDescriptor(ffi::Optional scope) + : layout(MemoryLayoutFromScope(scope)) {} + static MemoryLayout MemoryLayoutFromScope(ffi::Optional mem_scope); + static ffi::String ScopeFromMemoryLayout(MemoryLayout mem_scope); /* clBuffer object */ // buffer should be the first element here @@ -479,7 +481,7 @@ class OpenCLModuleNodeBase : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) override; + ffi::Optional GetFunction(const ffi::String& name) override; // Initialize the programs virtual void Init() = 0; @@ -509,14 +511,14 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap, std::string source) : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; // Return true if OpenCL program for the requested function and device was created bool IsProgramCreated(const std::string& func_name, int device_id); - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; void SetPreCompiledPrograms(const std::string& bytes); std::string GetPreCompiledPrograms(); - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; // Initialize the programs void Init() override; diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 1cc4e7936013..32ca168d314b 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -76,7 +76,7 @@ ImageInfo GetImageInfo(const cl::BufferDescriptor* desc, const DLTensor* tensor) } cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( - Optional mem_scope) { + ffi::Optional mem_scope) { if (!mem_scope.has_value()) { return cl::BufferDescriptor::MemoryLayout::kBuffer1D; } else if (mem_scope.value() == "global.texture") { @@ -89,7 +89,7 @@ cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( LOG(FATAL) << "No memory layout defined for memory of scope: " << mem_scope.value(); } -String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { +ffi::String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { switch (layout) { case cl::BufferDescriptor::MemoryLayout::kBuffer1D: return "global"; @@ -261,7 +261,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment, } void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, DLDataType type_hint, - Optional mem_scope) { + ffi::Optional mem_scope) { // Texture allocation given width and height cl_uint row_align = GetImageAlignment(dev.device_id); size_t pixel_size = (type_hint.bits * type_hint.lanes + 7) / 8; @@ -278,13 +278,13 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, D } if (!mem_scope.has_value()) { - mem_scope = String("global.texture"); + mem_scope = ffi::String("global.texture"); } return AllocCLImage(dev, back_buffer, width, height, row_pitch, type_hint, mem_scope); } void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { + ffi::Optional mem_scope) { this->Init(); if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { size_t size = GetMemObjectSize(dev, ndim, shape, dtype); @@ -321,7 +321,7 @@ void* OpenCLWorkspace::AllocCLBuffer(Device dev, size_t size, size_t alignment, void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, size_t height, size_t row_pitch, DLDataType type_hint, - Optional mem_scope) { + ffi::Optional mem_scope) { this->Init(); ICHECK(std::string(mem_scope.value()).find("texture") != std::string::npos) << "Expect texture scope while creating an Image object"; @@ -348,7 +348,7 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, return desc; } -size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_scope) { +size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, ffi::Optional mem_scope) { if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { return DeviceAPI::GetDataSize(arr); } @@ -360,7 +360,7 @@ size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_sc } void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, - DLDataType dtype, Optional mem_scope) { + DLDataType dtype, ffi::Optional mem_scope) { cl::BufferDescriptor* desc = static_cast(data); // Fall back for devices w/o "cl_khr_image2d_from_buffer" @@ -630,7 +630,7 @@ std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) { } std::string GetOpenCLVersion(cl_device_id pid) { - // String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To + // ffi::String returned is "OpenCL $MAJOR.$MINOR $VENDOR_INFO". To // match other implementations, we want to return "$MAJOR.$MINOR" std::string ret = GetDeviceInfo(pid, CL_DEVICE_VERSION); @@ -789,7 +789,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = OpenCLWorkspace::Global()->AllocDataSpace( dev, static_cast(width), static_cast(height), type_hint, - String("global.texture")); + ffi::String("global.texture")); }) .def_packed("device_api.opencl.free_nd", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -814,7 +814,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.opencl", - [](Device dev) { return Timer(make_object(dev)); }); + [](Device dev) { return Timer(ffi::make_object(dev)); }); }); class OpenCLPooledAllocator final : public memory::PooledAllocator { @@ -863,7 +863,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { buf.size = size; buf.alloc_type = AllocatorType::kPooled; buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - String(mem_scope)); + ffi::String(mem_scope)); if (mem_scope.find("texture") == std::string::npos) { // All textures are backed by buffers - don't count in total memory used_memory_.fetch_add(size, std::memory_order_relaxed); @@ -887,7 +887,8 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { void* CreateView(const Buffer& buffer, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) final { OpenCLWorkspace* ws_ = OpenCLWorkspace::Global(); - return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, String(mem_scope)); + return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, + ffi::String(mem_scope)); } void FreeView(Device dev, void* data) final { diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index a8e3b6fc20b6..169f9408c38b 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -135,7 +135,7 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } -Optional OpenCLModuleNodeBase::GetFunction(const String& name) { +ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -160,7 +160,7 @@ Optional OpenCLModuleNodeBase::GetFunction(const String& name) { return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::WriteToFile(const String& file_name, const String& format) const { +void OpenCLModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -178,7 +178,7 @@ ffi::Bytes OpenCLModuleNode::SaveToBytes() const { return ffi::Bytes(buffer); } -String OpenCLModuleNode::InspectSource(const String& format) const { +ffi::String OpenCLModuleNode::InspectSource(const ffi::String& format) const { if (format == fmt_) return data_; if (fmt_ == "cl") { return data_; @@ -349,7 +349,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { return data; } -Optional OpenCLModuleNode::GetFunction(const String& name) { +ffi::Optional OpenCLModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { @@ -367,13 +367,13 @@ Optional OpenCLModuleNode::GetFunction(const String& name) { ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { - auto n = make_object(data, fmt, fmap, source); + auto n = ffi::make_object(data, fmt, fmap, source); n->Init(); return ffi::Module(n); } // Load module from module. -ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 5b90e0b566c7..096b05382379 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -39,9 +39,9 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::unordered_map fmap) : OpenCLModuleNodeBase(fmap), shaders_(shaders), spirv_text_(spirv_text) {} - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final { return spirv_text_; } + ffi::String InspectSource(const ffi::String& format) const final { return spirv_text_; } void Init() override; cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -52,7 +52,8 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { std::string spirv_text_; }; -void OpenCLSPIRVModuleNode::WriteToFile(const String& file_name, const String& format) const { +void OpenCLSPIRVModuleNode::WriteToFile(const ffi::String& file_name, + const ffi::String& format) const { // TODO(masahi): How SPIRV binaries should be save to a file? LOG(FATAL) << "Not implemented."; } @@ -132,7 +133,7 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, const std::string& spirv_text, std::unordered_map fmap) { - auto n = make_object(shaders, spirv_text, fmap); + auto n = ffi::make_object(shaders, spirv_text, fmap); n->Init(); return ffi::Module(n); } diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index d5ac8b9de06f..8ef62c652138 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -64,7 +64,7 @@ class DefaultTimerNode : public TimerNode { Device device_; }; -Timer DefaultTimer(Device dev) { return Timer(make_object(dev)); } +Timer DefaultTimer(Device dev) { return Timer(ffi::make_object(dev)); } class CPUTimerNode : public TimerNode { public: @@ -84,7 +84,7 @@ class CPUTimerNode : public TimerNode { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("profiling.timer.cpu", - [](Device dev) { return Timer(make_object()); }); + [](Device dev) { return Timer(ffi::make_object()); }); }); // keep track of which timers are not defined but we have already warned about @@ -122,12 +122,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace profiling { Profiler::Profiler(std::vector devs, std::vector metric_collectors, - std::unordered_map configuration) + std::unordered_map configuration) : devs_(devs), collectors_(metric_collectors), configuration_(configuration) { is_running_ = false; std::vector wrapped_devs; for (auto dev : devs) { - wrapped_devs.push_back(DeviceWrapper(make_object(dev))); + wrapped_devs.push_back(DeviceWrapper(ffi::make_object(dev))); } for (auto& x : collectors_) { x->Init(wrapped_devs); @@ -135,8 +135,8 @@ Profiler::Profiler(std::vector devs, std::vector metric // reset the thread pool so that PAPI eventset hooks are set in all threads. threading::ResetThreadPool(); - configuration_[String("Number of threads")] = - ObjectRef(make_object(threading::NumThreads())); + configuration_[ffi::String("Number of threads")] = + ObjectRef(ffi::make_object(threading::NumThreads())); } void Profiler::Start() { @@ -146,7 +146,7 @@ void Profiler::Start() { } } -void Profiler::StartCall(String name, Device dev, +void Profiler::StartCall(ffi::String name, Device dev, std::unordered_map extra_metrics) { std::vector> objs; for (auto& collector : collectors_) { @@ -212,9 +212,11 @@ std::vector ToShape(Tensor shape_tensor) { return shape; } -String ShapeString(Tensor shape, DLDataType dtype) { return ShapeString(ToShape(shape), dtype); } +ffi::String ShapeString(Tensor shape, DLDataType dtype) { + return ShapeString(ToShape(shape), dtype); +} -String ShapeString(const std::vector& shape, DLDataType dtype) { +ffi::String ShapeString(const std::vector& shape, DLDataType dtype) { std::stringstream sizes; sizes << dtype << "["; for (size_t i = 0; i < shape.size(); i++) { @@ -224,10 +226,10 @@ String ShapeString(const std::vector& shape, DLDataType dtype) { sizes << shape[i]; } sizes << "]"; - return String(sizes.str()); + return ffi::String(sizes.str()); } -String ShapeString(const std::vector& shapes) { +ffi::String ShapeString(const std::vector& shapes) { std::stringstream sizes; for (const Tensor& ary : shapes) { if (sizes.tellp() > 0) { @@ -243,10 +245,10 @@ String ShapeString(const std::vector& shapes) { } sizes << "]"; } - return String(sizes.str()); + return ffi::String(sizes.str()); } -String ReportNode::AsCSV() const { +ffi::String ReportNode::AsCSV() const { // get unique headers std::set unique_headers; @@ -300,7 +302,7 @@ String ReportNode::AsCSV() const { namespace { void metric_as_json(std::ostream& os, ffi::Any o) { - if (auto opt_str = o.as()) { + if (auto opt_str = o.as()) { os << "{\"string\":" << "\"" << *opt_str << "\"" << "}"; @@ -321,7 +323,7 @@ void metric_as_json(std::ostream& os, ffi::Any o) { } } // namespace -String ReportNode::AsJSON() const { +ffi::String ReportNode::AsJSON() const { std::ostringstream s; // DMLC's JSONWriter does not allow us to write a key value pair without // implementing Write for the value. We want a specific write for the value, @@ -395,29 +397,29 @@ Any AggregateMetric(const std::vector& metrics) { for (auto& metric : metrics) { sum += metric.as()->microseconds; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { int64_t sum = 0; for (auto& metric : metrics) { sum += metric.as()->value; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { double sum = 0; for (auto& metric : metrics) { sum += metric.as()->percent; } - return ObjectRef(make_object(sum)); + return ObjectRef(ffi::make_object(sum)); } else if (metrics[0].as()) { double sum = 0; for (auto& metric : metrics) { sum += metric.as()->ratio; } - return ObjectRef(make_object(sum / metrics.size())); + return ObjectRef(ffi::make_object(sum / metrics.size())); } else if (auto opt_str = metrics[0].as()) { for (auto& m : metrics) { if (*opt_str != m.as()) { - return String(""); + return ffi::String(""); } } // Assume all strings in metrics are the same. @@ -442,7 +444,7 @@ static void set_locale_for_separators(std::stringstream& s) { } } -static String print_metric(ffi::Any metric) { +static ffi::String print_metric(ffi::Any metric) { std::string val; if (metric.as()) { std::stringstream s; @@ -471,23 +473,23 @@ static String print_metric(ffi::Any metric) { return val; } -String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { +ffi::String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes - std::vector> aggregated_calls; + std::vector> aggregated_calls; if (aggregate) { std::unordered_map> aggregates; for (size_t i = 0; i < calls.size(); i++) { auto& frame = calls[i]; auto it = frame.find("Hash"); - std::string name = frame["Name"].cast(); + std::string name = frame["Name"].cast(); if (it != frame.end()) { - name = (*it).second.cast(); + name = (*it).second.cast(); } if (frame.find("Argument Shapes") != frame.end()) { - name += frame["Argument Shapes"].cast(); + name += frame["Argument Shapes"].cast(); } if (frame.find("Device") != frame.end()) { - name += frame["Device"].cast(); + name += frame["Device"].cast(); } if (aggregates.find(name) == aggregates.end()) { @@ -497,7 +499,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } } for (const auto& p : aggregates) { - std::unordered_map aggregated; + std::unordered_map aggregated; std::unordered_set metrics; for (auto& call : calls) { for (auto& metric : call) { @@ -509,7 +511,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con for (auto i : p.second) { auto& call = calls[i]; auto it = std::find_if(call.begin(), call.end(), - [&metric](const std::pair& call_metric) { + [&metric](const std::pair& call_metric) { return std::string(call_metric.first) == metric; }); if (it != call.end()) { @@ -530,16 +532,17 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con // sort rows by duration if (sort) { - std::sort(aggregated_calls.begin(), aggregated_calls.end(), - [&](const Map& a, const Map& b) { - return a.at("Duration (us)").as()->microseconds > - b.at("Duration (us)").as()->microseconds; - }); + std::sort( + aggregated_calls.begin(), aggregated_calls.end(), + [&](const ffi::Map& a, const ffi::Map& b) { + return a.at("Duration (us)").as()->microseconds > + b.at("Duration (us)").as()->microseconds; + }); } // compute columnwise sums if (compute_col_sums) { - std::unordered_map col_sums; + std::unordered_map col_sums; for (auto call : aggregated_calls) { for (auto p : call) { if (p.second.as()) { @@ -548,35 +551,35 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con if (it != col_sums.end()) { val += it->second.as()->value; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { double val = p.second.as()->microseconds; auto it = col_sums.find(p.first); if (it != col_sums.end()) { val += it->second.as()->microseconds; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { double val = p.second.as()->percent; auto it = col_sums.find(p.first); if (it != col_sums.end()) { val += it->second.as()->percent; } - col_sums[p.first] = ObjectRef(make_object(val)); + col_sums[p.first] = ObjectRef(ffi::make_object(val)); } else if (p.second.as()) { // It does not make sense to sum ratios } } } - col_sums["Name"] = String("Sum"); - aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator + col_sums["Name"] = ffi::String("Sum"); + aggregated_calls.push_back({{ffi::String("Name"), ffi::String("----------")}}); // separator aggregated_calls.push_back(col_sums); } // per-device metrics for (auto p : device_metrics) { - Map metrics = p.second; - metrics.Set("Name", String("Total")); + ffi::Map metrics = p.second; + metrics.Set("Name", ffi::String("Total")); aggregated_calls.push_back(metrics); } @@ -660,14 +663,14 @@ std::string DeviceString(Device dev) { Report Profiler::Report() { // sync all timers and normalize rows - std::vector> rows; + std::vector> rows; for (auto& cf : calls_) { - std::unordered_map row; + std::unordered_map row; double us = cf.timer->SyncAndGetElapsedNanos() / 1e3; - row["Duration (us)"] = ObjectRef(make_object(us)); - row["Count"] = ObjectRef(make_object(1)); + row["Duration (us)"] = ObjectRef(ffi::make_object(us)); + row["Count"] = ObjectRef(ffi::make_object(1)); row["Name"] = cf.name; - row["Device"] = String(DeviceString(cf.dev)); + row["Device"] = ffi::String(DeviceString(cf.dev)); for (auto p : cf.extra_metrics) { row[p.first] = p.second; } @@ -676,23 +679,23 @@ Report Profiler::Report() { // the last frames are the overall times double overall_time_us = 0; - std::unordered_map> device_metrics; + std::unordered_map> device_metrics; for (size_t i = 0; i < devs_.size(); i++) { auto row = rows[rows.size() - 1]; rows.pop_back(); - device_metrics[row["Device"].cast()] = row; + device_metrics[row["Device"].cast()] = row; overall_time_us = std::max(overall_time_us, row["Duration (us)"].as()->microseconds); } // Calculate percentages for (auto& row : rows) { - row["Percent"] = ObjectRef(make_object( + row["Percent"] = ObjectRef(ffi::make_object( row["Duration (us)"].as()->microseconds / overall_time_us * 100)); } // convert to map - std::vector> converted_rows; + std::vector> converted_rows; for (const auto& row : rows) { converted_rows.push_back(row); } @@ -700,20 +703,20 @@ Report Profiler::Report() { return profiling::Report(converted_rows, device_metrics, configuration_); } -Report::Report(Array> calls, - Map> device_metrics, - Map configuration) { - auto node = make_object(); +Report::Report(ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration) { + auto node = ffi::make_object(); node->calls = std::move(calls); node->device_metrics = std::move(device_metrics); node->configuration = std::move(configuration); data_ = std::move(node); } -Map parse_metrics(dmlc::JSONReader* reader) { +ffi::Map parse_metrics(dmlc::JSONReader* reader) { reader->BeginObject(); std::string metric_name, metric_value_name; - Map metrics; + ffi::Map metrics; while (reader->NextObjectItem(&metric_name)) { ffi::Any o; reader->BeginObject(); @@ -721,23 +724,23 @@ Map parse_metrics(dmlc::JSONReader* reader) { if (metric_value_name == "microseconds") { double microseconds; reader->Read(µseconds); - o = ObjectRef(make_object(microseconds)); + o = ObjectRef(ffi::make_object(microseconds)); } else if (metric_value_name == "percent") { double percent; reader->Read(&percent); - o = ObjectRef(make_object(percent)); + o = ObjectRef(ffi::make_object(percent)); } else if (metric_value_name == "count") { int64_t count; reader->Read(&count); - o = ObjectRef(make_object(count)); + o = ObjectRef(ffi::make_object(count)); } else if (metric_value_name == "ratio") { double ratio; reader->Read(&ratio); - o = ObjectRef(make_object(ratio)); + o = ObjectRef(ffi::make_object(ratio)); } else if (metric_value_name == "string") { std::string s; reader->Read(&s); - o = String(s); + o = ffi::String(s); } else { LOG(FATAL) << "Cannot parse metric of type " << metric_value_name << " valid types are microseconds, percent, count."; @@ -752,13 +755,13 @@ Map parse_metrics(dmlc::JSONReader* reader) { return metrics; } -Report Report::FromJSON(String json) { +Report Report::FromJSON(ffi::String json) { std::stringstream input(json.operator std::string()); dmlc::JSONReader reader(&input); std::string key; - Array> calls; - Map> device_metrics; - Map configuration; + ffi::Array> calls; + ffi::Map> device_metrics; + ffi::Map configuration; reader.BeginObject(); while (reader.NextObjectItem(&key)) { @@ -793,7 +796,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device_type, - int device_id, int warmup_iters, Array collectors) { + int device_id, int warmup_iters, + ffi::Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable return ffi::Function::FromPacked([=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { @@ -810,7 +814,7 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device for (auto& collector : collectors) { collector->Init({DeviceWrapper(dev)}); } - std::vector> results; + std::vector> results; results.reserve(collectors.size()); std::vector> collector_data; collector_data.reserve(collectors.size()); @@ -828,7 +832,7 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device for (auto& kv : collector_data) { results.push_back(kv.first->Stop(kv.second)); } - Map combined_results; + ffi::Map combined_results; for (auto m : results) { for (auto p : m) { // assume that there is no shared metric name between collectors @@ -843,8 +847,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "runtime.profiling.ProfileFunction", - [](ffi::Module mod, String func_name, int device_type, int device_id, int warmup_iters, - Array collectors) { + [](ffi::Module mod, ffi::String func_name, int device_type, int device_id, int warmup_iters, + ffi::Array collectors) { if (mod->kind() == std::string("rpc")) { LOG(FATAL) << "Profiling a module over RPC is not yet supported"; // because we can't send @@ -925,18 +929,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.profiling.Report", - [](Array> calls, Map> device_metrics, - Map configuration) { + [](ffi::Array> calls, + ffi::Map> device_metrics, + ffi::Map configuration) { return Report(calls, device_metrics, configuration); }) .def("runtime.profiling.Count", - [](int64_t count) { return ObjectRef(make_object(count)); }) + [](int64_t count) { return ObjectRef(ffi::make_object(count)); }) .def("runtime.profiling.Percent", - [](double percent) { return ObjectRef(make_object(percent)); }) + [](double percent) { return ObjectRef(ffi::make_object(percent)); }) .def("runtime.profiling.Duration", - [](double duration) { return ObjectRef(make_object(duration)); }) + [](double duration) { return ObjectRef(ffi::make_object(duration)); }) .def("runtime.profiling.Ratio", - [](double ratio) { return ObjectRef(make_object(ratio)); }); + [](double ratio) { return ObjectRef(ffi::make_object(ratio)); }); }); } // namespace profiling diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 9692b811a40c..5b2287e61b5e 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -299,7 +299,8 @@ class ROCMTimerNode : public TimerNode { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("profiling.timer.rocm", [](Device dev) { return Timer(make_object()); }) + .def("profiling.timer.rocm", + [](Device dev) { return Timer(ffi::make_object()); }) .def("runtime.get_rocm_stream", []() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index f6beaca210bc..3ef9bf47a9b1 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -69,9 +69,9 @@ class ROCMModuleNode : public ffi::ModuleObj { int GetPropertyMask() const final { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -90,7 +90,7 @@ class ROCMModuleNode : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == fmt_) { return data_; } @@ -198,7 +198,7 @@ class ROCMWrappedFunc { LaunchParamConfig launch_param_config_; }; -Optional ROCMModuleNode::GetFunction(const String& name) { +ffi::Optional ROCMModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -212,7 +212,7 @@ Optional ROCMModuleNode::GetFunction(const String& name) { ffi::Module ROCMModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string hip_source, std::string assembly) { - auto n = make_object(data, fmt, fmap, hip_source, assembly); + auto n = ffi::make_object(data, fmt, fmap, hip_source, assembly); return ffi::Module(n); } diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index a02acd9611e3..2bddaff1a504 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -45,7 +45,7 @@ class RPCDeviceAPI final : public DeviceAPI { } void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final { + ffi::Optional mem_scope) final { auto sess = GetSess(dev); auto remote_dev = RemoveRPCSessionMask(dev); void* data = diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index e1282c17878a..c51484b2790f 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -261,7 +261,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Always wrap things back in RPCObjectRef // this is because we want to enable multi-hop RPC // and next hop would also need to check the object index - RPCObjectRef rpc_obj(make_object(reinterpret_cast(handle), nullptr)); + RPCObjectRef rpc_obj( + ffi::make_object(reinterpret_cast(handle), nullptr)); // Legacy ABI translation // TODO(tqchen): remove this once we have upgraded to new ABI *reinterpret_cast(out) = rpc_obj; @@ -433,7 +434,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { if (code == RPCCode::kException) { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); - String msg = args[0].cast(); + ffi::String msg = args[0].cast(); if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) { msg = "RPCError: Error caught from RPC call:\n" + msg; } @@ -962,7 +963,7 @@ void RPCDevAllocDataWithScope(RPCSession* handler, ffi::PackedArgs args, ffi::An int ndim = arr->ndim; int64_t* shape = arr->shape; DLDataType dtype = arr->dtype; - auto mem_scope = args[1].cast>(); + auto mem_scope = args[1].cast>(); void* data = handler->GetDeviceAPI(dev)->AllocDataSpace(dev, ndim, shape, dtype, mem_scope); *rv = data; } @@ -1154,7 +1155,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { } void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) final { + ffi::Optional mem_scope) final { DLTensor temp; temp.data = nullptr; temp.device = dev; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 97b90c25ac25..441c73989526 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -190,7 +190,7 @@ class RPCModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::ModulePropertyMask::kRunnable; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { if (name == "CloseRPCConnection") { return ffi::Function([this](ffi::PackedArgs, ffi::Any*) { sess_->Shutdown(); }); } @@ -199,7 +199,7 @@ class RPCModuleNode final : public ffi::ModuleObj { return WrapRemoteFunc(sess_->GetFunction(name)); } else { InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); - return remote_mod_get_function_(GetRef(this), name, true); + return remote_mod_get_function_(ffi::GetRef(this), name, true); } } @@ -215,12 +215,12 @@ class RPCModuleNode final : public ffi::ModuleObj { if (module_handle_ != nullptr) { return remote_get_time_evaluator_( - GetRef(this), name, static_cast(dev.device_type), dev.device_id, number, - repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, + ffi::GetRef(this), name, static_cast(dev.device_type), dev.device_id, + number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } else { return remote_get_time_evaluator_( - Optional(std::nullopt), name, static_cast(dev.device_type), + ffi::Optional(std::nullopt), name, static_cast(dev.device_type), dev.device_id, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc_name); } @@ -233,7 +233,7 @@ class RPCModuleNode final : public ffi::ModuleObj { void ImportModule(const ffi::Module& other) final { InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); - remote_import_module_(GetRef(this), other); + remote_import_module_(ffi::GetRef(this), other); } const std::shared_ptr& sess() { return sess_; } @@ -261,8 +261,8 @@ class RPCModuleNode final : public ffi::ModuleObj { // The local channel std::shared_ptr sess_; // remote function to get time evaluator - ffi::TypedFunction, std::string, int, int, int, int, int, int, - int, int, int, std::string)> + ffi::TypedFunction, std::string, int, int, int, int, int, + int, int, int, int, std::string)> remote_get_time_evaluator_; // remote function getter for modules. ffi::TypedFunction remote_mod_get_function_; @@ -303,7 +303,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } else if (type_index == ffi::TypeIndex::kTVMFFIModule) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); - auto n = make_object(handle, sess_); + auto n = ffi::make_object(handle, sess_); *rv = ffi::Module(n); } else if (type_index == ffi::TypeIndex::kTVMFFITensor || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { @@ -322,7 +322,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); - auto n = make_object(handle, sess_); + auto n = ffi::make_object(handle, sess_); *rv = ObjectRef(n); } else { ICHECK_EQ(args.size(), 2); @@ -331,7 +331,7 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) } ffi::Module CreateRPCSessionModule(std::shared_ptr sess) { - auto n = make_object(nullptr, sess); + auto n = ffi::make_object(nullptr, sess); RPCSession::InsertToSessionTable(sess); return ffi::Module(n); } @@ -397,7 +397,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.RPCTimeEvaluator", - [](Optional opt_mod, std::string name, int device_type, int device_id, + [](ffi::Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, std::string f_preproc_name) { @@ -420,7 +420,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - Optional pf = m->GetFunction(name); + ffi::Optional pf = m->GetFunction(name); CHECK(pf.has_value()) << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator( *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index d2f141ee21e0..91b3c01b6222 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -169,7 +169,7 @@ class SimpleSockHandler : public dmlc::Stream { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, String msg) { + refl::GlobalDef().def("rpc.ReturnException", [](int sockfd, ffi::String msg) { auto handler = SimpleSockHandler(sockfd); RPCReference::ReturnException(msg.c_str(), &handler); return; diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index b816fb600e1e..790915b37b91 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -47,7 +47,7 @@ class StaticLibraryNode final : public ffi::ModuleObj { public: const char* kind() const final { return "static_library"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { const ObjectPtr& sptr_to_self = ffi::GetObjectPtr(this); if (name == "get_func_names") { return ffi::Function( @@ -65,13 +65,13 @@ class StaticLibraryNode final : public ffi::ModuleObj { std::vector func_names; for (const auto func_name : func_names_) func_names.push_back(func_name); stream->Write(func_names); - return Bytes(buffer); + return ffi::Bytes(buffer); } static ffi::Module LoadFromBytes(ffi::Bytes bytes) { dmlc::MemoryFixedSizeStream ms(const_cast(bytes.data()), bytes.size()); dmlc::Stream* stream = &ms; - auto n = make_object(); + auto n = ffi::make_object(); // load data std::string data; ICHECK(stream->Read(&data)) << "Loading data failed"; @@ -80,12 +80,12 @@ class StaticLibraryNode final : public ffi::ModuleObj { // load func names std::vector func_names; ICHECK(stream->Read(&func_names)) << "Loading func names failed"; - for (auto func_name : func_names) n->func_names_.push_back(String(func_name)); + for (auto func_name : func_names) n->func_names_.push_back(ffi::String(func_name)); return ffi::Module(n); } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames() << " to '" << file_name << "'"; SaveBinaryToFile(file_name, data_); @@ -96,7 +96,7 @@ class StaticLibraryNode final : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name) final { + bool ImplementsFunction(const ffi::String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } @@ -119,13 +119,13 @@ class StaticLibraryNode final : public ffi::ModuleObj { /*! \brief Contents of the object file. */ std::string data_; /*! \brief Function names exported by the above. */ - Array func_names_; + ffi::Array func_names_; }; } // namespace -ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names) { - auto node = make_object(); +ffi::Module LoadStaticLibrary(const std::string& filename, ffi::Array func_names) { + auto node = ffi::make_object(); LoadBinaryFromFile(filename, &node->data_); node->func_names_ = std::move(func_names); VLOG(0) << "Loaded static library from '" << filename << "' implementing " << node->FuncNames(); diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 8a5600fc0588..2ebca2edd277 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -43,7 +43,7 @@ namespace runtime { * \brief Returns a static library with the contents loaded from filename which exports * func_names with the usual packed-func calling convention. */ -ffi::Module LoadStaticLibrary(const std::string& filename, Array func_names); +ffi::Module LoadStaticLibrary(const std::string& filename, ffi::Array func_names); } // namespace runtime } // namespace tvm diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 2e418304fa82..b655a5c611fc 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -97,7 +97,8 @@ void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional mem_scope) { +Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { tensor->data = DeviceAPI::Get(tensor->device) @@ -180,7 +181,7 @@ void Tensor::CopyFromBytes(const void* data, size_t nbytes) { TensorCopyFromBytes(get_mutable(), data, nbytes); } -Tensor Tensor::CopyTo(const Device& dev, Optional mem_scope) const { +Tensor Tensor::CopyTo(const Device& dev, ffi::Optional mem_scope) const { ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); Tensor ret = diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index deaeec6ad3a0..443098e08369 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -141,7 +141,7 @@ class ParallelLauncher { // The counter page. std::atomic* sync_counter_{nullptr}; // The error message - std::vector> par_errors_; + std::vector> par_errors_; }; /*! \brief Lock-free single-producer-single-consumer queue for each thread */ @@ -389,7 +389,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int nthreads = args[1].cast(); std::vector cpus; if (args.size() >= 3) { - auto cpu_array = args[2].cast>(); + auto cpu_array = args[2].cast>(); for (auto cpu : cpu_array) { ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index c8fbd9082103..3b37d9810b1c 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -25,12 +25,12 @@ namespace tvm { namespace runtime { namespace vm { -std::unique_ptr ConvertPagedPrefillFunc(Array args, +std::unique_ptr ConvertPagedPrefillFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -47,12 +47,12 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, throw; } -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -69,11 +69,12 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array args throw; } -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind) { +std::unique_ptr ConvertPagedDecodeFunc(ffi::Array args, + AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -90,12 +91,12 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At throw; } -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); @@ -105,12 +106,12 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< throw; } -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, - AttnKind attn_kind) { +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( + ffi::Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = args[0].cast(); + ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 4017738d6685..bc58d1c9e1d8 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -497,7 +497,8 @@ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillFunc pointer. */ -std::unique_ptr ConvertPagedPrefillFunc(Array args, AttnKind attn_kind); +std::unique_ptr ConvertPagedPrefillFunc(ffi::Array args, + AttnKind attn_kind); /*! * \brief Create a PagedDecodeFunc from the given arguments and the attention kind. @@ -505,7 +506,8 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedDecodeFunc pointer. */ -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); +std::unique_ptr ConvertPagedDecodeFunc(ffi::Array args, + AttnKind attn_kind); /*! * \brief Create a RaggedPrefillFunc from the given arguments and the attention kind. @@ -513,7 +515,7 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array args, AttnKind attn_kind); /*! @@ -522,7 +524,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array args * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::Array args, AttnKind attn_kind); /*! @@ -531,8 +533,8 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, - AttnKind attn_kind); +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( + ffi::Array args, AttnKind attn_kind); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 5eff9452c5b9..09557a8f0a27 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -706,7 +706,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * offset to the destination Tensor. */ void CopyVecDataToArray(Tensor array, int32_t* vec_data, - Optional shape = std::nullopt, int dst_elem_offset = 0) { + ffi::Optional shape = std::nullopt, int dst_elem_offset = 0) { if (array->shape[0] == 0) { return; } diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index bb07cbe44255..1a0da132f522 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -88,7 +88,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \sa MatchShape */ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t reg, - Optional err_ctx) { + ffi::Optional err_ctx) { int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); MatchShapeCode code = static_cast(code_value); @@ -134,7 +134,7 @@ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { ICHECK_LE(kBeginCode + size * 2, args.size()); // a function that lazily get context for error reporting const int64_t kErrorContextOffset = kBeginCode + size * 2; - Optional err_ctx = args[kErrorContextOffset].cast(); + ffi::Optional err_ctx = args[kErrorContextOffset].cast(); CHECK_EQ(input_shape.size(), size) << "RuntimeError: " << err_ctx.value_or("") << " match_cast shape size mismatch."; @@ -238,14 +238,14 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { ffi::AnyView arg = args[0]; int ndim = args[1].cast(); DataType dtype; - Optional err_ctx; + ffi::Optional err_ctx; if (args.size() == 3) { dtype = DataType::Void(); - err_ctx = args[2].cast>(); + err_ctx = args[2].cast>(); } else { dtype = args[2].cast(); - err_ctx = args[3].cast>(); + err_ctx = args[3].cast>(); } auto opt_ptr = arg.try_cast(); @@ -276,7 +276,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param ndim Expected size of the shape, can be -1 (indicate unknown). * \param err_ctx Additional context if error occurs. */ -void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { +void CheckShapeInfo(ObjectRef arg, int ndim, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " @@ -299,7 +299,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. * \param err_ctx Additional context if error occurs. */ -void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, Optional err_ctx) { +void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional err_ctx) { if (auto opt_obj = arg.as()) { LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); @@ -329,7 +329,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param size The expected size of the tuple. * \param err_ctx Additional context if error occurs. */ -void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { +void CheckTupleInfo(ObjectRef arg, int64_t size, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " @@ -349,7 +349,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ * \param arg The input argument. * \param err_ctx Additional context if error occurs. */ -void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { +void CheckFuncInfo(ObjectRef arg, ffi::Optional err_ctx) { // a function that lazily get context for error reporting bool is_func = arg.as() || arg.as(); CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " @@ -365,7 +365,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // Storage management. //------------------------------------------------- Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_index, - DLDataType dtype_hint, String mem_scope) { + DLDataType dtype_hint, ffi::String mem_scope) { VirtualMachine* vm = static_cast(ctx_ptr); ICHECK_LT(device_index, vm->devices.size()) @@ -508,12 +508,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ int num_args = args.size() - 3; ObjectRef io_effect = args[0].cast(); ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be lowered to None."; - String debug_func_name = args[1].cast(); + ffi::String debug_func_name = args[1].cast(); const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " << "Use the decorator `@tvm.register_global_func(\"" << debug_func_name << "\")` to register it."; - String line_info = args[2].cast(); + ffi::String line_info = args[2].cast(); std::vector call_args(num_args + 1); { call_args[0] = line_info; @@ -533,14 +533,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("vm.builtin.tuple_getitem", - [](Array arr, int64_t index) { return arr[index]; }) + [](ffi::Array arr, int64_t index) { return arr[index]; }) .def("vm.builtin.tuple_reset_item", [](const ffi::ArrayObj* arr, int64_t index) { const_cast(arr)->SetItem(index, nullptr); }) .def_packed("vm.builtin.make_tuple", [](ffi::PackedArgs args, ffi::Any* rv) { - Array arr; + ffi::Array arr; for (int i = 0; i < args.size(); ++i) { arr.push_back(args[i]); } diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index d7ccff66a046..ec841b5ed2d5 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -44,7 +44,7 @@ struct CUDAGraphCaptureKey { // identified by this shape tuple. This is default constructed as an empty tuple. ffi::Shape shape_expr; - CUDAGraphCaptureKey(int64_t index, const Optional& shape_expr) : index(index) { + CUDAGraphCaptureKey(int64_t index, const ffi::Optional& shape_expr) : index(index) { if (shape_expr) { this->shape_expr = shape_expr.value(); } @@ -153,7 +153,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { * \return The return value of the capture function. */ ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, Any args, - int64_t entry_index, Optional shape_expr) { + int64_t entry_index, ffi::Optional shape_expr) { CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { // Launch CUDA graph @@ -166,7 +166,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { } // Set up arguments for the graph execution - Array tuple_args = args.cast>(); + ffi::Array tuple_args = args.cast>(); int nargs = static_cast(tuple_args.size()); std::vector packed_args(nargs); @@ -242,7 +242,7 @@ class CUDAGraphExtension : public VMExtension { public: TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode); static CUDAGraphExtension Create() { - auto data_ = make_object(); + auto data_ = ffi::make_object(); return CUDAGraphExtension(std::move(data_)); } }; @@ -258,7 +258,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto capture_func = args[1].cast(); Any func_args = args[2]; int64_t entry_index = args[3].cast(); - Optional shape_expr = std::nullopt; + ffi::Optional shape_expr = std::nullopt; if (args.size() == 5) { shape_expr = args[4].cast(); } diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 287af83c6058..3d72afc42148 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -74,7 +74,7 @@ std::string VMExecutable::Stats() const { } oss.seekp(-2, oss.cur); oss << "], "; - } else if (auto opt_str = it.as()) { + } else if (auto opt_str = it.as()) { std::string f = opt_str.value(); oss << "\""; oss << f; @@ -181,7 +181,7 @@ ffi::Bytes VMExecutable::SaveToBytes() const { return ffi::Bytes(code); } -void VMExecutable::WriteToFile(const String& file_name, const String& format) const { +void VMExecutable::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { runtime::SaveBinaryToFile(file_name, VMExecutable::SaveToBytes()); } @@ -189,7 +189,7 @@ ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { std::string code; dmlc::MemoryFixedSizeStream strm(const_cast(bytes.data()), bytes.size()); - ObjectPtr exec = make_object(); + ObjectPtr exec = ffi::make_object(); // Load header. LoadHeader(&strm); @@ -206,7 +206,7 @@ ffi::Module VMExecutable::LoadFromBytes(const ffi::Bytes& bytes) { return ffi::Module(exec); } -ffi::Module VMExecutable::LoadFromFile(const String& file_name) { +ffi::Module VMExecutable::LoadFromFile(const ffi::String& file_name) { std::string data; runtime::LoadBinaryFromFile(file_name, &data); return VMExecutable::LoadFromBytes(ffi::Bytes(data)); @@ -258,8 +258,8 @@ void VMExecutable::SaveConstantSection(dmlc::Stream* strm) const { for (size_t i = 0; i < shape.size(); ++i) { strm->Write(shape.at(i)); } - } else if (auto opt_str = it.as()) { - String str = opt_str.value(); + } else if (auto opt_str = it.as()) { + ffi::String str = opt_str.value(); strm->Write(ffi::TypeIndex::kTVMFFIStr); strm->Write(str.size()); for (size_t i = 0; i < str.size(); ++i) { @@ -333,7 +333,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { strm->Read(&(data[i])); } ffi::Any cell; - cell = String(std::string(data.begin(), data.end())); + cell = ffi::String(std::string(data.begin(), data.end())); this->constants.push_back(cell); } else if (constant_type == ffi::TypeIndex::kTVMFFIInt) { int64_t value; @@ -395,9 +395,9 @@ ffi::Module VMExecutable::VMProfilerLoadExecutable() const { return ffi::Module(vm); } -bool VMExecutable::HasFunction(const String& name) const { return func_map.count(name); } +bool VMExecutable::HasFunction(const ffi::String& name) const { return func_map.count(name); } -String VMExecutable::AsText() const { +ffi::String VMExecutable::AsText() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return func_table[index].name; @@ -471,10 +471,10 @@ String VMExecutable::AsText() const { } os << "\n"; } - return String(os.str()); + return ffi::String(os.str()); } -String VMExecutable::AsPython() const { +ffi::String VMExecutable::AsPython() const { auto get_func_name = [&](Index index) -> std::string { if (static_cast(index) < func_table.size()) { return "\"" + func_table[index].name + "\""; @@ -549,7 +549,7 @@ String VMExecutable::AsPython() const { } } } - return String(os.str()); + return ffi::String(os.str()); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/runtime/vm/kv_state.cc b/src/runtime/vm/kv_state.cc index 366e22c36baf..9958b01deb3d 100644 --- a/src/runtime/vm/kv_state.cc +++ b/src/runtime/vm/kv_state.cc @@ -45,9 +45,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ KVState kv_state = args[0].cast(); ffi::Shape seq_ids = args[1].cast(); ffi::Shape append_lengths = args[2].cast(); - Optional token_tree_parent_ptr; + ffi::Optional token_tree_parent_ptr; if (args.size() == 4) { - token_tree_parent_ptr = args[3].cast>(); + token_tree_parent_ptr = args[3].cast>(); } kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); }) diff --git a/src/runtime/vm/kv_state.h b/src/runtime/vm/kv_state.h index de42488b7f40..fa56ff6426cd 100644 --- a/src/runtime/vm/kv_state.h +++ b/src/runtime/vm/kv_state.h @@ -94,8 +94,9 @@ class KVStateObj : public Object { * is the sum of "append_lengths". Nullptr means the token tree of each sequence * is a chain. */ - virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, - const Optional& token_tree_parent_ptr = std::nullopt) = 0; + virtual void BeginForward( + const IntTuple& seq_ids, const IntTuple& append_lengths, + const ffi::Optional& token_tree_parent_ptr = std::nullopt) = 0; /*! * \brief Mark the start of the forward function. @@ -178,7 +179,7 @@ class AttentionKVCacheObj : public KVStateObj { * \param sm_scale The additional attention scaling factor. * \sa AttentionKVCache::Attention */ - virtual void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, Optional mask, + virtual void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, ffi::Optional mask, Tensor o_data, double sm_scale) = 0; /*! @@ -220,8 +221,8 @@ class AttentionKVCacheObj : public KVStateObj { * \param lse2_data The second source LSE data. * \return The merged O and LSE data. */ - virtual Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, - Tensor o_cross_attn, Tensor lse_cross_attn) = 0; + virtual ffi::Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, + Tensor o_cross_attn, Tensor lse_cross_attn) = 0; /*! * \brief Compute linear attention with Q/K/V data. diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index 416ece17b402..4ccacf7ab7ff 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -240,7 +240,7 @@ class AttentionKVCacheLegacy : public ObjectRef { */ static AttentionKVCacheLegacy Create(Tensor init_data, ffi::Shape reserve_shape, int init_fill_count) { - auto n = make_object(); + auto n = ffi::make_object(); n->data = Tensor::Empty(reserve_shape, init_data->dtype, init_data->device); n->fill_count = 0; n->Append(init_data); @@ -334,7 +334,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -void AttentionKVCacheArrayPopN(Array caches, int64_t n) { +void AttentionKVCacheArrayPopN(ffi::Array caches, int64_t n) { for (AttentionKVCacheLegacy cache : caches) { cache->PopN(static_cast(n)); } @@ -345,7 +345,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_popn", AttentionKVCacheArrayPopN); }); -void AttentionKVCacheArrayClear(Array caches) { +void AttentionKVCacheArrayClear(ffi::Array caches) { for (AttentionKVCacheLegacy cache : caches) { cache->Clear(); } diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 9ac3ab95ccf2..631d1c8be69d 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -111,7 +111,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The RoPE theta. */ const double rotary_theta_; /*! \brief The optional RoPE extension factors for RoPE scaling. */ - const Optional rope_ext_factors_; + const ffi::Optional rope_ext_factors_; /*! \brief The KV cache dtype. */ const DataType kv_dtype_; @@ -251,10 +251,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; - Optional f_transpose_append_mha_; - Optional f_transpose_append_mla_; - Optional f_transfer_kv_; - Optional f_transfer_kv_page_to_page_ = std::nullopt; + ffi::Optional f_transpose_append_mha_; + ffi::Optional f_transpose_append_mla_; + ffi::Optional f_transfer_kv_; + ffi::Optional f_transfer_kv_page_to_page_ = std::nullopt; ffi::Function f_compact_copy_; std::unique_ptr f_attention_prefill_ragged_; std::unique_ptr f_attention_prefill_; @@ -264,10 +264,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv_; std::unique_ptr f_attention_prefill_with_tree_mask_; std::unique_ptr f_mla_prefill_; - Array f_merge_inplace_; + ffi::Array f_merge_inplace_; ffi::Function f_split_rotary_; ffi::Function f_copy_single_page_; - Optional f_debug_get_kv_; + ffi::Optional f_debug_get_kv_; /*! \brief The device this PagedKVCache runs on. */ Device device_; @@ -286,9 +286,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t v_head_dim, std::vector attn_kinds, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, - Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device, - Optional f_transpose_append_mha, - Optional f_transpose_append_mla, ffi::Function f_compact_copy, + ffi::Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, + Device device, ffi::Optional f_transpose_append_mha, + ffi::Optional f_transpose_append_mla, ffi::Function f_compact_copy, std::unique_ptr f_attention_prefill_ragged, std::unique_ptr f_attention_prefill, std::unique_ptr f_attention_decode, @@ -296,7 +296,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unique_ptr f_attention_decode_sliding_window, std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv, std::unique_ptr f_attention_prefill_with_tree_mask, - std::unique_ptr f_mla_prefill, Array f_merge_inplace, + std::unique_ptr f_mla_prefill, ffi::Array f_merge_inplace, ffi::Function f_split_rotary, ffi::Function f_copy_single_page, ffi::Function f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), @@ -849,7 +849,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Attention **************/ void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + const ffi::Optional& opt_token_tree_parent_ptr) final { // Note: MLA does not supported tree attention for now. if (attn_kinds_[0] == AttnKind::kMLA) { CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA"; @@ -1271,7 +1271,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequence->kv_transfer_metadata.local_position_map.end()); } - void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, Optional mask, + void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, ffi::Optional mask, Tensor o_data, double sm_scale) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; @@ -1481,8 +1481,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transpose_append_mla_.value()(pages_[local_layer_id], kv_data, append_position_map_view_); } - Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, - Tensor o_cross_attn, Tensor lse_cross_attn) final { + ffi::Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, + Tensor o_cross_attn, Tensor lse_cross_attn) final { CHECK_GE(f_merge_inplace_.size(), 2) << "The general attention merge function is not defined."; f_merge_inplace_[1](o_self_attn, lse_self_attn, o_cross_attn, lse_cross_attn); return {o_self_attn, lse_self_attn}; @@ -2463,27 +2463,27 @@ TVM_FFI_STATIC_INIT_BLOCK({ int rope_mode = args[8].cast(); double rotary_scale = args[9].cast(); double rotary_theta = args[10].cast(); - Optional rope_ext_factors = std::nullopt; // args[11] + ffi::Optional rope_ext_factors = std::nullopt; // args[11] Tensor init = args[12].cast(); - Optional f_transpose_append_mha = std::nullopt; // args[13] - Optional f_transpose_append_mla = std::nullopt; // args[14] + ffi::Optional f_transpose_append_mha = std::nullopt; // args[13] + ffi::Optional f_transpose_append_mla = std::nullopt; // args[14] std::unique_ptr f_attention_prefill_ragged = - ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill = - ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode = - ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_sliding_window = - ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode_sliding_window = - ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = - ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); + ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask = - ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); std::unique_ptr f_mla_prefill = - ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); - Array f_merge_inplace = args[23].cast>(); + ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); + ffi::Array f_merge_inplace = args[23].cast>(); ffi::Function f_split_rotary = args[24].cast(); ffi::Function f_copy_single_page = args[25].cast(); ffi::Function f_debug_get_kv = args[26].cast(); @@ -2492,7 +2492,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ if (auto opt_nd = args[11].as()) { rope_ext_factors = opt_nd.value(); } - auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { + auto f_convert_optional_packed_func = [&args](int arg_idx) -> ffi::Optional { if (auto opt_func = args[arg_idx].as()) { return opt_func.value(); } @@ -2521,7 +2521,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // NOTE: We will remove this legacy construction after finishing the transition phase. // Some `ffi::Function()` here are placeholders that will be filled. - ObjectPtr n = make_object( + ObjectPtr n = ffi::make_object( page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 76457dd0d113..f88b30b6ad9c 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -80,7 +80,7 @@ class RNNStateImpObj : public RNNStateObj { * \brief The init value for ALL layer in the storage. * The array has `num_states_per_layer_` Tensors */ - const Array init_layer_value_; + const ffi::Array init_layer_value_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -94,7 +94,7 @@ class RNNStateImpObj : public RNNStateObj { * \note As `num_states_per_layer_` may vary for different dtype and shape, * we use a 2D array to store the Tensors for each layer. */ - Array> storages_; + ffi::Array> storages_; /*! \brief The list of ids of released seq slot for reuse. */ std::vector free_slot_ids_; /*! \brief The mapping from sequence ids to sequences. */ @@ -140,7 +140,7 @@ class RNNStateImpObj : public RNNStateObj { * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_gets_; + ffi::Array f_gets_; /*! * \brief The function to set the state data to the storage. * The function signature is `f_set_(state, seq_slot_ids, history_slot_ids, data, max_history)`. @@ -151,17 +151,17 @@ class RNNStateImpObj : public RNNStateObj { * \note Each state data per layer may have different dtype and shape, so we use a * different function for each state data. */ - Array f_sets_; + ffi::Array f_sets_; public: /*! \brief Constructor. Take the cache configuration and initialize the Tensors. */ - explicit RNNStateImpObj(int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - DLDevice device, // - Array f_gets, // - Array f_sets, // - Array init_layer_value) + explicit RNNStateImpObj(int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + DLDevice device, // + ffi::Array f_gets, // + ffi::Array f_sets, // + ffi::Array init_layer_value) : num_layers_(num_layers), reserved_num_seqs_(reserved_num_seqs), num_states_per_layer_(init_layer_value.size()), @@ -172,7 +172,7 @@ class RNNStateImpObj : public RNNStateObj { // Allocate the storage for the space state models. storages_.reserve(num_layers_); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { - Array layer_storages; + ffi::Array layer_storages; layer_storages.reserve(num_states_per_layer_); for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) { ffi::Shape state_shape = init_layer_value[state_id].Shape(); @@ -208,7 +208,7 @@ class RNNStateImpObj : public RNNStateObj { /************** Interaction **************/ void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + const ffi::Optional& opt_token_tree_parent_ptr) final { CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; @@ -468,12 +468,12 @@ class RNNStateImpObj : public RNNStateObj { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("vm.builtin.rnn_state_create", [](int64_t num_layers, // - int64_t reserved_num_seqs, // - int64_t max_history, // - Array f_gets, // - Array f_sets, // - Array init_layer_value) { + refl::GlobalDef().def("vm.builtin.rnn_state_create", [](int64_t num_layers, // + int64_t reserved_num_seqs, // + int64_t max_history, // + ffi::Array f_gets, // + ffi::Array f_sets, // + ffi::Array init_layer_value) { CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; CHECK_GT(reserved_num_seqs, 0) << "The number of reserved sequences should be greater than 0."; CHECK_GE(max_history, 0) << "The maximum history length should be greater or equal than 0."; @@ -492,8 +492,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ << "The number of state setters should be the same as the number of states per layer, " << "but got " << f_sets.size() << " and " << init_layer_value.size() << " respectively."; ObjectPtr n = - make_object(num_layers, reserved_num_seqs, max_history, device, - std::move(f_gets), std::move(f_sets), init_layer_value); + ffi::make_object(num_layers, reserved_num_seqs, max_history, device, + std::move(f_gets), std::move(f_sets), init_layer_value); return RNNState(std::move(n)); }); }); diff --git a/src/runtime/vm/tensor_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc index cff92994e41f..2cc53c6d400f 100644 --- a/src/runtime/vm/tensor_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -77,7 +77,7 @@ TensorCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::o TensorCacheMetadata::FileRecord::ParamRecord result; std::string dtype = GetValue(json, "dtype"); result.name = GetValue(json, "name"); - result.dtype = DataType(StringToDLDataType(dtype)); + result.dtype = DataType(ffi::StringToDLDataType(dtype)); result.format = GetValue(json, "format"); result.nbytes = GetValue(json, "nbytes"); result.byte_offset = GetValue(json, "byteOffset"); @@ -142,7 +142,7 @@ TVM_DLL TensorCacheMetadata TensorCacheMetadata::Load(const std::string& path) { } void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, - Optional* staging_buffer) { + ffi::Optional* staging_buffer) { Device device = param->device; if (device.device_type != kDLOpenCL || staging_buffer == nullptr) { param.CopyFromBytes(data, nbytes); @@ -166,9 +166,8 @@ void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, DeviceAPI::Get(device)->StreamSync(device, nullptr); } -Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load(Device device, - const std::string* raw_data, - Optional* staging_buffer) const { +Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load( + Device device, const std::string* raw_data, ffi::Optional* staging_buffer) const { Tensor arr = Tensor::Empty(shape, dtype, device); if (dtype == DataType::Float(32) && format == "f32-to-bf16") { // decode bf16 to f32 @@ -185,17 +184,17 @@ Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load(Device device, return arr; } -TVM_DLL Array TensorCacheMetadata::FileRecord::Load( +TVM_DLL ffi::Array TensorCacheMetadata::FileRecord::Load( Device device, const std::string& path_prefix, // std::string* raw_data_buffer, // - Optional* staging_buffer) const { + ffi::Optional* staging_buffer) const { LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer); CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported"; CHECK_EQ(this->nbytes, raw_data_buffer->length()) << "ValueError: Encountered an corrupted parameter shard. It means it is not downloaded " "completely or downloading is interrupted. Please try to download again."; - Array result; + ffi::Array result; result.reserve(this->records.size()); for (const ParamRecord& nd_rec : this->records) { result.push_back(nd_rec.Load(device, raw_data_buffer, staging_buffer)); @@ -213,7 +212,7 @@ class TensorCache { return inst; } - static void Update(String name, Tensor arr, bool override) { + static void Update(ffi::String name, Tensor arr, bool override) { TensorCache* pool = Global(); if (!override) { ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache"; @@ -221,7 +220,7 @@ class TensorCache { pool->pool_.Set(name, arr); } - static Optional Get(String name) { + static ffi::Optional Get(ffi::String name) { TensorCache* pool = Global(); auto it = pool->pool_.find(name); if (it != pool->pool_.end()) { @@ -231,7 +230,7 @@ class TensorCache { } } - static void Remove(String name) { + static void Remove(ffi::String name) { TensorCache* pool = Global(); pool->pool_.erase(name); } @@ -247,9 +246,9 @@ class TensorCache { static void Load(const std::string& cache_path, int device_type, int device_id) { DLDevice device{static_cast(device_type), device_id}; TensorCacheMetadata metadata = TensorCacheMetadata::Load(cache_path); - Optional staging_buffer; + ffi::Optional staging_buffer; std::string raw_data; - Array params; + ffi::Array params; for (const TensorCacheMetadata::FileRecord& shard_rec : metadata.records) { try { params = shard_rec.Load(device, cache_path, &raw_data, &staging_buffer); @@ -265,7 +264,7 @@ class TensorCache { } private: - Map pool_; + ffi::Map pool_; }; TVM_FFI_STATIC_INIT_BLOCK({ @@ -275,7 +274,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("vm.builtin.tensor_cache.update", [](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 2 || args.size() == 3); - String name = args[0].cast(); + ffi::String name = args[0].cast(); bool is_override = args.size() == 2 ? false : args[2].cast(); Tensor arr; @@ -307,7 +306,7 @@ class ParamModuleNode : public ffi::ModuleObj { public: const char* kind() const final { return "param_module"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { if (name == "get_params") { auto params = params_; return ffi::Function([params](ffi::PackedArgs args, ffi::Any* rv) { *rv = params; }); @@ -316,8 +315,8 @@ class ParamModuleNode : public ffi::ModuleObj { } } - static Array GetParams(const String& prefix, int num_params) { - Array params; + static ffi::Array GetParams(const ffi::String& prefix, int num_params) { + ffi::Array params; for (int i = 0; i < num_params || num_params == -1; ++i) { std::string name = prefix + "_" + std::to_string(i); auto opt = TensorCache::Get(name); @@ -331,11 +330,11 @@ class ParamModuleNode : public ffi::ModuleObj { return params; } - static Array GetParamByName(const Array& names) { - Array result; + static ffi::Array GetParamByName(const ffi::Array& names) { + ffi::Array result; result.reserve(names.size()); - for (const String& name : names) { - if (Optional opt = TensorCache::Get(name)) { + for (const ffi::String& name : names) { + if (ffi::Optional opt = TensorCache::Get(name)) { result.push_back(opt.value()); } else { LOG(FATAL) << "ValueError: Cannot find parameter in cache: " << name; @@ -345,19 +344,19 @@ class ParamModuleNode : public ffi::ModuleObj { } static ffi::Module Create(const std::string& prefix, int num_params) { - auto n = make_object(); + auto n = ffi::make_object(); n->params_ = GetParams(prefix, num_params); return ffi::Module(n); } - static ffi::Module CreateByName(const Array& names) { - auto n = make_object(); + static ffi::Module CreateByName(const ffi::Array& names) { + auto n = ffi::make_object(); n->params_ = GetParamByName(names); return ffi::Module(n); } private: - Array params_; + ffi::Array params_; }; TVM_FFI_STATIC_INIT_BLOCK({ @@ -369,14 +368,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("vm.builtin.param_array_from_cache_by_name", ParamModuleNode::GetParamByName) .def_packed("vm.builtin.param_array_from_cache_by_name_unpacked", [](ffi::PackedArgs args, ffi::Any* rv) { - Array names; + ffi::Array names; names.reserve(args.size()); for (int i = 0; i < args.size(); ++i) { - if (!args[i].try_cast()) { + if (!args[i].try_cast()) { LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].GetTypeKey() << " at " << i; } - names.push_back(args[i].cast()); + names.push_back(args[i].cast()); } *rv = ParamModuleNode::GetParamByName(names); }); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 149948fb0ecf..be981b205cbb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -38,8 +38,8 @@ namespace vm { // VM Closure object //--------------------------------------------- -VMClosure::VMClosure(String func_name, ffi::Function impl) { - auto ptr = make_object(); +VMClosure::VMClosure(ffi::String func_name, ffi::Function impl) { + auto ptr = ffi::make_object(); ptr->func_name = func_name; ptr->impl = std::move(impl); data_ = std::move(ptr); @@ -103,7 +103,7 @@ Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { for (size_t i = 0; i < arr.size(); i++) { ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); } - return Array(ret.begin(), ret.end()); + return ffi::Array(ret.begin(), ret.end()); } else { return src; } @@ -189,7 +189,7 @@ class VirtualMachineImpl : public VirtualMachine { void LoadExecutable(ObjectPtr exec) final; void Init(const std::vector& devices, const std::vector& alloc_types) final; - VMClosure GetClosure(const String& func_name) final { + VMClosure GetClosure(const ffi::String& func_name) final { return this->GetClosureInternal(func_name, false).value(); } void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, ffi::PackedArgs args, @@ -210,7 +210,7 @@ class VirtualMachineImpl : public VirtualMachine { void _SetInputWithParamModule(ffi::PackedArgs args, ffi::Any* rv); int _GetFunctionArity(std::string func_name); std::string _GetFunctionParamName(std::string func_name, int index); - ffi::Function _LookupFunction(const String& name); + ffi::Function _LookupFunction(const ffi::String& name); TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine"); TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization", &VirtualMachineImpl::_Init); @@ -236,7 +236,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param allow_missing Whether none is allowed. * \return The result */ - Optional GetClosureInternal(const String& func_name, bool allow_missing); + ffi::Optional GetClosureInternal(const ffi::String& func_name, bool allow_missing); /*! * \brief Set inputs to a function. @@ -276,7 +276,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param args The arguments to bound to the function. * \note This function is used by RPC server to help benchmarking. */ - void SaveClosure(const String& func_name, const String& save_name, bool include_return, + void SaveClosure(const ffi::String& func_name, const ffi::String& save_name, bool include_return, ffi::PackedArgs args); /*! * \brief Internal function to invoke a closure. @@ -300,7 +300,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param name The name of the function. * \return The result function, can return ffi::Function(nullptr) if nothing is found. */ - Optional GetFuncFromImports(const String& name) { + ffi::Optional GetFuncFromImports(const ffi::String& name) { for (auto& lib : this->imports_) { if (auto opt_func = lib.cast()->GetFunction(name, true)) { return *opt_func; @@ -572,7 +572,7 @@ RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_pa return ret; } -void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, +void VirtualMachineImpl::SaveClosure(const ffi::String& func_name, const ffi::String& save_name, bool include_return, ffi::PackedArgs args) { VMClosure clo = this->GetClosure(func_name); std::vector inputs(args.size()); @@ -589,8 +589,8 @@ void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save saved_closures_[save_name] = VMClosure(save_name, impl); } -Optional VirtualMachineImpl::GetClosureInternal(const String& func_name, - bool allow_missing) { +ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::String& func_name, + bool allow_missing) { // look up saved closures. auto saved_it = saved_closures_.find(func_name); if (saved_it != saved_closures_.end()) { @@ -621,7 +621,7 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); - Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ffi::Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); ICHECK(tir_func.has_value()) << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { @@ -697,7 +697,7 @@ void VirtualMachineImpl::InitFuncPool() { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first - Optional func = GetFuncFromImports(info.name); + ffi::Optional func = GetFuncFromImports(info.name); if (!func.has_value()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); if (p_func.has_value()) func = *p_func; @@ -846,7 +846,9 @@ void VirtualMachineImpl::RunLoop() { } } -ObjectPtr VirtualMachine::Create() { return make_object(); } +ObjectPtr VirtualMachine::Create() { + return ffi::make_object(); +} //-------------------------------------------------------------------- // FFI related code @@ -869,7 +871,7 @@ void VirtualMachineImpl::_Init(ffi::PackedArgs args, ffi::Any* rv) { void VirtualMachineImpl::_SaveClosure(ffi::PackedArgs args, ffi::Any* rv) { ICHECK_GE(args.size(), 3); std::string func_name = args[0].cast(); - this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); + this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); } void VirtualMachineImpl::_InvokeClosure(ffi::PackedArgs args, ffi::Any* rv) { @@ -894,7 +896,7 @@ void VirtualMachineImpl::_SetInstrument(ffi::PackedArgs args, ffi::Any* rv) { if (args[0].as()) { this->SetInstrument(args[0].cast()); } else { - String func_name = args[0].cast(); + ffi::String func_name = args[0].cast(); const auto factory = tvm::ffi::Function::GetGlobal(func_name); CHECK(factory.has_value()) << "Cannot find factory " << func_name; ffi::Any rv; @@ -950,9 +952,9 @@ std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int return vm_func.param_names[index]; } -ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { - if (Optional opt = this->GetClosureInternal(name, true)) { - return ffi::Function([clo = opt.value(), _self = GetRef(this)]( +ffi::Function VirtualMachineImpl::_LookupFunction(const ffi::String& name) { + if (ffi::Optional opt = this->GetClosureInternal(name, true)) { + return ffi::Function([clo = opt.value(), _self = ffi::GetRef(this)]( ffi::PackedArgs args, ffi::Any* rv) -> void { auto* self = const_cast(_self.as()); ICHECK(self); @@ -973,7 +975,7 @@ ffi::Function VirtualMachineImpl::_LookupFunction(const String& name) { */ class VirtualMachineProfiler : public VirtualMachineImpl { public: - Optional GetFunction(const String& name) override { + ffi::Optional GetFunction(const ffi::String& name) override { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "profile") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { @@ -987,7 +989,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } } - prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); + prof_ = profiling::Profiler(devices, {}, {{ffi::String("Executor"), ffi::String("VM")}}); auto inputs = GetInputsFor(f_name); @@ -1074,7 +1076,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { }; ObjectPtr VirtualMachine::CreateProfiler() { - return make_object(); + return ffi::make_object(); } #else diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index a5fb6c2293fa..7c25985b6f07 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -33,11 +33,11 @@ namespace vulkan { ffi::Module VulkanModuleCreate(std::unordered_map smap, std::unordered_map fmap, std::string source) { - auto n = make_object(smap, fmap, source); + auto n = ffi::make_object(smap, fmap, source); return ffi::Module(n); } -ffi::Module VulkanModuleLoadFile(const std::string& file_name, const String& format) { +ffi::Module VulkanModuleLoadFile(const std::string& file_name, const ffi::String& format) { std::string data; std::unordered_map smap; std::unordered_map fmap; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 2f50a0154658..007d6abdbadb 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -205,7 +205,7 @@ VulkanModuleNode::~VulkanModuleNode() { } } -Optional VulkanModuleNode::GetFunction(const String& name) { +ffi::Optional VulkanModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); ICHECK_EQ(sptr_to_self.get(), this); auto it = fmap_.find(name); @@ -403,7 +403,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, return pe; } -void VulkanModuleNode::WriteToFile(const String& file_name, const String& format) const { +void VulkanModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); @@ -427,7 +427,7 @@ ffi::Bytes VulkanModuleNode::SaveToBytes() const { return ffi::Bytes(buffer); } -String VulkanModuleNode::InspectSource(const String& format) const { +ffi::String VulkanModuleNode::InspectSource(const ffi::String& format) const { // can only return disassembly code. return source_; } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index 2ff90568de9d..53ae3ac4ba82 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -94,15 +94,15 @@ class VulkanModuleNode final : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args); - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; private: // function information table. diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 1b02e7dfb8c0..003157572c36 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -31,7 +31,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); void IRBuilderFrameNode::EnterWithScope() { - IRBuilder::Current()->frames.push_back(GetRef(this)); + IRBuilder::Current()->frames.push_back(ffi::GetRef(this)); } void IRBuilderFrameNode::ExitWithScope() { @@ -50,7 +50,7 @@ void IRBuilderFrameNode::AddCallback(ffi::TypedFunction callback) { } IRBuilder::IRBuilder() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->frames.clear(); n->result = std::nullopt; data_ = n; @@ -95,7 +95,7 @@ Namer::FType& Namer::vtable() { return inst; } -void Namer::Name(ObjectRef node, String name) { +void Namer::Name(ObjectRef node, ffi::String name) { static const FType& f = vtable(); CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 9a1e5cdd109c..d2bb5231a867 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -28,7 +28,7 @@ namespace ir { TVM_FFI_STATIC_INIT_BLOCK({ IRModuleFrameNode::RegisterReflection(); }); void IRModuleFrameNode::ExitWithScope() { - Map func_map; + ffi::Map func_map; CHECK_EQ(functions.size(), global_var_map.size()) << "All functions must be defined in the IRModule. Got " << global_var_map.size() << "declared function(s), but only " << functions.size() << "defined function(s)."; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 26af0e55c76d..b0c56e779a71 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -32,7 +32,7 @@ namespace ir_builder { namespace ir { IRModuleFrame IRModule() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); @@ -49,14 +49,15 @@ inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) { } } -GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { +GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature) { IRModuleFrame frame = FindModuleFrame(); CHECK(!frame->global_var_map.count(func_name)) << "ValueError: function " << func_name << " already exists"; auto gvar_type = [&]() -> Type { if (auto prim_func = func_signature.as()) { - Array arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); }); + ffi::Array arg_types = + prim_func->params.Map([](const auto& var) { return GetType(var); }); return FuncType(arg_types, prim_func->ret_type); } @@ -72,7 +73,7 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) return gv; } -void DefFunction(const String& func_name, const BaseFunc& func) { +void DefFunction(const ffi::String& func_name, const BaseFunc& func) { IRModuleFrame frame = FindModuleFrame(); auto it = frame->global_var_map.find(func_name); CHECK(it != frame->global_var_map.end()) @@ -82,7 +83,7 @@ void DefFunction(const String& func_name, const BaseFunc& func) { gv->struct_info_ = GetGlobalVarStructInfo(func); } -void ModuleAttrs(Map attrs, bool allow_overwrite) { +void ModuleAttrs(ffi::Map attrs, bool allow_overwrite) { if (IRBuilder::IsInScope()) { // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); @@ -93,7 +94,7 @@ void ModuleAttrs(Map attrs, bool allow_overwrite) { } } -Optional ModuleGetAttr(const String& key) { +ffi::Optional ModuleGetAttr(const ffi::String& key) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (frame->attrs.find(key) != frame->attrs.end()) { @@ -103,7 +104,8 @@ Optional ModuleGetAttr(const String& key) { return std::nullopt; } -void ModuleSetAttr(const String& key, const Optional& value, bool allow_override) { +void ModuleSetAttr(const ffi::String& key, const ffi::Optional& value, + bool allow_override) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (!allow_override && frame->attrs.find(key) != frame->attrs.end() && value.defined()) { @@ -119,7 +121,7 @@ void ModuleSetAttr(const String& key, const Optional& value, bool all } } -void ModuleGlobalInfos(Map> global_infos) { +void ModuleGlobalInfos(ffi::Map> global_infos) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); if (!frame->global_infos.empty()) { @@ -130,13 +132,13 @@ void ModuleGlobalInfos(Map> global_infos) { } } -VDevice LookupVDevice(String target_kind, int device_index) { +VDevice LookupVDevice(ffi::String target_kind, int device_index) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (frame->global_infos.empty()) { LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not defined."; } - Array vdevices = frame->global_infos["vdevice"]; + ffi::Array vdevices = frame->global_infos["vdevice"]; if (vdevices.empty() || device_index < 0 || static_cast(device_index) >= vdevices.size()) { LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found."; diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h index b12e5e270d89..54ea6ce6ad92 100644 --- a/src/script/ir_builder/ir/utils.h +++ b/src/script/ir_builder/ir/utils.h @@ -26,10 +26,10 @@ namespace script { namespace ir_builder { namespace ir { -inline IRModuleFrame FindModuleFrame(const String& method) { +inline IRModuleFrame FindModuleFrame(const ffi::String& method) { IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { - const Optional& last_module_frame = builder->GetLastFrame(); + if (ffi::Optional frame = builder->FindFrame()) { + const ffi::Optional& last_module_frame = builder->GetLastFrame(); if (last_module_frame.defined() && last_module_frame.value() == frame) { return frame.value(); } @@ -43,7 +43,7 @@ inline IRModuleFrame FindModuleFrame(const String& method) { inline IRModuleFrame FindModuleFrame() { IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->FindFrame()) { return frame.value(); } else { LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it" diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index 424d20980ad2..bab14f3b3fd2 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -28,8 +28,9 @@ namespace tvm { namespace relax { -Expr MakeCallTIRDist(Expr func, Tuple args, Array out_sinfo_list, - Optional packed_ints) { +Expr MakeCallTIRDist(Expr func, Tuple args, + ffi::Array out_sinfo_list, + ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index c3c7ae6f4f88..d69547383a80 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -43,7 +43,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ void SeqExprFrameNode::ExitWithScope() { // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call // its `ExitBlockFrame` and check if there is any more unended BlockFrame. - if (Optional block_frame = IRBuilder::Current()->GetLastFrame()) { + if (ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame()) { block_frame.value()->ExitWithScope(); ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) << "ValueError: There is some remaining BlockFrame that is not properly popped out."; @@ -87,12 +87,12 @@ void FunctionFrameNode::ExitWithScope() { // Case 0. No outer frame, return function directly ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; - } else if (Optional opt_frame = builder->FindFrame()) { + } else if (ffi::Optional opt_frame = builder->FindFrame()) { // Case 1. A global function of an IRModule CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); - const String& func_name = name.value_or(""); + const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { // First time visiting the function. ir::DeclFunction(func_name, func); @@ -108,7 +108,7 @@ void FunctionFrameNode::ExitWithScope() { void BlockFrameNode::EnterWithScope() { // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the // last block frame. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); if (block_frame.defined()) { block_frame.value()->ExitWithScope(); // Block frames cannot appear consecutively. @@ -116,7 +116,7 @@ void BlockFrameNode::EnterWithScope() { } // Step 2. Deal with the new block frame. RelaxFrameNode::EnterWithScope(); - Optional func_frame = IRBuilder::Current()->FindFrame(); + ffi::Optional func_frame = IRBuilder::Current()->FindFrame(); CHECK(func_frame.defined()) << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " "creating the block under Relax function scope."; @@ -162,7 +162,7 @@ void BlockFrameNode::ExitWithScope() { // Step 3. Rewrite the dataflow block. if (is_dataflow) { // Step 3.0. Define a map to replace variables - Array new_output_vars; + ffi::Array new_output_vars; std::unordered_map var_remap; for (const auto& output_var : output_vars) { tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var)); @@ -185,7 +185,7 @@ void BlockFrameNode::ExitWithScope() { } // Step 3. Get the last frame from the IRBuilder frame stack. - Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); ICHECK(opt_last_frame.defined()); RelaxFrame last_frame = opt_last_frame.value(); @@ -195,7 +195,7 @@ void BlockFrameNode::ExitWithScope() { // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { - auto frame = GetRef(seq_frame); + auto frame = ffi::GetRef(seq_frame); frame->binding_blocks.push_back(block); } else { LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " @@ -210,7 +210,7 @@ void BlockFrameNode::ExitWithScope() { } void IfFrameNode::EnterWithScope() { - const Array& frames = IRBuilder::Current()->frames; + const ffi::Array& frames = IRBuilder::Current()->frames; for (const IRBuilderFrame& frame : frames) { const auto* block_frame = frame.as(); if (block_frame && block_frame->is_dataflow) { @@ -241,8 +241,8 @@ void ThenFrameNode::EnterWithScope() { void ThenFrameNode::ExitWithScope() { SeqExprFrameNode::ExitWithScope(); - String var_name; - output = GetSeqExprForBranch(GetRef(this), &var_name); + ffi::String var_name; + output = GetSeqExprForBranch(ffi::GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Then"); frame->then_expr = output; frame->var_name = var_name; @@ -259,8 +259,8 @@ void ElseFrameNode::EnterWithScope() { void ElseFrameNode::ExitWithScope() { SeqExprFrameNode::ExitWithScope(); - String var_name; - output = GetSeqExprForBranch(GetRef(this), &var_name); + ffi::String var_name; + output = GetSeqExprForBranch(ffi::GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Else"); frame->else_expr = output; CHECK(frame->var_name == var_name) diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index b845434e917b..8cab805a0433 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -34,7 +34,7 @@ namespace relax { using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using tvm::relax::VarNode; using tvm::relax::IdNode; const VarNode* var = node.as(); @@ -43,7 +43,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using tvm::relax::DataflowVarNode; using tvm::relax::IdNode; const DataflowVarNode* var = node.as(); @@ -54,10 +54,11 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) /////////////////////////////// Function //////////////////////////////// FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); - Optional mod = std::nullopt; - if (const Optional mod_frame = ir_builder->GetLastFrame()) { + ffi::Optional mod = std::nullopt; + if (const ffi::Optional mod_frame = + ir_builder->GetLastFrame()) { mod = tvm::IRModule(mod_frame.value()->functions); } n->block_builder = tvm::relax::BlockBuilder::Create( @@ -67,7 +68,7 @@ FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { return FunctionFrame(n); } -tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { +tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info) { FunctionFrame frame = FindFunctionFrame("R.Arg"); tvm::relax::Var var(name, struct_info); frame->params.push_back(var); @@ -76,7 +77,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf return var; } -void FuncName(const String& name) { +void FuncName(const ffi::String& name) { FunctionFrame frame = FindFunctionFrame("R.func_name"); if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() @@ -85,7 +86,7 @@ void FuncName(const String& name) { frame->name = name; } -void FuncAttrs(Map attrs) { +void FuncAttrs(ffi::Map attrs) { FunctionFrame frame = FindFunctionFrame("R.func_attr"); for (const auto& [key, value] : attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { @@ -159,22 +160,22 @@ TVM_FFI_STATIC_INIT_BLOCK({ ///////////////////////////// BindingBlock ////////////////////////////// BlockFrame Dataflow() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->is_dataflow = true; n->block_ended = false; return BlockFrame(n); } BlockFrame BindingBlock() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->is_dataflow = false; n->block_ended = false; return BlockFrame(n); } -void DataflowBlockOutput(const Array& vars) { +void DataflowBlockOutput(const ffi::Array& vars) { // Step 1. Check that we're in a Dataflow block that is not ended. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined() && block_frame.value()->is_dataflow) << "ValueError: `R.output` should appear inside a dataflow block. However, the current " "innermost block is not a dataflow block."; @@ -187,7 +188,7 @@ void DataflowBlockOutput(const Array& vars) { // Step 3. All the output variables must be global variables and must be emitted by this dataflow // block. - const Array& emitted_vars = block_frame.value()->emitted_vars; + const ffi::Array& emitted_vars = block_frame.value()->emitted_vars; for (const tvm::relax::Var& var : vars) { CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " @@ -207,7 +208,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /////////////////////////////// Bindings /////////////////////////////// tvm::relax::Var Emit(const tvm::relax::Expr& expr, - const Optional& annotate_struct_info) { + const ffi::Optional& annotate_struct_info) { using tvm::relax::GetStructInfo; BlockFrame block_frame = CheckBlockFrameExistAndUnended(); const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); @@ -255,7 +256,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ /////////////////////////////// SeqExpr /////////////////////////////// SeqExprFrame SeqExpr() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return SeqExprFrame(n); } @@ -267,7 +268,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ///////////////////////////// If Then Else ///////////////////////////// IfFrame If(tvm::relax::Expr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; n->then_expr = std::nullopt; n->else_expr = std::nullopt; @@ -275,12 +276,12 @@ IfFrame If(tvm::relax::Expr condition) { } ThenFrame Then() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ThenFrame(n); } ElseFrame Else() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ElseFrame(n); } diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 7fd7e21a6739..e24b4a27593d 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -31,8 +31,8 @@ namespace script { namespace ir_builder { namespace relax { -inline FunctionFrame FindFunctionFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->FindFrame()) { +inline FunctionFrame FindFunctionFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method @@ -40,8 +40,8 @@ inline FunctionFrame FindFunctionFrame(const String& method) { throw; } -inline IfFrame FindIfFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline IfFrame FindIfFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); } else { LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method @@ -51,7 +51,7 @@ inline IfFrame FindIfFrame(const String& method) { } inline tvm::relax::BlockBuilder GetBlockBuilder() { - Optional frame = IRBuilder::Current()->FindFrame(); + ffi::Optional frame = IRBuilder::Current()->FindFrame(); CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " "assignment is called under R.function()"; return frame.value()->block_builder; @@ -61,14 +61,14 @@ inline BlockFrame CheckBlockFrameExistAndUnended() { // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new // bindings into this block, and we should throw exceptions. - Optional block_frame = IRBuilder::Current()->GetLastFrame(); + ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); CHECK(block_frame.defined()) << "ValueError: Block frame not find"; CHECK(!block_frame.value()->block_ended) << "ValueError: New binding is not allowed after dataflow block output."; return block_frame.value(); } -inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::String* var_name) { // Step 0. Check frame type std::string method; std::string output_var_suffix; @@ -101,10 +101,10 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String *var_name = last_binding->var->name_hint(); // Step 3. Re-collect binding blocks to replace the last binding. - Array new_blocks(frame->binding_blocks.begin(), - frame->binding_blocks.end() - 1); - Array last_block_bindings(last_block->bindings.begin(), - last_block->bindings.end() - 1); + ffi::Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + ffi::Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix, GetStructInfo(last_binding->var)); diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index b0d5bb337f35..2bfb9266eada 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -67,11 +67,11 @@ void PrimFuncFrameNode::ExitWithScope() { if (builder->frames.empty()) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; - } else if (Optional opt_frame = builder->FindFrame()) { + } else if (ffi::Optional opt_frame = builder->FindFrame()) { CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " "function scope, if it's defined in a Module"; const ir::IRModuleFrame& frame = opt_frame.value(); - const String& func_name = name.value_or(""); + const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { // Case. First time visiting the function. ir::DeclFunction(func_name, func); @@ -86,17 +86,17 @@ void PrimFuncFrameNode::ExitWithScope() { void BlockFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - Array tir_alloc_buffers; + ffi::Array tir_alloc_buffers; for (const tvm::tir::Buffer& buffer : alloc_buffers) { tir_alloc_buffers.push_back(buffer); } - Map attrs = annotations.value_or({}); + ffi::Map attrs = annotations.value_or({}); if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); } - tvm::tir::Block block(iter_vars, reads.value_or(Array()), - writes.value_or(Array()), name, AsStmt(stmts), init, - tir_alloc_buffers, match_buffers, attrs); + tvm::tir::Block block(iter_vars, reads.value_or(ffi::Array()), + writes.value_or(ffi::Array()), name, AsStmt(stmts), + init, tir_alloc_buffers, match_buffers, attrs); if (no_realize) { CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 06790ad4fab3..e934f5d562dc 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -30,10 +30,11 @@ namespace tir { using tvm::tir::IterVar; -Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, - Optional> strides, Optional elem_offset, - String storage_scope, int align, int offset_factor, String buffer_type, - Optional> axis_separators) { +Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, int align, + int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators) { CHECK(buffer_type == "auto" || buffer_type == "default" || buffer_type.empty()) << "ValueError: `buffer_type` must be `auto` or `default` or empty"; Var buffer_data; @@ -50,14 +51,14 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; elem_offset = tvm::tir::Var("elem_offset", shape_dtype); } - return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), + return Buffer(buffer_data, dtype, shape, strides.value_or(ffi::Array()), elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, (buffer_type == "auto" ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault), - axis_separators.value_or(Array())); + axis_separators.value_or(ffi::Array())); } PrimFuncFrame PrimFunc(bool is_private) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->name = std::nullopt; n->is_private = is_private; n->args.clear(); @@ -69,14 +70,14 @@ PrimFuncFrame PrimFunc(bool is_private) { return PrimFuncFrame(n); } -Var Arg(String name, Var var) { +Var Arg(ffi::String name, Var var) { PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); details::Namer::Name(var, name); frame->args.push_back(var); return var; } -Buffer Arg(String name, Buffer buffer) { +Buffer Arg(ffi::String name, Buffer buffer) { PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); details::Namer::Name(buffer, name); Var handle(buffer->name + "_handle", DataType::Handle()); @@ -85,7 +86,7 @@ Buffer Arg(String name, Buffer buffer) { return buffer; } -void FuncName(String name) { +void FuncName(ffi::String name) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); @@ -93,7 +94,7 @@ void FuncName(String name) { frame->name = name; } -void FuncAttrs(Map new_attrs) { +void FuncAttrs(ffi::Map new_attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); for (const auto& [key, value] : new_attrs) { @@ -124,15 +125,15 @@ tvm::Type FuncRet(tvm::Type ret_type) { return ret_type; } -Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optional data, - Array strides, PrimExpr elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type_str, - Optional> axis_separators) { +Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, + ffi::Optional data, ffi::Array strides, PrimExpr elem_offset, + ffi::String storage_scope, int align, int offset_factor, + ffi::String buffer_type_str, ffi::Optional> axis_separators) { Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); if (const auto* var = param.as()) { PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); - Var v = GetRef(var); + Var v = ffi::GetRef(var); for (auto const& arg : frame->args) { if (arg.same_as(v)) { frame->buffer_map.Set(v, buffer); @@ -143,19 +144,19 @@ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optio } else if (const auto* buffer_load = param.as()) { BlockFrame frame = FindBlockFrame("T.match_buffer"); frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( - buffer, BufferRegionFromLoad(GetRef(buffer_load)))); + buffer, BufferRegionFromLoad(ffi::GetRef(buffer_load)))); } else if (const auto* buffer_region = param.as()) { BlockFrame frame = FindBlockFrame("T.match_buffer"); frame->match_buffers.push_back( - tvm::tir::MatchBufferRegion(buffer, GetRef(buffer_region))); + tvm::tir::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; } return buffer; } -BlockFrame Block(String name, bool no_realize) { - ObjectPtr n = make_object(); +BlockFrame Block(ffi::String name, bool no_realize) { + ObjectPtr n = ffi::make_object(); n->name = name; n->iter_vars.clear(); n->reads = std::nullopt; @@ -170,7 +171,7 @@ BlockFrame Block(String name, bool no_realize) { return BlockFrame(n); } -BlockInitFrame Init() { return BlockInitFrame(make_object()); } +BlockInitFrame Init() { return BlockInitFrame(ffi::make_object()); } void Where(PrimExpr predicate) { BlockFrame frame = FindBlockFrame("T.where"); @@ -181,13 +182,13 @@ void Where(PrimExpr predicate) { frame->predicate = predicate; } -void Reads(Array buffer_slices) { +void Reads(ffi::Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.reads"); if (frame->reads.defined()) { LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; } - Array reads; + ffi::Array reads; for (const ObjectRef& obj : buffer_slices) { if (auto buffer_region = obj.as()) { reads.push_back(buffer_region.value()); @@ -200,14 +201,14 @@ void Reads(Array buffer_slices) { frame->reads = reads; } -void Writes(Array buffer_slices) { +void Writes(ffi::Array buffer_slices) { using namespace tvm::tir; BlockFrame frame = FindBlockFrame("T.writes"); if (frame->writes.defined()) { LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " << frame->writes; } - Array writes; + ffi::Array writes; for (const ObjectRef& obj : buffer_slices) { if (auto buffer_region = obj.as()) { writes.push_back(buffer_region.value()); @@ -221,9 +222,9 @@ void Writes(Array buffer_slices) { } /*! \brief Recursively merge two annotations, the new attrs will override the old ones */ -Map MergeAnnotations(const Map& new_attrs, - const Map& old_attrs) { - Map result = old_attrs; +ffi::Map MergeAnnotations(const ffi::Map& new_attrs, + const ffi::Map& old_attrs) { + ffi::Map result = old_attrs; for (const auto& [key, value] : new_attrs) { auto old_value = old_attrs.Get(key); // Case 1: the key is not in the old annotations, set the key to the new value @@ -234,8 +235,8 @@ Map MergeAnnotations(const Map& new_attrs, // Case 2: the key is in the old annotations // Case 2.1: both are dicts - auto old_dict = old_value->try_cast>(); - auto new_dict = value.try_cast>(); + auto old_dict = old_value->try_cast>(); + auto new_dict = value.try_cast>(); if (old_dict && new_dict) { // Recursively merge the two dicts auto merged_dict = MergeAnnotations(*old_dict, *new_dict); @@ -251,7 +252,7 @@ Map MergeAnnotations(const Map& new_attrs, return result; } -void BlockAttrs(Map attrs) { +void BlockAttrs(ffi::Map attrs) { BlockFrame frame = FindBlockFrame("T.block_attr"); // Case 1: the block has no annotations, set the new annotations if (!frame->annotations.defined()) { @@ -262,16 +263,16 @@ void BlockAttrs(Map attrs) { } } -Buffer AllocBuffer(Array shape, DataType dtype, Optional data, - Array strides, PrimExpr elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type_str, - Optional> axis_separators) { +Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::Optional data, + ffi::Array strides, PrimExpr elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type_str, + ffi::Optional> axis_separators) { Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); IRBuilder builder = IRBuilder::Current(); - if (Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->FindFrame()) { frame.value()->alloc_buffers.push_back(buffer); - } else if (Optional frame = builder->GetLastFrame()) { + } else if (ffi::Optional frame = builder->GetLastFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); } else { LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " @@ -282,7 +283,7 @@ Buffer AllocBuffer(Array shape, DataType dtype, Optional data, namespace axis { IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { - if (Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { + if (ffi::Optional opt_frame = IRBuilder::Current()->GetLastFrame()) { BlockFrame frame = opt_frame.value(); frame->iter_vars.push_back(iter_var); frame->iter_values.push_back(binding); @@ -307,9 +308,9 @@ TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan"); TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); #undef TVM_TIR_IR_BUILDER_AXIS -Array Remap(String kinds, Array bindings, DataType dtype) { +ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType dtype) { using namespace tvm::tir; - Array results; + ffi::Array results; ICHECK_EQ(kinds.size(), bindings.size()); int n = bindings.size(); results.reserve(n); @@ -334,7 +335,7 @@ Array Remap(String kinds, Array bindings, DataType dtype) { } } } - ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef(v); + ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << ffi::GetRef(v); DataType dtype = v->dtype; if (c == 'S') { results.push_back(PushBlockVar(IterVar(/*dom=*/dom, @@ -359,21 +360,23 @@ Array Remap(String kinds, Array bindings, DataType dtype) { } // namespace axis -#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ - ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ - PrimExpr min = start; \ - PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ - ObjectPtr n = make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ - n->doms = {Range::FromMinExtent(min, extent)}; \ - n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ - ICHECK_EQ(vars.size(), 1); \ - ICHECK_EQ(doms.size(), 1); \ - return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ - annotations.value_or(Map())); \ - }; \ - return ForFrame(n); \ +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, \ + ffi::Optional> annotations) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = ffi::make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ + tvm::tir::Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ + annotations.value_or(ffi::Map())); \ + }; \ + return ForFrame(n); \ } TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); @@ -383,30 +386,30 @@ TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); #undef TVM_TIR_IR_BUILDER_FOR_FRAME -ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Optional> annotations) { +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, + ffi::Optional> annotations) { using namespace tvm::tir; PrimExpr min = start; PrimExpr extent = arith::Analyzer().Simplify(stop - start); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); DataType dtype = DataType(min.dtype().code(), bits, 1); n->vars = {Var("v", dtype)}; n->doms = {Range::FromMinExtent(min, extent)}; - n->f_make_for_loop = [annotations, thread, dtype](Array vars, Array doms, + n->f_make_for_loop = [annotations, thread, dtype](ffi::Array vars, ffi::Array doms, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, - annotations.value_or(Map())); + annotations.value_or(ffi::Map())); }; return ForFrame(n); } -ForFrame Grid(Array extents) { +ForFrame Grid(ffi::Array extents) { using namespace tvm::tir; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->vars.reserve(extents.size()); n->doms.reserve(extents.size()); for (const auto& extent : extents) { @@ -414,7 +417,7 @@ ForFrame Grid(Array extents) { n->vars.push_back(Var("v", extent.dtype())); n->doms.push_back(Range(make_const(dtype, 0), extent)); } - n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + n->f_make_for_loop = [](ffi::Array vars, ffi::Array doms, Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); int n = vars.size(); for (int i = n - 1; i >= 0; --i) { @@ -428,15 +431,15 @@ ForFrame Grid(Array extents) { return ForFrame(n); } -AssertFrame Assert(PrimExpr condition, String message) { - ObjectPtr n = make_object(); +AssertFrame Assert(PrimExpr condition, ffi::String message) { + ObjectPtr n = ffi::make_object(); n->condition = condition; n->message = tvm::tir::StringImm(message); return AssertFrame(n); } -LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional var) { - ObjectPtr n = make_object(); +LetFrame LetStmt(PrimExpr value, ffi::Optional type_annotation, ffi::Optional var) { + ObjectPtr n = ffi::make_object(); if (var.defined()) { n->var = var.value(); } else if (type_annotation.defined()) { @@ -449,7 +452,7 @@ LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional v } LetFrame LegacyLetStmt(Var var, PrimExpr value) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->var = var; n->value = value; return LetFrame(n); @@ -458,8 +461,8 @@ LetFrame LegacyLetStmt(Var var, PrimExpr value) { LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { IterVar iter_var{nullptr}; - if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { - if (Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { + if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (ffi::Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { iter_var = opt_iter_var.value(); } else { LOG(FATAL) << "ValueError: " << var->name_hint @@ -468,7 +471,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { } else { LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc"; } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); if (!iter_var->dom.defined()) { const_cast(iter_var.get())->dom = Range(tvm::tir::make_zero(extent.dtype()), extent); @@ -482,48 +485,50 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { return LaunchThreadFrame(n); } -LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) { +LaunchThreadFrame LaunchThread(ffi::String thread_tag, PrimExpr extent) { return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent); } -RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, ffi::String storage_scope, PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->buffer_slice = buffer_slice; n->storage_scope = storage_scope; n->condition = condition; return RealizeFrame(n); } -AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope, - Optional condition, Optional> annotations) { - ObjectPtr n = make_object(); +AllocateFrame Allocate(ffi::Array extents, DataType dtype, ffi::String storage_scope, + ffi::Optional condition, + ffi::Optional> annotations) { + ObjectPtr n = ffi::make_object(); n->extents = extents; n->dtype = dtype; n->storage_scope = storage_scope; n->condition = condition.value_or(tvm::Bool(true)); - n->annotations = annotations.value_or(Map()); + n->annotations = annotations.value_or(ffi::Map()); n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope)); return AllocateFrame(n); } -AllocateConstFrame AllocateConst(tvm::runtime::Tensor data, DataType dtype, Array extents, - Optional> annotations) { - ObjectPtr n = make_object(); +AllocateConstFrame AllocateConst(tvm::runtime::Tensor data, DataType dtype, + ffi::Array extents, + ffi::Optional> annotations) { + ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->extents = extents; n->data = data; - n->annotations = annotations.value_or(Map()); + n->annotations = annotations.value_or(ffi::Map()); n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype))); return AllocateConstFrame(n); } -AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { +AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value) { // convert POD value to PrimExpr if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { node = node.cast(); } - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->node = std::move(node); n->attr_key = attr_key; n->value = value; @@ -531,13 +536,13 @@ AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { } WhileFrame While(PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; return WhileFrame(n); } IfFrame If(PrimExpr condition) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->condition = condition; n->then_stmts = std::nullopt; n->else_stmts = std::nullopt; @@ -545,19 +550,19 @@ IfFrame If(PrimExpr condition) { } ThenFrame Then() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ThenFrame(n); } ElseFrame Else() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); return ElseFrame(n); } -Var EnvThread(String thread_tag, DataType dtype) { +Var EnvThread(ffi::String thread_tag, DataType dtype) { IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag); Var var = iter_var->var; - if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { opt_frame.value()->env_threads.Set(var, iter_var); } else { LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; @@ -565,8 +570,8 @@ Var EnvThread(String thread_tag, DataType dtype) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = std::nullopt) { +void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate = std::nullopt) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -631,12 +636,12 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices, AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } -DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, - Optional data, Optional> strides, - Optional elem_offset, String storage_scope, int align, - int offset_factor, String buffer_type, - Optional> axis_separators) { - ObjectPtr n = make_object(); +DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, + ffi::Optional data, ffi::Optional> strides, + ffi::Optional elem_offset, ffi::String storage_scope, + int align, int offset_factor, ffi::String buffer_type, + ffi::Optional> axis_separators) { + ObjectPtr n = ffi::make_object(); n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type, axis_separators); n->allocated = data.defined(); @@ -645,7 +650,8 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } -PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global", bool is_size_var = false) { +PrimExpr Ptr(runtime::DataType dtype, ffi::String storage_scope = "global", + bool is_size_var = false) { PointerType type_annotation(PrimType(dtype), storage_scope); return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } @@ -653,7 +659,7 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global", bool is_s using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { tvm::tir::BufferNode* buffer = const_cast(node.as()); buffer->name = name; @@ -668,21 +674,21 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; SizeVarNode* var = const_cast(node.as()); var->name_hint = name; }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; VarNode* var = const_cast(node.as()); var->name_hint = name; }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, String name) -> void { + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; IterVarNode* var = const_cast(node.as()); Namer::Name(var->var, name); @@ -694,7 +700,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.tir.Buffer", BufferDecl) .def("script.ir_builder.tir.PrimFunc", PrimFunc) .def("script.ir_builder.tir.Arg", - [](String name, ObjectRef obj) -> ObjectRef { + [](ffi::String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (auto var = obj.as()) { return Arg(name, var.value()); @@ -740,10 +746,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("script.ir_builder.tir.Else", Else) .def("script.ir_builder.tir.DeclBuffer", DeclBuffer) .def("script.ir_builder.tir.LaunchThread", - [](ffi::Variant thread_tag_or_var, PrimExpr extent) { + [](ffi::Variant thread_tag_or_var, PrimExpr extent) { if (auto var = thread_tag_or_var.as()) { return LaunchThread(var.value(), extent); - } else if (auto str = thread_tag_or_var.as()) { + } else if (auto str = thread_tag_or_var.as()) { return LaunchThread(str.value(), extent); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 9703a2adc323..d7c272ae5138 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -39,7 +39,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = stmt; } else if (const auto* tir_frame = builder->frames.back().as()) { - GetRef(tir_frame)->stmts.push_back(stmt); + ffi::GetRef(tir_frame)->stmts.push_back(stmt); } else { LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); } @@ -50,7 +50,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) { * \param stmt The array of Stmt. * \return The SeqStmt. */ -inline tvm::tir::Stmt AsStmt(const Array& stmt) { +inline tvm::tir::Stmt AsStmt(const ffi::Array& stmt) { return tvm::tir::SeqStmt::Flatten(stmt); } @@ -59,10 +59,11 @@ inline tvm::tir::Stmt AsStmt(const Array& stmt) { * \param method The method name to be printed when throwing exception. * \return The top frame of PrimFuncFrame. */ -inline PrimFuncFrame FindPrimFuncFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline PrimFuncFrame FindPrimFuncFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = + IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a PrimFunc. " << "While " << method << " did occur within the PrimFunc \"" << frame.value()->name << "\", other frames (e.g. block/if/else/let) had been introduced since the " @@ -79,10 +80,10 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \param method The method name to be printed when throwing exception. * \return The top frame of BlockFrame. */ -inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->FindFrame()) { +inline BlockFrame FindBlockFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " << "While " << method << " did occur within the block \"" << frame.value()->name << "\", other frames (e.g. if/else/let) had been introduced since the T.block(\"" @@ -99,10 +100,10 @@ inline BlockFrame FindBlockFrame(const String& method) { * \param method The method name to be printed when throwing exception. * \return The top frame of IfFrame. */ -inline IfFrame FindIfFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { +inline IfFrame FindIfFrame(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); - } else if (Optional frame = IRBuilder::Current()->FindFrame()) { + } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.if_(). " << "While " << method << " did occur within the conditional based on (" << frame.value()->condition @@ -121,7 +122,7 @@ inline IfFrame FindIfFrame(const String& method) { * \return The converted BufferRegion. */ inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { - Array ranges; + ffi::Array ranges; for (const PrimExpr& index : buffer_load->indices) { ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); } diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index aa7e0473488b..6f0d548bafca 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -56,31 +56,34 @@ TVM_FFI_STATIC_INIT_BLOCK({ DocStringDocNode::RegisterReflection(); }); -ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(ffi::String attr) const { + return AttrAccessDoc(ffi::GetRef(this), attr); +} -ExprDoc ExprDocNode::operator[](Array indices) const { - return IndexDoc(GetRef(this), indices); +ExprDoc ExprDocNode::operator[](ffi::Array indices) const { + return IndexDoc(ffi::GetRef(this), indices); } -ExprDoc ExprDocNode::Call(Array args) const { - return CallDoc(GetRef(this), args, Array(), Array()); +ExprDoc ExprDocNode::Call(ffi::Array args) const { + return CallDoc(ffi::GetRef(this), args, ffi::Array(), + ffi::Array()); } -ExprDoc ExprDocNode::Call(Array args, Array kwargs_keys, - Array kwargs_values) const { - return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); +ExprDoc ExprDocNode::Call(ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) const { + return CallDoc(ffi::GetRef(this), args, kwargs_keys, kwargs_values); } -ExprDoc ExprDoc::operator[](Array indices) const { return (*get())[indices]; } +ExprDoc ExprDoc::operator[](ffi::Array indices) const { return (*get())[indices]; } -StmtBlockDoc::StmtBlockDoc(Array stmts) { - ObjectPtr n = make_object(); +StmtBlockDoc::StmtBlockDoc(ffi::Array stmts) { + ObjectPtr n = ffi::make_object(); n->stmts = stmts; this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) { - ObjectPtr n = make_object(); +LiteralDoc::LiteralDoc(ffi::Any value, const ffi::Optional& object_path) { + ObjectPtr n = ffi::make_object(); n->value = value; if (object_path.defined()) { n->source_paths.push_back(object_path.value()); @@ -88,29 +91,29 @@ LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) this->data_ = std::move(n); } -IdDoc::IdDoc(String name) { - ObjectPtr n = make_object(); +IdDoc::IdDoc(ffi::String name) { + ObjectPtr n = ffi::make_object(); n->name = name; this->data_ = std::move(n); } -AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) { - ObjectPtr n = make_object(); +AttrAccessDoc::AttrAccessDoc(ExprDoc value, ffi::String name) { + ObjectPtr n = ffi::make_object(); n->value = value; n->name = name; this->data_ = std::move(n); } -IndexDoc::IndexDoc(ExprDoc value, Array indices) { - ObjectPtr n = make_object(); +IndexDoc::IndexDoc(ExprDoc value, ffi::Array indices) { + ObjectPtr n = ffi::make_object(); n->value = value; n->indices = indices; this->data_ = std::move(n); } -CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, - Array kwargs_values) { - ObjectPtr n = make_object(); +CallDoc::CallDoc(ExprDoc callee, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) { + ObjectPtr n = ffi::make_object(); n->callee = callee; n->args = args; n->kwargs_keys = kwargs_keys; @@ -118,96 +121,97 @@ CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, this->data_ = std::move(n); } -OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array operands) { - ObjectPtr n = make_object(); +OperationDoc::OperationDoc(OperationDocNode::Kind kind, ffi::Array operands) { + ObjectPtr n = ffi::make_object(); n->kind = kind; n->operands = operands; this->data_ = std::move(n); } -LambdaDoc::LambdaDoc(Array args, ExprDoc body) { - ObjectPtr n = make_object(); +LambdaDoc::LambdaDoc(ffi::Array args, ExprDoc body) { + ObjectPtr n = ffi::make_object(); n->args = args; n->body = body; this->data_ = std::move(n); } -TupleDoc::TupleDoc(Array elements) { - ObjectPtr n = make_object(); +TupleDoc::TupleDoc(ffi::Array elements) { + ObjectPtr n = ffi::make_object(); n->elements = elements; this->data_ = std::move(n); } -ListDoc::ListDoc(Array elements) { - ObjectPtr n = make_object(); +ListDoc::ListDoc(ffi::Array elements) { + ObjectPtr n = ffi::make_object(); n->elements = elements; this->data_ = std::move(n); } -DictDoc::DictDoc(Array keys, Array values) { - ObjectPtr n = make_object(); +DictDoc::DictDoc(ffi::Array keys, ffi::Array values) { + ObjectPtr n = ffi::make_object(); n->keys = keys; n->values = values; this->data_ = std::move(n); } -SliceDoc::SliceDoc(Optional start, Optional stop, Optional step) { - ObjectPtr n = make_object(); +SliceDoc::SliceDoc(ffi::Optional start, ffi::Optional stop, + ffi::Optional step) { + ObjectPtr n = ffi::make_object(); n->start = start; n->stop = stop; n->step = step; this->data_ = std::move(n); } -AssignDoc::AssignDoc(ExprDoc lhs, Optional rhs, Optional annotation) { +AssignDoc::AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional annotation) { CHECK(rhs.defined() || annotation.defined()) << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc."; CHECK(lhs->IsInstance() || annotation == nullptr) << "ValueError: annotation can only be nonnull if lhs is an identifier."; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->annotation = annotation; this->data_ = std::move(n); } -IfDoc::IfDoc(ExprDoc predicate, Array then_branch, Array else_branch) { +IfDoc::IfDoc(ExprDoc predicate, ffi::Array then_branch, ffi::Array else_branch) { CHECK(!then_branch.empty() || !else_branch.empty()) << "ValueError: At least one of the then branch or else branch needs to be non-empty."; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->predicate = predicate; n->then_branch = then_branch; n->else_branch = else_branch; this->data_ = std::move(n); } -WhileDoc::WhileDoc(ExprDoc predicate, Array body) { - ObjectPtr n = make_object(); +WhileDoc::WhileDoc(ExprDoc predicate, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->predicate = predicate; n->body = body; this->data_ = std::move(n); } -ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->body = body; this->data_ = std::move(n); } -ScopeDoc::ScopeDoc(Optional lhs, ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ScopeDoc::ScopeDoc(ffi::Optional lhs, ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = lhs; n->rhs = rhs; n->body = body; this->data_ = std::move(n); } -ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { - ObjectPtr n = make_object(); +ScopeDoc::ScopeDoc(ExprDoc rhs, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->lhs = std::nullopt; n->rhs = rhs; n->body = body; @@ -215,27 +219,27 @@ ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { } ExprStmtDoc::ExprStmtDoc(ExprDoc expr) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->expr = expr; this->data_ = std::move(n); } -AssertDoc::AssertDoc(ExprDoc test, Optional msg) { - ObjectPtr n = make_object(); +AssertDoc::AssertDoc(ExprDoc test, ffi::Optional msg) { + ObjectPtr n = ffi::make_object(); n->test = test; n->msg = msg; this->data_ = std::move(n); } ReturnDoc::ReturnDoc(ExprDoc value) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->value = value; this->data_ = std::move(n); } -FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decorators, - Optional return_type, Array body) { - ObjectPtr n = make_object(); +FunctionDoc::FunctionDoc(IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->args = args; n->decorators = decorators; @@ -244,22 +248,22 @@ FunctionDoc::FunctionDoc(IdDoc name, Array args, Array decor this->data_ = std::move(n); } -ClassDoc::ClassDoc(IdDoc name, Array decorators, Array body) { - ObjectPtr n = make_object(); +ClassDoc::ClassDoc(IdDoc name, ffi::Array decorators, ffi::Array body) { + ObjectPtr n = ffi::make_object(); n->name = name; n->decorators = decorators; n->body = body; this->data_ = std::move(n); } -CommentDoc::CommentDoc(String comment) { - ObjectPtr n = make_object(); +CommentDoc::CommentDoc(ffi::String comment) { + ObjectPtr n = ffi::make_object(); n->comment = comment; this->data_ = std::move(n); } -DocStringDoc::DocStringDoc(String docs) { - ObjectPtr n = make_object(); +DocStringDoc::DocStringDoc(ffi::String docs) { + ObjectPtr n = ffi::make_object(); n->comment = docs; this->data_ = std::move(n); } @@ -268,7 +272,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.DocSetSourcePaths", - [](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); + [](Doc doc, ffi::Array source_paths) { doc->source_paths = source_paths; }); }); TVM_FFI_STATIC_INIT_BLOCK({ @@ -276,22 +280,24 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def_method("script.printer.ExprDocAttr", &ExprDocNode::Attr) .def_method("script.printer.ExprDocIndex", &ExprDocNode::operator[]) - .def_method( - "script.printer.ExprDocCall", - [](ExprDoc doc, Array args, Array kwargs_keys, - Array kwargs_values) { return doc->Call(args, kwargs_keys, kwargs_values); }); + .def_method("script.printer.ExprDocCall", + [](ExprDoc doc, ffi::Array args, ffi::Array kwargs_keys, + ffi::Array kwargs_values) { + return doc->Call(args, kwargs_keys, kwargs_values); + }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.StmtDocSetComment", - [](StmtDoc doc, Optional comment) { doc->comment = comment; }); + refl::GlobalDef().def( + "script.printer.StmtDocSetComment", + [](StmtDoc doc, ffi::Optional comment) { doc->comment = comment; }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.StmtBlockDoc", - [](Array stmts) { return StmtBlockDoc(stmts); }); + [](ffi::Array stmts) { return StmtBlockDoc(stmts); }); }); TVM_FFI_STATIC_INIT_BLOCK({ @@ -306,104 +312,107 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IdDoc", [](String name) { return IdDoc(name); }); + refl::GlobalDef().def("script.printer.IdDoc", [](ffi::String name) { return IdDoc(name); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.AttrAccessDoc", - [](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); + [](ExprDoc value, ffi::String attr) { return AttrAccessDoc(value, attr); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IndexDoc", - [](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); + refl::GlobalDef().def("script.printer.IndexDoc", [](ExprDoc value, ffi::Array indices) { + return IndexDoc(value, indices); + }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.CallDoc", [](ExprDoc callee, // - Array args, // - Array kwargs_keys, // - Array kwargs_values) { + refl::GlobalDef().def("script.printer.CallDoc", [](ExprDoc callee, // + ffi::Array args, // + ffi::Array kwargs_keys, // + ffi::Array kwargs_values) { return CallDoc(callee, args, kwargs_keys, kwargs_values); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.OperationDoc", [](int32_t kind, Array operands) { - return OperationDoc(OperationDocNode::Kind(kind), operands); - }); + refl::GlobalDef().def("script.printer.OperationDoc", + [](int32_t kind, ffi::Array operands) { + return OperationDoc(OperationDocNode::Kind(kind), operands); + }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.LambdaDoc", - [](Array args, ExprDoc body) { return LambdaDoc(args, body); }); + [](ffi::Array args, ExprDoc body) { return LambdaDoc(args, body); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.TupleDoc", - [](Array elements) { return TupleDoc(elements); }); + [](ffi::Array elements) { return TupleDoc(elements); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ListDoc", - [](Array elements) { return ListDoc(elements); }); + [](ffi::Array elements) { return ListDoc(elements); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.DictDoc", [](Array keys, Array values) { - return DictDoc(keys, values); - }); + refl::GlobalDef().def( + "script.printer.DictDoc", + [](ffi::Array keys, ffi::Array values) { return DictDoc(keys, values); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.SliceDoc", - [](Optional start, Optional stop, - Optional step) { return SliceDoc(start, stop, step); }); + [](ffi::Optional start, ffi::Optional stop, + ffi::Optional step) { return SliceDoc(start, stop, step); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.AssignDoc", - [](ExprDoc lhs, Optional rhs, Optional annotation) { - return AssignDoc(lhs, rhs, annotation); - }); + refl::GlobalDef().def("script.printer.AssignDoc", [](ExprDoc lhs, ffi::Optional rhs, + ffi::Optional annotation) { + return AssignDoc(lhs, rhs, annotation); + }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.IfDoc", [](ExprDoc predicate, Array then_branch, - Array else_branch) { - return IfDoc(predicate, then_branch, else_branch); - }); + refl::GlobalDef().def( + "script.printer.IfDoc", + [](ExprDoc predicate, ffi::Array then_branch, ffi::Array else_branch) { + return IfDoc(predicate, then_branch, else_branch); + }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.WhileDoc", [](ExprDoc predicate, Array body) { + refl::GlobalDef().def("script.printer.WhileDoc", [](ExprDoc predicate, ffi::Array body) { return WhileDoc(predicate, body); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("script.printer.ForDoc", [](ExprDoc lhs, ExprDoc rhs, Array body) { - return ForDoc(lhs, rhs, body); - }); + refl::GlobalDef().def( + "script.printer.ForDoc", + [](ExprDoc lhs, ExprDoc rhs, ffi::Array body) { return ForDoc(lhs, rhs, body); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ScopeDoc", - [](Optional lhs, ExprDoc rhs, Array body) { + [](ffi::Optional lhs, ExprDoc rhs, ffi::Array body) { return ScopeDoc(lhs, rhs, body); }); }); @@ -418,7 +427,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.AssertDoc", - [](ExprDoc test, Optional msg = std::nullopt) { return AssertDoc(test, msg); }); + [](ExprDoc test, ffi::Optional msg = std::nullopt) { return AssertDoc(test, msg); }); }); TVM_FFI_STATIC_INIT_BLOCK({ @@ -429,8 +438,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.FunctionDoc", - [](IdDoc name, Array args, Array decorators, - Optional return_type, Array body) { + [](IdDoc name, ffi::Array args, ffi::Array decorators, + ffi::Optional return_type, ffi::Array body) { return FunctionDoc(name, args, decorators, return_type, body); }); }); @@ -438,7 +447,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.ClassDoc", - [](IdDoc name, Array decorators, Array body) { + [](IdDoc name, ffi::Array decorators, ffi::Array body) { return ClassDoc(name, decorators, body); }); }); @@ -446,13 +455,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.CommentDoc", - [](String comment) { return CommentDoc(comment); }); + [](ffi::String comment) { return CommentDoc(comment); }); }); TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("script.printer.DocStringDoc", - [](String docs) { return DocStringDoc(docs); }); + [](ffi::String docs) { return DocStringDoc(docs); }); }); } // namespace printer diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 7e6d76c4bf9a..77990c8048c5 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -275,7 +275,7 @@ void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { } } -String DocPrinter::GetString() const { +ffi::String DocPrinter::GetString() const { std::string text = output_.str(); // Remove any trailing indentation diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index b92c9dbe7aa2..53c388f84a5b 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -81,7 +81,7 @@ class DocPrinter { * * \sa Append */ - String GetString() const; + ffi::String GetString() const; protected: /*! @@ -267,7 +267,7 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - Array path_to_underline_; + ffi::Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 21f5e3301568..e576c5acb1bf 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -182,7 +182,7 @@ class PythonDocPrinter : public DocPrinter { } template - void PrintJoinedDocs(const Array& docs, const std::string& separator) { + void PrintJoinedDocs(const ffi::Array& docs, const std::string& separator) { bool is_first = true; for (auto& doc : docs) { if (is_first) { @@ -194,7 +194,7 @@ class PythonDocPrinter : public DocPrinter { } } - void PrintIndentedBlock(const Array& docs) { + void PrintIndentedBlock(const ffi::Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { NewLine(); @@ -207,7 +207,7 @@ class PythonDocPrinter : public DocPrinter { DecreaseIndent(); } - void PrintDecorators(const Array& decorators) { + void PrintDecorators(const ffi::Array& decorators) { for (const ExprDoc& decorator : decorators) { output_ << "@"; PrintDoc(decorator); @@ -285,7 +285,7 @@ class PythonDocPrinter : public DocPrinter { } } - void PrintDocString(const String& comment) { + void PrintDocString(const ffi::String& comment) { size_t start_pos = output_.tellp(); output_ << "\"\"\""; @@ -304,7 +304,7 @@ class PythonDocPrinter : public DocPrinter { underlines_exempted_.push_back({start_pos, end_pos}); } - void PrintBlockComment(const String& comment) { + void PrintBlockComment(const ffi::String& comment) { IncreaseIndent(); NewLine(); PrintDocString(comment); @@ -484,7 +484,7 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { } else { output_ << ", "; } - const String& keyword = doc->kwargs_keys[i]; + const ffi::String& keyword = doc->kwargs_keys[i]; output_ << keyword; output_ << "="; PrintDoc(doc->kwargs_values[i]); @@ -714,7 +714,7 @@ void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { } } -String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { +ffi::String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { if (cfg->num_context_lines < 0) { cfg->num_context_lines = std::numeric_limits::max(); } diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index fd478768bf32..62d4c3ad6132 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -28,7 +28,7 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](ffi::Shape n, AccessPath n_p, IRDocsifier d) -> Doc { int s = n.size(); - Array results; + ffi::Array results; results.reserve(s); for (int i = 0; i < s; ++i) { results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayItem(i))); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 70be98f4c425..0bca40948e3c 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -130,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc { d->AddGlobalInfo("vdevice", vdev); - Map config = vdev->target->Export(); + ffi::Map config = vdev->target->Export(); return IR(d, "vdevice") ->Call({d->AsDoc(config, p), LiteralDoc::Int(vdev->vdevice_id, p->Attr("vdevice_id")), diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 5643ab4de43a..f33170577154 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -23,10 +23,10 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch>( // - "", [](Array array, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch>( // + "", [](ffi::Array array, AccessPath p, IRDocsifier d) -> Doc { int n = array.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { results.push_back(d->AsDoc(array[i], p->ArrayItem(i))); @@ -35,8 +35,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch>( // - "", [](Map dict, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch>( // + "", [](ffi::Map dict, AccessPath p, IRDocsifier d) -> Doc { using POO = std::pair; std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; @@ -48,12 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (is_str_map) { std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) { - return Downcast(lhs.first) < Downcast(rhs.first); + return Downcast(lhs.first) < Downcast(rhs.first); }); } int n = dict.size(); - Array ks; - Array vs; + ffi::Array ks; + ffi::Array vs; ks.reserve(n); vs.reserve(n); for (int i = 0; i < n; ++i) { diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index d79e5cd4565d..6b62bac3ec23 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -37,7 +37,7 @@ namespace printer { class IRFrameNode : public FrameNode { public: - Map>* global_infos = nullptr; + ffi::Map>* global_infos = nullptr; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -51,7 +51,7 @@ class IRFrameNode : public FrameNode { class IRFrame : public Frame { public: explicit IRFrame(const IRDocsifier& d) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->global_infos = nullptr; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index efe7bc2f937a..94d2a281e2fe 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -35,7 +35,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ IRDocsifierNode::RegisterReflection(); }); -IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { +IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, + const ffi::String& name_hint) { if (auto it = obj2info.find(obj); it != obj2info.end()) { // TVM's IR dialects do not allow multiple definitions of the same // variable within an IRModule. This branch can only be reached @@ -51,7 +52,7 @@ IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const St return IdDoc(it->second.name.value()); } - String name = name_hint; + ffi::String name = name_hint; if (cfg->show_object_address) { std::stringstream stream; stream << name << "_" << obj.get(); @@ -72,7 +73,7 @@ void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreato frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); } -Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { +ffi::Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { auto it = obj2info.find(obj); if (it == obj2info.end()) { return std::nullopt; @@ -82,8 +83,8 @@ Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata"; - String key = obj.GetTypeKey(); - Array& array = metadata[key]; + ffi::String key = obj.GetTypeKey(); + ffi::Array& array = metadata[key]; int index = std::find_if(array.begin(), array.end(), [&](const ffi::Any& a) { return ffi::AnyEqual()(a, obj); }) - array.begin(); @@ -94,9 +95,9 @@ ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { "metadata")[{LiteralDoc::Str(key, std::nullopt)}][{LiteralDoc::Int(index, std::nullopt)}]; } -void IRDocsifierNode::AddGlobalInfo(const String& name, const GlobalInfo& ginfo) { +void IRDocsifierNode::AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo) { ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos"; - Array& array = global_infos[name]; + ffi::Array& array = global_infos[name]; array.push_back(ginfo); } @@ -191,11 +192,11 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, } IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { - auto n = make_object(); + auto n = ffi::make_object(); n->cfg = cfg; n->dispatch_tokens.push_back(""); // Define builtin keywords according to cfg. - for (const String& keyword : cfg->GetBuiltinKeywords()) { + for (const ffi::String& keyword : cfg->GetBuiltinKeywords()) { n->defined_names.insert(keyword); } data_ = std::move(n); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index d4580af96891..19da2cd508aa 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -23,15 +23,15 @@ namespace script { namespace printer { IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // - const Optional& var, const Optional& ann) { + const ffi::Optional& var, const ffi::Optional& ann) { using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); - std::vector> branches{ + std::vector> branches{ PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false), PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false), }; if (var.defined()) { - for (Array& stmts : branches) { + for (ffi::Array& stmts : branches) { ExprDoc ret = Downcast(stmts.back())->expr; stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); } @@ -44,7 +44,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; - Optional ann = std::nullopt; + ffi::Optional ann = std::nullopt; if (d->cfg->show_all_struct_info) { ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); } @@ -59,9 +59,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); - return PrintIfExpr(GetRef(if_), n_p->Attr("value"), d, lhs, ann); + return PrintIfExpr(ffi::GetRef(if_), n_p->Attr("value"), d, lhs, ann); } else if (n->value->IsInstance() && !n->value->IsInstance()) { IdDoc lhs = DefineVar(n->var, d->frames.back(), d); @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ExprStmtDoc(rhs); } else { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); - Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ffi::Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); return AssignDoc(lhs, rhs, ann); } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index e7e7e21380e4..9b0d2b966a4d 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -29,8 +29,8 @@ namespace printer { class AttrPrinter { public: - explicit AttrPrinter(AccessPath p, const IRDocsifier& d, Array* keys, - Array* values) + explicit AttrPrinter(AccessPath p, const IRDocsifier& d, ffi::Array* keys, + ffi::Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} void operator()(const tvm::Attrs& attrs) { @@ -46,7 +46,7 @@ class AttrPrinter { << "` misses reflection registration and do not support serialization"; // new printing mechanism using the new reflection ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { - String field_name = String(field_info->name); + ffi::String field_name = ffi::String(field_info->name); Any field_value = ffi::reflection::FieldGetter(field_info)(attrs); keys->push_back(field_name); values->push_back(d->AsDoc(field_value, p->Attr(field_name))); @@ -56,8 +56,8 @@ class AttrPrinter { AccessPath p; const IRDocsifier& d; - Array* keys; - Array* values; + ffi::Array* keys; + ffi::Array* values; }; ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) { @@ -69,8 +69,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifi } } -Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); @@ -83,9 +83,9 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& } ICHECK(n->args.size() == 2 || n->args.size() == 3); ICHECK(n->sinfo_args.size() == 1); - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // Step 1. Print n->args[0], the callee args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); // Step 2. Print n->args[1], the input arguments @@ -96,7 +96,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& bool is_dtensor = false; kwargs_keys.push_back("out_sinfo"); if (const auto* o = o_sinfo.as()) { - Array fields; + ffi::Array fields; AccessPath fields_p = o_sinfo_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { if (o->fields[i].as()) { @@ -115,7 +115,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& // for call_tir_inplace, we also need to include the inplace args if (n->op.same_as(call_tir_inplace_op)) { kwargs_keys.push_back("inplace_indices"); - Array index_fields; + ffi::Array index_fields; if (auto* call_tir_inplace_attrs = n->attrs.as()) { for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) { index_fields.push_back( @@ -160,7 +160,8 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& } } -Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { +ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { return std::nullopt; @@ -170,7 +171,7 @@ Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, con // is the _format_ string, or else roundtripping will fail // (the format string will be interpreted as an argument and there will be a new default format // string given) - Array args; + ffi::Array args; args.push_back(d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0))); ExprDoc second_arg = d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1)); for (size_t i = 2; i < n->args.size(); i++) { @@ -179,17 +180,17 @@ Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, con return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { return std::nullopt; } - Array args; + ffi::Array args; args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; ICHECK(n->attrs.defined()); if (n->attrs.as()) { AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); @@ -198,17 +199,17 @@ Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, return Relax(d, "hint_on_device")->Call(args); } -Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { return std::nullopt; } - Array args; + ffi::Array args; args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; ICHECK(n->attrs.defined()); if (const auto* attrs = n->attrs.as()) { VDevice vdev = attrs->dst_vdevice; @@ -221,8 +222,8 @@ Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, - const IRDocsifier& d) { +ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, + const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { return std::nullopt; @@ -233,7 +234,7 @@ Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, // (the format string will be interpreted as an argument and there will be a new default format // string given) ExprDoc first_arg = d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0)); - Array args; + ffi::Array args; for (size_t i = 1; i < n->args.size(); i++) { args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayItem(i))); } @@ -244,29 +245,29 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed, call_tir_with_grad - if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { + if (ffi::Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); } // Special case: assert_op - if (Optional doc = PrintAssertOp(n, n_p, d)) { + if (ffi::Optional doc = PrintAssertOp(n, n_p, d)) { return doc.value(); } // Special case: hint_on_device - if (Optional doc = PrintHintOnDevice(n, n_p, d)) { + if (ffi::Optional doc = PrintHintOnDevice(n, n_p, d)) { return doc.value(); } // Special case: to_vdevice - if (Optional doc = PrintToVDevice(n, n_p, d)) { + if (ffi::Optional doc = PrintToVDevice(n, n_p, d)) { return doc.value(); } // Special case: print - if (Optional doc = PrintRelaxPrint(n, n_p, d)) { + if (ffi::Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } ExprDoc prefix{nullptr}; - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; // Step 1. Print op if (const auto* op = n->op.as()) { prefix = Relax(d, "call_packed"); @@ -299,7 +300,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs"))); } if (const auto* attrs = n->attrs.as()) { - std::vector> sorted; + std::vector> sorted; for (const auto& kv : attrs->dict) { sorted.push_back(kv); } @@ -317,7 +318,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 4. Print type_args if (n->sinfo_args.size() > 0) { AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); - Array sinfo_args; + ffi::Array sinfo_args; for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); } diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index d8b3871b35bc..d1a29be24f5e 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -37,16 +37,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; bool require_kwargs = false; if (n->tensor_sinfo->shape.defined()) { // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->tensor_sinfo->shape.value().as()) { - auto shape_expr = GetRef(shape); + auto shape_expr = ffi::GetRef(shape); AccessPath shape_p = n_p->Attr("shape")->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); @@ -102,7 +102,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } if (!has_relax_frame || !f) { - Array args; + ffi::Array args; args.push_back(d->AsDoc(n->shape, n_p->Attr("shape"))); if (n->device_range.defined()) { args.push_back(d->AsDoc(n->device_range, n_p->Attr("device_range"))); @@ -116,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (kv.second[i].same_as(n)) { std::stringstream ss; ss << kv.first << "[" << i << "]"; - return d->AsDoc(String(ss.str()), n_p); + return d->AsDoc(ffi::String(ss.str()), n_p); } } } diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index 903aef5a697e..0c8cd3c12371 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "tuple")->Call({}); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc { - Array values_doc; + ffi::Array values_doc; AccessPath values_p = n_p->Attr("values"); for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d)); @@ -79,7 +79,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); -Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { +ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { @@ -135,7 +135,7 @@ Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { - if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { + if (ffi::Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { if (n->struct_info_.as()) { ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); return Relax(d, "dist.const")->Call({s.value(), ann}); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index aa6182f189fe..1a1bf006995d 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -47,7 +47,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) IdDoc func_name(""); // if we are binding a local definition, then calling d->Define // will result in a repeated definition and an incorrect displayed name - if (Optional name = GetBindingName(d)) { + if (ffi::Optional name = GetBindingName(d)) { func_name = IdDoc(name.value()); } else { func_name = IdDoc(FindFunctionName(d, n).value_or("main")); @@ -56,13 +56,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->is_func = true; (*f)->func_vars = &func_vars; // Step 1. Print the return type - Optional ret_type = std::nullopt; + ffi::Optional ret_type = std::nullopt; if (const auto& func_sinfo = relax::MatchStructInfo(n)) { ret_type = d->AsDoc(func_sinfo.value()->ret, // n_p->Attr("struct_info_")->Attr("ret")); } // Step 2. Print params - Array params; + ffi::Array params; { AccessPath params_p = n_p->Attr("params"); for (int i = 0, l = n->params.size(); i < l; ++i) { @@ -81,8 +81,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // For a function without an IR module whose global symbol // doesn't match the function name, we should still print the global symbol attribute. if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { - Map new_attrs; + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { + ffi::Map new_attrs; for (auto kv : n->attrs->dict) { if (kv.first != tvm::attr::kGlobalSymbol) { new_attrs.Set(kv.first, kv.second); @@ -101,26 +101,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 5. Prepare the decorator (include purity if it's impure) ExprDoc decorator = Relax(d, "function"); - Array pos_args = {}; - Array dec_keys; - Array dec_values; + ffi::Array pos_args = {}; + ffi::Array dec_keys; + ffi::Array dec_values; if (!n->is_pure) { dec_keys.push_back("pure"); - dec_values.push_back(LiteralDoc::Boolean(false, Optional())); + dec_values.push_back(LiteralDoc::Boolean(false, ffi::Optional())); } // if the function is global or is not in a module and does not have a global symbol, // indicate that it's private if (AtTopLevelFunction(d) && (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { dec_keys.push_back("private"); - dec_values.push_back(LiteralDoc::Boolean(true, Optional())); + dec_values.push_back(LiteralDoc::Boolean(true, ffi::Optional())); } if (dec_keys.size()) { decorator = decorator->Call(pos_args, dec_keys, dec_values); } // Step 6. Print body - Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); + ffi::Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); }); diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc index 7cedc63c271c..a28967cb4194 100644 --- a/src/script/printer/relax/region.cc +++ b/src/script/printer/relax/region.cc @@ -22,18 +22,18 @@ namespace tvm { namespace script { namespace printer { -Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, - bool use_ret) { +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, + const IRDocsifier& d, bool use_ret) { With f(d); - const Array& blocks = n->blocks; + const ffi::Array& blocks = n->blocks; AccessPath blocks_p = n_p->Attr("blocks"); - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; for (int i = 0, l = blocks.size(); i < l; ++i) { Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i)); if (const auto* stmt_block = block.as()) { stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else if (const auto* stmt = block.as()) { - stmts->push_back(GetRef(stmt)); + stmts->push_back(ffi::GetRef(stmt)); } else { LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); } @@ -52,18 +52,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); }); -Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, - const IRDocsifier& d, Array* non_dataflow_vars) { - const Array& bindings = n->bindings; +ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, + const IRDocsifier& d, + ffi::Array* non_dataflow_vars) { + const ffi::Array& bindings = n->bindings; AccessPath bindings_p = n_p->Attr("bindings"); - Array stmts; + ffi::Array stmts; for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; AccessPath binding_p = bindings_p->ArrayItem(i); ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { - stmts.push_back(GetRef(stmt)); + stmts.push_back(ffi::GetRef(stmt)); } else if (const auto* stmt_block = binding_doc.as()) { stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else { @@ -85,8 +86,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc { - Array non_dataflow_vars; - Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); + ffi::Array non_dataflow_vars; + ffi::Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); return ScopeDoc(std::nullopt, Relax(d, "dataflow")->Call({}), stmts); }); diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 87de6a8335f5..d6e2ac0f13f5 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -63,9 +63,9 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (n->value.defined()) { kwargs_keys.push_back("value"); @@ -81,9 +81,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { - Array shape = n->values.value(); + ffi::Array shape = n->values.value(); AccessPath shape_p = n_p->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape.size(); i < ndim; ++i) { shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d)); } @@ -96,15 +96,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (n->shape.defined()) { // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->shape.value().as()) { - auto shape_expr = GetRef(shape); + auto shape_expr = ffi::GetRef(shape); AccessPath shape_p = n_p->Attr("shape")->Attr("values"); - Array shape_docs; + ffi::Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); @@ -141,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "Tuple"); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -156,8 +156,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); if (n->IsOpaque()) { - Array keys; - Array values; + ffi::Array keys; + ffi::Array values; if (!n->ret->IsInstance()) { keys.push_back("ret"); @@ -175,8 +175,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // TODO(@junrushao): track symbolic shape relation - Array params_doc; - Array params = n->params.value(); + ffi::Array params_doc; + ffi::Array params = n->params.value(); AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 67f39a6f6c45..0c1a2cd26035 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -58,11 +58,11 @@ Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { ICHECK(f->is_func); f->func_vars->insert(n.get()); } - IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); + IdDoc var = d->Define(n, ffi::GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); var->source_paths.push_back(n_p); f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), std::nullopt)); } - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; @@ -86,7 +86,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { IdDoc ret(n->name_hint); @@ -98,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // - Optional doc = d->GetVarDoc(mod); + ffi::Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { // Use Module Name directly diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index d4ad35a13ee5..893f4304342e 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -58,7 +58,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->fields.empty()) { return Relax(d, "Tuple"); } - Array fields_doc; + ffi::Array fields_doc; AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); @@ -69,8 +69,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "relax", [](tvm::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { - Array arg_types_doc; - Array arg_types = n->arg_types; + ffi::Array arg_types_doc; + ffi::Array arg_types = n->arg_types; AccessPath arg_types_p = n_p->Attr("arg_types"); for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayItem(i))); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 37ae86220051..bdfce4cfc64e 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -58,7 +58,7 @@ class RelaxFrameNode : public FrameNode { class RelaxFrame : public Frame { public: explicit RelaxFrame(const IRDocsifier& d) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->is_func = false; @@ -81,8 +81,9 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, - const IRDocsifier& d, const Optional& rhs) { +inline ffi::Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, + const IRDocsifier& d, + const ffi::Optional& rhs) { if (!v->struct_info_.defined()) { return std::nullopt; } @@ -96,7 +97,7 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& } } if (attempt_to_hide_struct_info) { - Optional inferred_sinfo = std::nullopt; + ffi::Optional inferred_sinfo = std::nullopt; if (auto opt = rhs.as()) { auto call = opt.value(); if (auto opt = call->op.as()) { @@ -133,13 +134,13 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } -Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, - bool use_ret); +ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, + const IRDocsifier& d, bool use_ret); ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d); inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) { - Array vdevices = d->global_infos["vdevice"]; + ffi::Array vdevices = d->global_infos["vdevice"]; int kind_index = 0; for (size_t i = 0; i < vdevices.size(); ++i) { auto vdev = Downcast(vdevices[i]); diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index fb4f8a9d772b..587520d72fe5 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -23,7 +23,8 @@ namespace script { namespace printer { Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // - Optional opt_realize, Optional opt_realize_p) { + ffi::Optional opt_realize, + ffi::Optional opt_realize_p) { With frame(d, block); ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tir::BlockRealizeNode* realize = @@ -35,7 +36,8 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // for (Frame f : d->frames) { if (const auto* tir_f = f.as()) { if (auto for_loop = tir_f->tir.as()) { - for (Optional loop = for_loop; loop; loop = loop.value()->body.as()) { + for (ffi::Optional loop = for_loop; loop; + loop = loop.value()->body.as()) { loop_vars.insert(std::make_pair(loop.value()->loop_var.get(), loop.value())); } } @@ -113,12 +115,12 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // remap_vars_indices.clear(); return; } - Array lhs; - Array loop_var_doc; + ffi::Array lhs; + ffi::Array loop_var_doc; lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; - Array binding_paths; + ffi::Array binding_paths; for (int i : remap_vars_indices) { tir::IterVar iter_var = block->iter_vars[i]; AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); @@ -158,12 +160,12 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // } // Step 3. Handle block read/write regions { - Array reads; + ffi::Array reads; for (int i = 0, n = block->reads.size(); i < n; ++i) { reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayItem(i))); } (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads))); - Array writes; + ffi::Array writes; for (int i = 0, n = block->writes.size(); i < n; ++i) { writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayItem(i))); } @@ -201,8 +203,8 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // } // Step 8. Handle block body AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (!realize) { kwargs_keys.push_back("no_realize"); kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 0e7ae3a843cf..4057b1d09bfc 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -24,13 +24,14 @@ namespace tvm { namespace script { namespace printer { -Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, const Frame& frame, - const IRDocsifier& d, BufferVarDefinition var_definitions) { +ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, + const Frame& frame, const IRDocsifier& d, + BufferVarDefinition var_definitions) { using tvm::tir::Var; using tvm::tir::VarNode; - Map kwargs; - Array var_def_lhs; - Array var_def_rhs; + ffi::Map kwargs; + ffi::Array var_def_lhs; + ffi::Array var_def_rhs; // Step 0. Set up statistics std::unordered_map use_count; @@ -73,10 +74,10 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, }; // Step 1. Handle `buffer.shape` { - const Array& shape = buffer->shape; + const ffi::Array& shape = buffer->shape; AccessPath shape_p = buffer_p->Attr("shape"); int n = shape.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = shape[i]; @@ -108,10 +109,10 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, } // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { - const Array& strides = buffer->strides; + const ffi::Array& strides = buffer->strides; AccessPath strides_p = buffer_p->Attr("strides"); int n = strides.size(); - Array results; + ffi::Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = strides[i]; @@ -148,7 +149,7 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, } // Step 6. Handle `buffer.scope` { - String scope = buffer.scope(); + ffi::String scope = buffer.scope(); if (scope != "global") { kwargs.Set( "scope", @@ -182,17 +183,18 @@ Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, return kwargs; } -ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Array args) { - Array kwargs_keys; - Array kwargs_values; - for (String s : {"shape", "dtype"}) { - if (Optional doc = attrs.Get(s)) { +ExprDoc BufferCall(const ExprDoc& prefix, const ffi::Map& attrs, + ffi::Array args) { + ffi::Array kwargs_keys; + ffi::Array kwargs_values; + for (ffi::String s : {"shape", "dtype"}) { + if (ffi::Optional doc = attrs.Get(s)) { args.push_back(doc.value()); } } - for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", - "buffer_type", "axis_separators"}) { - if (Optional doc = attrs.Get(s)) { + for (ffi::String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", + "buffer_type", "axis_separators"}) { + if (ffi::Optional doc = attrs.Get(s)) { kwargs_keys.push_back(s); kwargs_values.push_back(doc.value()); } @@ -200,9 +202,9 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr return prefix->Call(args, kwargs_keys, kwargs_values); } -ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const AccessPath& p, const Frame& frame, const IRDocsifier& d, - BufferVarDefinition var_definitions) { +ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, + const ffi::Array& args, const AccessPath& p, const Frame& frame, + const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d, var_definitions), /*args=*/args); @@ -210,17 +212,18 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d) { - Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); + ffi::Map attrs = + BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); ExprDoc shape = attrs.Get("shape").value(); ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } -Array BufferIndices(const Array& indices, const AccessPath& p, - const IRDocsifier& d) { +ffi::Array BufferIndices(const ffi::Array& indices, const AccessPath& p, + const IRDocsifier& d) { int n = indices.size(); - Array indices_doc; + ffi::Array indices_doc; indices_doc.reserve(n); for (int i = 0; i < n; ++i) { if (const auto* ramp = indices[i].as()) { @@ -231,7 +234,7 @@ Array BufferIndices(const Array& indices, const AccessPath& p, ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // ramp_p->Attr("lanes")); - Optional step = std::nullopt; + ffi::Optional step = std::nullopt; if (stride->value != 1) { step = d->AsDoc(ramp->stride, ramp_p->Attr("stride")); } @@ -244,9 +247,10 @@ Array BufferIndices(const Array& indices, const AccessPath& p, return indices_doc; } -Array BufferSlices(const Array& region, const AccessPath& p, const IRDocsifier& d) { +ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& p, + const IRDocsifier& d) { int n = region.size(); - Array indices; + ffi::Array indices; indices.reserve(n); for (int i = 0; i < n; ++i) { Range range = region[i]; @@ -306,14 +310,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // .set_dispatch("", [](tir::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { - if (Optional opt_f = FindLowestVarDef(buffer, d)) { + if (ffi::Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); ExprDoc rhs = BufferDecl(buffer, "Buffer", {}, p, opt_f.value(), d, BufferVarDefinition::DataPointer); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } } - if (Optional doc = d->GetVarDoc(buffer)) { + if (ffi::Optional doc = d->GetVarDoc(buffer)) { return doc.value(); } LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 78b52edf859c..ddcf1b64f1a1 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -28,8 +28,8 @@ ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRD Type type = var->type_annotation; AccessPath type_p = var_p->Attr("type_annotation"); ExprDoc rhs{nullptr}; - Array kwargs_keys; - Array kwargs_values; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (var->IsInstance()) { kwargs_keys.push_back("is_size_var"); @@ -66,7 +66,7 @@ ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRD Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { - if (Optional opt_f = FindLowestVarDef(var, d)) { + if (ffi::Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); ExprDoc rhs = PrintVarCreation(var, var_p, d); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); @@ -74,7 +74,7 @@ Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; } } - if (Optional doc = d->GetVarDoc(var)) { + if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint; @@ -173,7 +173,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) { With f(d, r); int n_vars = r->lhs.size(); - Array vars; + ffi::Array vars; vars.reserve(n_vars + n_vars); for (int i = 0; i < n_vars; ++i) { vars.push_back(Downcast(DefineVar(r->lhs[i], *f, d))); @@ -182,7 +182,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) vars.push_back(Downcast(DefineVar(r->rhs[i], *f, d))); } int n_results = r->result.size(); - Array results; + ffi::Array results; results.reserve(n_results); for (int i = 0; i < n_results; ++i) { results.push_back(d->AsDoc(r->result[i], p->Attr("result")->ArrayItem(i))); @@ -197,14 +197,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "comm_reducer")->Call({lambda, id}); }); -LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const AccessPath& vs_p, - const Array& es, const AccessPath& es_p, const IRDocsifier& d) { +LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, + const AccessPath& vs_p, const ffi::Array& es, + const AccessPath& es_p, const IRDocsifier& d) { With f(d, map); - Array vars; + ffi::Array vars; for (int i = 0, l = vs.size(); i < l; ++i) { vars.push_back(Downcast(DefineVar(vs[i], *f, d))); } - Array exprs; + ffi::Array exprs; for (int i = 0, l = es.size(); i < l; ++i) { exprs.push_back(d->AsDoc(es[i], es_p->ArrayItem(i))); } @@ -246,7 +247,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc prefix{nullptr}; if (auto optional_op = call->op.as()) { auto op = optional_op.value(); - String name = op_names.get(op, op->name); + ffi::String name = op_names.get(op, op->name); if (op_names.count(op) == 0) { LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } @@ -261,7 +262,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) auto f_llvm_lookup_intrinsic_name = tvm::ffi::Function::GetGlobal("target.llvm_get_intrinsic_name"); - Array args; + ffi::Array args; args.reserve(n_args + 1); if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); @@ -269,7 +270,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n_args; ++i) { if ((i == 0) && (f_llvm_lookup_intrinsic_name)) { - String name = (*f_llvm_lookup_intrinsic_name)(id).cast(); + ffi::String name = (*f_llvm_lookup_intrinsic_name)(id).cast(); args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayItem(i))); } else { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); @@ -285,7 +286,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else { LOG(FATAL) << "call: " << call; } - Array args; + ffi::Array args; int n_args = call->args.size(); args.reserve(n_args + 1); if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index bfdae3b14221..10bb6f756df2 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -50,8 +50,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 2. Construct `T.grid` if (grid.size() > 1) { int n = grid.size(); - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; lhs.reserve(n); rhs.reserve(n); for (int i = 0; i < n; ++i) { @@ -65,10 +65,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 3. If not `T.grid`, print loop kind accordingly ExprDoc lhs = DefineVar(loop->loop_var, *f, d); - Optional min = std::nullopt; - Optional max = std::nullopt; - Optional annotations = std::nullopt; - Optional thread = std::nullopt; + ffi::Optional min = std::nullopt; + ffi::Optional max = std::nullopt; + ffi::Optional annotations = std::nullopt; + ffi::Optional thread = std::nullopt; if (tir::is_zero(loop->min)) { max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { @@ -98,9 +98,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else { LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); } - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (min.defined()) { args.push_back(min.value()); } diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 688c58e6de09..c5083b57c2d0 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ++buffer_data_counter.at(data_var); } // Step 1. Handle `func->params` - Array args; + ffi::Array args; args.reserve(n_args); std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { @@ -107,8 +107,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (func->attrs.defined() && !func->attrs->dict.empty()) { // for global symbol, don't display it if it matches the func name if (func->attrs->dict.count(tvm::attr::kGlobalSymbol) && - Downcast(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { - Map new_attrs; + Downcast(func->attrs->dict.at(tvm::attr::kGlobalSymbol)) == + func_name->name) { + ffi::Map new_attrs; for (auto kv : func->attrs->dict) { if (kv.first != tvm::attr::kGlobalSymbol) { new_attrs.Set(kv.first, kv.second); @@ -142,7 +143,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Handle `func->body` - Optional implicit_root_block = [&]() -> Optional { + ffi::Optional implicit_root_block = [&]() -> ffi::Optional { const tir::BlockRealizeNode* root_block_realize = func->body.as(); if (root_block_realize && !root_block_realize->iter_values.size() && tir::is_one(root_block_realize->predicate)) { @@ -178,7 +179,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else { AsDocBody(func->body, p->Attr("body"), f->get(), d); } - Optional ret_type = std::nullopt; + ffi::Optional ret_type = std::nullopt; if (func->ret_type.defined()) { const auto* as_tuple = func->ret_type.as(); if (!as_tuple || as_tuple->fields.size()) { @@ -189,9 +190,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc decorator = TIR(d, "prim_func"); // mark private if there is no global symbol if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { - Array pos_args; + ffi::Array pos_args; decorator = decorator->Call(pos_args, {"private"}, - {LiteralDoc::Boolean(true, Optional())}); + {LiteralDoc::Boolean(true, ffi::Optional())}); } return HeaderWrapper(d, FunctionDoc( @@ -207,7 +208,7 @@ TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "tir", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // - if (Optional doc = d->GetVarDoc(n)) { + if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { IdDoc ret(n->name_hint); @@ -219,7 +220,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "tir", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // - Optional doc = d->GetVarDoc(mod); + ffi::Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); }); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index a99d4236158f..0cd38d4c6a49 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -91,7 +91,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, AccessPath p, IRDocsifier d) -> Doc { - Map config = target->Export(); + ffi::Map config = target->Export(); return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 14acff77bed8..228fbbc78556 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -23,8 +23,8 @@ namespace tvm { namespace script { namespace printer { -Doc DoConciseScoping(const Optional& lhs, const ExprDoc& rhs, Array* stmts, - bool concise_scoping) { +Doc DoConciseScoping(const ffi::Optional& lhs, const ExprDoc& rhs, + ffi::Array* stmts, bool concise_scoping) { if (concise_scoping) { if (lhs.defined()) { stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, std::nullopt)); @@ -64,7 +64,7 @@ bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IR return false; } -Optional FindReturnValue(const tir::Stmt& node) { +ffi::Optional FindReturnValue(const tir::Stmt& node) { auto eval = node.as(); if (!eval) return std::nullopt; @@ -99,8 +99,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); // Step 1. Type annotation - Optional type_doc = d->AsDoc(stmt->var->type_annotation, // - p->Attr("var")->Attr("type_annotation")); + ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // + p->Attr("var")->Attr("type_annotation")); if (const auto* tuple_type = stmt->var->type_annotation.as()) { if (tuple_type->fields.empty()) { type_doc = std::nullopt; @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); // Step 3. LHS and body With f(d, stmt); - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; bool var_defined = d->IsVarDefined(stmt->var); if (!var_defined) { DefineVar(stmt->var, *f, d); @@ -139,7 +139,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); if (concise) { - Array* stmts = &(*f)->stmts; + ffi::Array* stmts = &(*f)->stmts; stmts->insert(stmts->begin(), AssertDoc(cond, msg)); return StmtBlockDoc(*stmts); } @@ -177,8 +177,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); - Array then_branch; - Array else_branch; + ffi::Array then_branch; + ffi::Array else_branch; if (stmt->then_case.defined()) { With f(d, stmt->then_case); AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d); @@ -226,9 +226,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DeclBufferDoc(Downcast(stmt->body), stmt_p->Attr("body"), d, BufferVarDefinition::DataPointer); } - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var), @@ -260,7 +260,7 @@ ExprDoc PrintTensor(::tvm::runtime::Tensor arr) { for (int i = 0; i < ndim; i++) { tot_dim *= arr->shape[i]; } - Array result; + ffi::Array result; T* data_ptr = reinterpret_cast(arr->data); runtime::DataType dtype = arr.DataType(); for (int i = 0; i < tot_dim; i++) { @@ -280,10 +280,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tir::AllocateConst stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); - Array args; - Array kwargs_keys; - Array kwargs_values; + ffi::String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); + ffi::Array args; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; ExprDoc data_doc{nullptr}; if (stmt->dtype.is_int()) { if (stmt->dtype.bits() == 8) { @@ -332,11 +332,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); -ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional value, // +ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, ffi::Optional value, // AccessPath p, IRDocsifier d) { ExprDoc buffer = d->AsDoc(stmt->buffer, p->Attr("buffer")); { - Array bounds; + ffi::Array bounds; bounds.reserve(stmt->bounds.size()); for (int i = 0, n = stmt->bounds.size(); i < n; ++i) { Range range = stmt->bounds[i]; @@ -348,9 +348,9 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional args{buffer}; - Array kwargs_keys; - Array kwargs_values; + ffi::Array args{buffer}; + ffi::Array kwargs_keys; + ffi::Array kwargs_values; if (value.defined()) { args.push_back(value.value()); } @@ -373,7 +373,7 @@ void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p, } ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, - Optional* define_var, const IRDocsifier& d) { + ffi::Optional* define_var, const IRDocsifier& d) { tir::IterVar iter_var = Downcast(attr_stmt->node); AccessPath iter_var_p = attr_stmt_p->Attr("node"); @@ -408,9 +408,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - Optional lhs = std::nullopt; - Optional rhs = std::nullopt; - Optional define_var = std::nullopt; + ffi::Optional lhs = std::nullopt; + ffi::Optional rhs = std::nullopt; + ffi::Optional define_var = std::nullopt; tir::Stmt body = stmt->body; AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 4474a83ca8ff..1bbdf2e02d65 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -65,7 +65,7 @@ class TIRFrame : public Frame { public: /*! \brief Constructor */ explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); n->tir = tir; @@ -84,7 +84,7 @@ class TIRFrame : public Frame { * \return The IdDoc corresponding to the variable */ inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { - if (Optional doc = d->GetVarDoc(var)) { + if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint); @@ -111,7 +111,7 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I */ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { - Array body = seq_stmt->seq; + ffi::Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { f->allow_concise_scoping = (i == n - 1); Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); @@ -139,7 +139,7 @@ inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, cons * \param d The IRDocsifier * \return The frame that could place the var definition */ -inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { +inline ffi::Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { if (!d->common_prefix.count(var.get())) { return std::nullopt; } @@ -159,11 +159,11 @@ inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& const std::vector& path = d->common_prefix.at(var.get()); for (auto it = path.rbegin(); it != path.rend(); ++it) { if (tir_to_frame.count(*it)) { - return GetRef(tir_to_frame.at(*it)); + return ffi::GetRef(tir_to_frame.at(*it)); } } if (fallback_frame != nullptr) { - return GetRef(fallback_frame); + return ffi::GetRef(fallback_frame); } return std::nullopt; } @@ -214,9 +214,9 @@ enum class BufferVarDefinition { * the buffer. * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const AccessPath& p, const Frame& frame, const IRDocsifier& d, - BufferVarDefinition var_definitions); +ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, + const ffi::Array& args, const AccessPath& p, const Frame& frame, + const IRDocsifier& d, BufferVarDefinition var_definitions); /*! * \brief Declare and define a buffer as annotation diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 1e3a258579a2..8e9b9cdf1049 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -56,9 +56,9 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra if (!cfg->verbose_expr) { f->stmts.clear(); } - f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); + f->stmts.push_back(ExprStmtDoc(ffi::GetRef(expr_doc))); } else if (const auto* stmt_doc = doc.as()) { - f->stmts.push_back(GetRef(stmt_doc)); + f->stmts.push_back(ffi::GetRef(stmt_doc)); } else if (const auto* stmt_block = doc.as()) { for (const StmtDoc& d : stmt_block->stmts) { f->stmts.push_back(d); @@ -72,8 +72,8 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra if (d->cfg->show_meta) { os << "metadata = tvm.ir.load_json(\"\"\"" << support::StrEscape( - SaveJSON(Map(d->metadata.begin(), d->metadata.end())), false, - false) + SaveJSON(ffi::Map(d->metadata.begin(), d->metadata.end())), + false, false) << "\"\"\")\n"; } else { f->stmts.push_back( @@ -91,19 +91,19 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra } /*! \brief Creates the IR common prefix, which is by default `I` */ -inline ExprDoc IR(const IRDocsifier& d, const String& attr) { +inline ExprDoc IR(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("ir"); return IdDoc(d->cfg->ir_prefix)->Attr(attr); } /*! \brief Creates the TIR common prefix, which is by default `T` */ -inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { +inline ExprDoc TIR(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("tir"); return IdDoc(d->cfg->tir_prefix)->Attr(attr); } /*! \brief Creates the Relax common prefix, which is by default `R` */ -inline ExprDoc Relax(const IRDocsifier& d, const String& attr) { +inline ExprDoc Relax(const IRDocsifier& d, const ffi::String& attr) { d->ir_usage.insert("relax"); return IdDoc(d->cfg->relax_prefix)->Attr(attr); } @@ -115,7 +115,7 @@ inline std::string DType2Str(const runtime::DataType& dtype) { /*! \brief Add headers as comments to doc if needed */ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { if (d->ir_usage.size()) { - Array stmts; + ffi::Array stmts; if (d->ir_usage.count("ir")) { stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix)); } @@ -137,23 +137,23 @@ inline bool HasMultipleLines(const std::string& str) { return str.find_first_of('\n') != std::string::npos; } -inline Optional GetBindingName(const IRDocsifier& d) { - return d->cfg->binding_names.empty() ? Optional(std::nullopt) +inline ffi::Optional GetBindingName(const IRDocsifier& d) { + return d->cfg->binding_names.empty() ? ffi::Optional(std::nullopt) : d->cfg->binding_names.back(); } -inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { - if (Optional name = GetBindingName(d)) { +inline ffi::Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { + if (ffi::Optional name = GetBindingName(d)) { return name.value(); } - if (Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { + if (ffi::Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { return sym.value(); } return std::nullopt; } -inline String GenerateUniqueName(std::string name_hint, - const std::unordered_set& defined_names) { +inline ffi::String GenerateUniqueName(std::string name_hint, + const std::unordered_set& defined_names) { for (char& c : name_hint) { if (c != '_' && !std::isalnum(c)) { c = '_'; diff --git a/src/support/array.h b/src/support/array.h index f49439aeb3ff..6e2aeca3e11f 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -35,7 +35,7 @@ namespace support { * \return A boolean indicating if they are the same */ template -inline bool ArrayWithSameContent(const Array& a, const Array& b) { +inline bool ArrayWithSameContent(const ffi::Array& a, const ffi::Array& b) { if (a.size() != b.size()) { return false; } @@ -76,7 +76,7 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector * \return The result vector */ template -inline std::vector AsVector(const Array& vec); +inline std::vector AsVector(const ffi::Array& vec); /*! * \brief Convert a std::vector to tvm::Array @@ -85,7 +85,7 @@ inline std::vector AsVector(const Array& vec); * \return The result Array */ template -inline Array AsArray(const std::vector& vec); +inline ffi::Array AsArray(const std::vector& vec); /*! * \brief Convert a tvm::Array to std::list @@ -93,7 +93,7 @@ inline Array AsArray(const std::vector& vec); * \return The result list */ template -inline std::list AsList(const Array& array) { +inline std::list AsList(const ffi::Array& array) { std::list list; for (const auto& v : array) list.push_back(v); return list; @@ -105,8 +105,8 @@ inline std::list AsList(const Array& array) { * \return The result list */ template -inline Array AsArray(const std::list& list) { - Array array; +inline ffi::Array AsArray(const std::list& list) { + ffi::Array array; for (const auto& v : list) array.push_back(v); return array; } @@ -116,8 +116,8 @@ inline Array AsArray(const std::list& list) { * \param shape The shape tuple * \return An array of the shape tuple */ -inline Array AsArray(const ffi::Shape& shape) { - Array result; +inline ffi::Array AsArray(const ffi::Shape& shape) { + ffi::Array result; result.reserve(shape->size); for (ffi::Shape::index_type i : shape) { result.push_back(Integer(i)); @@ -134,12 +134,12 @@ inline Array AsArray(const ffi::Shape& shape) { * \return The concatenated array */ template -inline Array ConcatArrayList(Iterator begin, Iterator end) { +inline ffi::Array ConcatArrayList(Iterator begin, Iterator end) { int size = 0; for (Iterator it = begin; it != end; ++it) { size += (*it).size(); } - Array result; + ffi::Array result; result.reserve(size); for (Iterator it = begin; it != end; ++it) { const auto& item = *it; @@ -157,17 +157,17 @@ struct AsVectorImpl {}; template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const ffi::Array& vec) const { return std::vector(vec.begin(), vec.end()); } }; template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -179,10 +179,10 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -194,10 +194,10 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { + inline std::vector operator()(const ffi::Array& array) const { ffi::Any ret_value; ret_value = array; - Array as_int_vec = ret_value.cast>(); + ffi::Array as_int_vec = ret_value.cast>(); std::vector results; for (const auto& value : as_int_vec) { @@ -217,15 +217,15 @@ struct AsArrayImpl {}; template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - return Array(vec.begin(), vec.end()); + inline ffi::Array operator()(const std::vector& vec) const { + return ffi::Array(vec.begin(), vec.end()); } }; template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -238,8 +238,8 @@ struct AsArrayImpl { template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -252,8 +252,8 @@ struct AsArrayImpl { template struct AsArrayImpl { - inline Array operator()(const std::vector& vec) const { - Array result; + inline ffi::Array operator()(const std::vector& vec) const { + ffi::Array result; result.reserve(vec.size()); for (auto x : vec) { ffi::Any ret_value; @@ -267,12 +267,12 @@ struct AsArrayImpl { } // namespace details template -inline std::vector AsVector(const Array& vec) { +inline std::vector AsVector(const ffi::Array& vec) { return details::AsVectorImpl()(vec); } template -inline Array AsArray(const std::vector& vec) { +inline ffi::Array AsArray(const std::vector& vec) { return details::AsArrayImpl()(vec); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 70c23c546bbb..9f4d03416332 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -37,8 +37,8 @@ namespace tvm { // Attrs used to python API struct TestAttrs : public AttrsNodeReflAdapter { int axis; - String name; - Array padding; + ffi::String name; + ffi::Array padding; TypedEnvFunc func; static void RegisterReflection() { @@ -47,7 +47,7 @@ struct TestAttrs : public AttrsNodeReflAdapter { .def_ro("axis", &TestAttrs::axis, "axis field", refl::DefaultValue(10)) .def_ro("name", &TestAttrs::name, "name") .def_ro("padding", &TestAttrs::padding, "padding of input", - refl::DefaultValue(Array({0, 0}))) + refl::DefaultValue(ffi::Array({0, 0}))) .def_ro("func", &TestAttrs::func, "some random env function", refl::DefaultValue(TypedEnvFunc(nullptr))); } @@ -129,7 +129,7 @@ class FrontendTestModuleNode : public ffi::ModuleObj { static constexpr const char* kAddFunctionName = "__add_function"; - virtual ffi::Optional GetFunction(const String& name); + virtual ffi::Optional GetFunction(const ffi::String& name); private: std::unordered_map functions_; @@ -137,8 +137,8 @@ class FrontendTestModuleNode : public ffi::ModuleObj { constexpr const char* FrontendTestModuleNode::kAddFunctionName; -ffi::Optional FrontendTestModuleNode::GetFunction(const String& name) { - ffi::Module self_strong_ref = GetRef(this); +ffi::Optional FrontendTestModuleNode::GetFunction(const ffi::String& name) { + ffi::Module self_strong_ref = ffi::GetRef(this); if (name == kAddFunctionName) { return ffi::Function::FromTyped( [this, self_strong_ref](std::string func_name, ffi::Function pf) { @@ -157,7 +157,7 @@ ffi::Optional FrontendTestModuleNode::GetFunction(const String& n } ffi::Module NewFrontendTestModule() { - auto n = make_object(); + auto n = ffi::make_object(); return ffi::Module(n); } @@ -172,16 +172,16 @@ TVM_FFI_STATIC_INIT_BLOCK({ std::this_thread::sleep_for(duration); }) .def("testing.ReturnsVariant", - [](int x) -> Variant { + [](int x) -> ffi::Variant { if (x % 2 == 0) { return IntImm(DataType::Int(64), x / 2); } else { - return String("argument was odd"); + return ffi::String("argument was odd"); } }) .def("testing.AcceptsVariant", - [](Variant arg) -> String { - if (auto opt_str = arg.as()) { + [](ffi::Variant arg) -> ffi::String { + if (auto opt_str = arg.as()) { return ffi::StaticTypeKey::kTVMFFIStr; } else { return arg.get().GetTypeKey(); @@ -189,13 +189,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("testing.AcceptsBool", [](bool arg) -> bool { return arg; }) .def("testing.AcceptsInt", [](int arg) -> int { return arg; }) - .def("testing.AcceptsObjectRefArray", [](Array arg) -> Any { return arg[0]; }) + .def("testing.AcceptsObjectRefArray", [](ffi::Array arg) -> Any { return arg[0]; }) .def("testing.AcceptsMapReturnsValue", - [](Map map, Any key) -> Any { return map[key]; }) - .def("testing.AcceptsMapReturnsMap", [](Map map) -> ObjectRef { return map; }) + [](ffi::Map map, Any key) -> Any { return map[key]; }) + .def("testing.AcceptsMapReturnsMap", [](ffi::Map map) -> ObjectRef { return map; }) .def("testing.AcceptsPrimExpr", [](PrimExpr expr) -> ObjectRef { return expr; }) .def("testing.AcceptsArrayOfPrimExpr", - [](Array arr) -> ObjectRef { + [](ffi::Array arr) -> ObjectRef { for (ObjectRef item : arr) { CHECK(item->IsInstance()) << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; @@ -203,14 +203,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ return arr; }) .def("testing.AcceptsArrayOfVariant", - [](Array> arr) -> ObjectRef { + [](ffi::Array> arr) -> ObjectRef { for (auto item : arr) { CHECK(item.as() || item.as()) << "Array should contain either PrimExpr or ffi::Function"; } return arr; }) - .def("testing.AcceptsMapOfPrimExpr", [](Map map) -> ObjectRef { + .def("testing.AcceptsMapOfPrimExpr", [](ffi::Map map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; CHECK(value->IsInstance()) @@ -226,7 +226,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ class TestingEventLogger { public: struct Entry { - String event; + ffi::String event; double time_us; }; @@ -235,7 +235,7 @@ class TestingEventLogger { start_ = std::chrono::high_resolution_clock::now(); } - void Record(String event) { + void Record(ffi::String event) { auto tend = std::chrono::high_resolution_clock::now(); double time_us = static_cast((tend - start_).count()) / 1e3; entries_.emplace_back(Entry{event, time_us}); @@ -264,8 +264,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def_packed("testing.record_event", [](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() != 0 && args[0].try_cast()) { - TestingEventLogger::ThreadLocal()->Record(args[0].cast()); + if (args.size() != 0 && args[0].try_cast()) { + TestingEventLogger::ThreadLocal()->Record(args[0].cast()); } else { TestingEventLogger::ThreadLocal()->Record("X"); } diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d404..f63aaf92faca 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -50,7 +50,7 @@ inline NDIntSet NDIntSetFromRegion(const tir::Region& region) { * \param shape The shape which is an array of the length of each dimension. * \return The constructed set. */ -inline NDIntSet NDIntSetFromShape(const Array& shape) { +inline NDIntSet NDIntSetFromShape(const ffi::Array& shape) { PrimExpr zero = Integer(0); NDIntSet result; result.reserve(shape.size()); @@ -65,7 +65,7 @@ inline NDIntSet NDIntSetFromShape(const Array& shape) { * \param indices The N-dimensional indices representing the point. * \return The constructed set. */ -inline NDIntSet NDIntSetFromPoint(const Array& indices) { +inline NDIntSet NDIntSetFromPoint(const ffi::Array& indices) { NDIntSet result; result.reserve(indices.size()); for (const PrimExpr& index : indices) { @@ -106,7 +106,7 @@ inline NDIntSet NDIntSetUnion(const std::vector& nd_int_sets) { } NDIntSet result; result.reserve(ndim); - Array int_sets(n, arith::IntSet{nullptr}); + ffi::Array int_sets(n, arith::IntSet{nullptr}); for (int dim = 0; dim < ndim; ++dim) { for (int i = 0; i < n; ++i) { int_sets.Set(i, nd_int_sets[i][dim]); diff --git a/src/target/build_common.h b/src/target/build_common.h index 9e52f6f8ffa6..cf1e3344fc3c 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -60,12 +60,12 @@ inline std::unordered_map ExtractFuncInfo(co ? runtime::FunctionInfo::ArgExtraTags::kTensorMap : runtime::FunctionInfo::ArgExtraTags::kNone); } - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& tag : opt.value()) { info.launch_param_tags.push_back(tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); if (global_symbol) { fmap[static_cast(global_symbol.value())] = info; } diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index ac45476f7702..5b6b0e107c02 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -79,7 +79,7 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { name = T()(dtype, name.substr(4)); if (name.length() != 0) { - Array new_args = {StringImm(name)}; + ffi::Array new_args = {StringImm(name)}; for (auto arg : call->args) { new_args.push_back(arg); } diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 7937f72bea43..545e90697c58 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -85,7 +85,7 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { } const auto* attr_value = op->value.as(); - ICHECK(attr_value) << "Expect " << attr_key << " to have a String value but was " + ICHECK(attr_value) << "Expect " << attr_key << " to have a ffi::String value but was " << op->value->GetTypeKey(); std::string aarch64_attr_key = attr_key.substr(7); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 9439af440b82..8fd9dc210561 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -280,7 +280,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); auto fbitcode = tvm::ffi::Function::GetGlobalRequired("tvm_callback_rocm_bitcode_path"); - auto bitcode_files = fbitcode().cast>(); + auto bitcode_files = fbitcode().cast>(); for (auto& bitcode_path : bitcode_files) { std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 3adcfc82bba8..c686e5fc38d4 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -75,7 +75,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_fixed_length_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { - Array vcnt_args; + ffi::Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(e); return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); @@ -98,13 +98,13 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { // Popcount 8bit->8bit const CallNode* c0 = input8.as(); ICHECK(c0 != nullptr); - Array vcnt8_args; + ffi::Array vcnt8_args; vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); // Accumulation 8->16bit - Array vcnt16_args; + ffi::Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); @@ -113,7 +113,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { } // Accumulation 16->32bit - Array vcnt32_args; + ffi::Array vcnt32_args; vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); @@ -122,7 +122,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { } // Accumulation 32->64bit - Array vcnt64_args; + ffi::Array vcnt64_args; vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(vcnt32); return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 34e9e8381898..e9dbdeb0c23e 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -71,7 +71,7 @@ CodeGenCPU::CodeGenCPU() = default; CodeGenCPU::~CodeGenCPU() = default; void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { CodeGenLLVM::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); system_lib_prefix_ = system_lib_prefix; @@ -175,7 +175,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, } llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(llvm::StringRef name, - const Array& param_types, + const ffi::Array& param_types, const Type& return_type) { #if TVM_LLVM_VERSION < 50 return nullptr; @@ -211,7 +211,7 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(llvm::StringRef name, } llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const GlobalVar& gvar, const PrimFunc& func) { - std::string name = func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + std::string name = func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); return CreateDebugFunction(name, func->params.Map(GetType), func->ret_type); } @@ -220,7 +220,7 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { EmitDebugLocation(func->span); CodeGenLLVM::AddFunction(gvar, func); if (f_tvm_register_system_symbol_ != nullptr) { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { export_system_symbols_.emplace_back( std::make_pair(global_symbol.value().operator std::string(), function_)); } @@ -390,8 +390,8 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } } -llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) { std::vector arg_values; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { arg_values.push_back(MakeValue(args[i])); @@ -531,7 +531,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // - Make sure the generated compute function is clearly separately(though it can get inlined) // - Set noalias on all the pointer arguments, some of them are loaded from ffi::PackedArgs. // This is easier than set the alias scope manually. - Array vargs = tir::UndefinedVars(op->body, {}); + ffi::Array vargs = tir::UndefinedVars(op->body, {}); std::vector arg_values; std::vector arg_types; for (Var v : vargs) { @@ -598,7 +598,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { AddDebugInformation(fcompute, vargs.Map(GetType)); } -CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, +CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const ffi::Array& vfields, uint64_t* num_bytes, std::string struct_name) { if (vfields.size() == 0) { @@ -624,7 +624,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, return TypedPointer(ctype, cvalue); } -void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array& vfields, +void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const ffi::Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { llvm::Type* field_type = cdata.type->getStructElementType(i); @@ -644,7 +644,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin SetTargetAttributes(f); // allocate and setup the closure, call the closure. - Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); #if TVM_LLVM_VERSION >= 90 @@ -720,7 +720,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod } // allocate and setup the closure, call the closure. uint64_t nbytes; - Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tir::UndefinedVars(body, {}); TypedPointer cdata = PackClosureData(vfields, &nbytes); llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); @@ -830,7 +830,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, +CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Array& args, const DataType& r_type, const int64_t begin, const int64_t end, bool use_env_lookup) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index f8c6b362badf..d5401b966220 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -65,7 +65,7 @@ class CodeGenCPU : public CodeGenLLVM { virtual ~CodeGenCPU(); void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; @@ -74,8 +74,8 @@ class CodeGenCPU : public CodeGenLLVM { void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const ForNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg) override; + llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) override; protected: void AddStartupFunction() final; @@ -122,10 +122,10 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - TypedPointer PackClosureData(const Array& fields, uint64_t* num_bytes, + TypedPointer PackClosureData(const ffi::Array& fields, uint64_t* num_bytes, std::string struct_name = ""); TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(TypedPointer cdata, const Array& fields, + void UnpackClosureData(TypedPointer cdata, const ffi::Array& fields, std::unordered_map* vmap); // Make packed call. struct PackedCall { @@ -133,7 +133,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* ret_type_index; llvm::BasicBlock* end_block; }; - PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, + PackedCall MakeCallPackedLowered(const ffi::Array& args, const DataType& r_type, const int64_t begin, const int64_t end, bool use_string_lookup); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); @@ -151,7 +151,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); llvm::DISubprogram* CreateDebugFunction(const GlobalVar& gvar, const PrimFunc& f); - llvm::DISubprogram* CreateDebugFunction(llvm::StringRef name, const Array& param_types, + llvm::DISubprogram* CreateDebugFunction(llvm::StringRef name, const ffi::Array& param_types, const Type& return_type); // Context for injection lookup @@ -161,7 +161,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_ffi_set_last_error_c_str_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; - std::unordered_map gv_func_map_; + std::unordered_map gv_func_map_; // context for direct dynamic lookup llvm::Function* f_tvm_ffi_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; @@ -181,7 +181,7 @@ class CodeGenCPU : public CodeGenLLVM { bool target_c_runtime_; // The system lib prefix if it is not nullopt, then we should do // system lib registration with the given prefix. The prefix can be "" - Optional system_lib_prefix_; + ffi::Optional system_lib_prefix_; }; } // namespace codegen diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 67fccd8b073a..55abd565ff99 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -71,7 +71,7 @@ namespace codegen { class CodeGenHexagon final : public CodeGenCPU { public: void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) override; void InitTarget() final; @@ -79,10 +79,10 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Value* VisitExpr_(const BufferLoadNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg) override; - llvm::Value* CreateCallExternQHL(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg); + llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) override; + llvm::Value* CreateCallExternQHL(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg); llvm::Module* GetModulePtr() const { return module_.get(); } @@ -105,7 +105,7 @@ class CodeGenHexagon final : public CodeGenCPU { bool IsQHLFunction(const std::string& func); - llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array indices); + llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, ffi::Array indices); llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); std::vector fqhl_list_ = { "tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf", @@ -116,7 +116,7 @@ class CodeGenHexagon final : public CodeGenCPU { }; void CodeGenHexagon::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { CodeGenCPU::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); } @@ -149,8 +149,9 @@ void CodeGenHexagon::InitTarget() { CodeGenCPU::InitTarget(); } -llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, + bool skip_first_arg) { int num_lanes = args[1].dtype().lanes(); int vector_length = native_vector_bits_ / args[1].dtype().bits(); num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length; @@ -184,8 +185,9 @@ bool CodeGenHexagon::IsQHLFunction(const std::string& func) { return std::find(fqhl_list_.begin(), fqhl_list_.end(), func) != fqhl_list_.end(); } -llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, + bool skip_first_arg) { int num_lanes = args[1].dtype().lanes(); int vector_length = native_vector_bits_ / args[1].dtype().bits(); if (IsQHLFunction(global_symbol) && (num_lanes > vector_length)) @@ -328,7 +330,7 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, } llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type, - Array indices) { + ffi::Array indices) { PrimExpr index = indices[0]; if (!index.dtype().is_fixed_length_vector()) { return nullptr; @@ -453,8 +455,8 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { return vec; }; std::string llvm_options_str = "llvm"; - if (const auto& llvm_options = target->GetAttr>("llvm-options")) { - for (const String& s : llvm_options.value()) llvm_options_str += "," + s; + if (const auto& llvm_options = target->GetAttr>("llvm-options")) { + for (const ffi::String& s : llvm_options.value()) llvm_options_str += "," + s; } // Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '. for (int i = 0, e = llvm_options_str.size(); i != e; ++i) { @@ -494,7 +496,7 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { } auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()); entry_func = global_symbol.value(); } @@ -572,10 +574,10 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { ICHECK(f.has_value()) << "tvm.contrib.hexagon.link_shared does not to exist, " "do import tvm.contrib.hexagon"; - Array o_names = {StringImm(o_name)}; - Map extra_args; + ffi::Array o_names = {StringImm(o_name)}; + ffi::Map extra_args; if (target->attrs.count("mcpu")) { - std::string mcpu = Downcast(target->attrs.at("mcpu")); + std::string mcpu = Downcast(target->attrs.at("mcpu")); #if TVM_LLVM_VERSION >= 180 ICHECK(llvm::StringRef(mcpu).starts_with("hexagon")) #else diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bb4a76bc19c9..ecbdf437608d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -138,7 +138,7 @@ std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { } void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, + ffi::Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) { llvm_target_ = llvm_target; llvm::LLVMContext* ctx = llvm_target_->GetContext(); @@ -240,7 +240,7 @@ void CodeGenLLVM::InitFuncState() { std::tuple CodeGenLLVM::GetLinkage( const GlobalVar& gvar, const PrimFunc& func) { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { return {global_symbol.value(), llvm::Function::ExternalLinkage}; } @@ -717,8 +717,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = - NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); + *p_native_bits = NativeVectorBits( + runtime::StorageScope::Create(GetPtrStorageScope(ffi::GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -1060,8 +1060,8 @@ llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) { return call; } -llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg) { std::vector arg_value; std::vector arg_type; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { @@ -1367,7 +1367,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::Type* return_type = GetLLVMType(GetRef(op)); + llvm::Type* return_type = GetLLVMType(ffi::GetRef(op)); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " << llvmGetIntrinName(id); @@ -1406,7 +1406,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { const BufferLoadNode* load = op->args[0].as(); ICHECK(op->args.size() == 1 && load); - Array indices = load->indices; + ffi::Array indices = load->indices; if (const RampNode* r = indices[indices.size() - 1].as()) { indices.Set(indices.size() - 1, r->base); } @@ -1697,7 +1697,8 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + Buffer buffer, ffi::Array indices, ffi::Optional predicate, + DataType value_dtype, std::function make_instruction) { @@ -1855,20 +1856,20 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); - return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, - true); + return this->CreateCallExtern(GetType(ffi::GetRef(op)), global_symbol->value, + op->args, true); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. - return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], - op->args, false); + return this->CreateCallExtern(GetType(ffi::GetRef(op)), + op_attr_global_symbol_[call_op], op->args, false); } else { - VLOG(2) << "CreateIntrinsic: " << GetRef(op); + VLOG(2) << "CreateIntrinsic: " << ffi::GetRef(op); auto x = CreateIntrinsic(op); VLOG(2) << "CreateIntrinsic done"; return x; } } else if (auto* ptr_gvar = op->op.as()) { - auto gvar = GetRef(ptr_gvar); + auto gvar = ffi::GetRef(ptr_gvar); auto it = functions_.find(ptr_gvar); ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\""; llvm::Function* callee = it->second; @@ -2188,7 +2189,7 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -void CodeGenLLVM::EmitDebugLocation(const Optional& span) { +void CodeGenLLVM::EmitDebugLocation(const ffi::Optional& span) { #if TVM_LLVM_VERSION >= 50 if (di_subprogram_ == nullptr) { // debug info is not always generated outside of CPU codegen @@ -2213,7 +2214,8 @@ void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullpt void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } // Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv -void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, const Array& tvm_param_types) { +void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, + const ffi::Array& tvm_param_types) { #if TVM_LLVM_VERSION >= 50 ICHECK(di_subprogram_); f_llvm->setSubprogram(di_subprogram_); @@ -2355,9 +2357,9 @@ static void CodegenLLVMRegisterReflection() { []() -> std::string { return llvm::sys::getProcessTriple(); }) .def("tvm.codegen.llvm.GetHostCPUName", []() -> std::string { return llvm::sys::getHostCPUName().str(); }) - .def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> Map { + .def("tvm.codegen.llvm.GetHostCPUFeatures", []() -> ffi::Map { #if TVM_LLVM_VERSION >= 190 - Map ret; + ffi::Map ret; auto features = llvm::sys::getHostCPUFeatures(); for (auto it = features.begin(); it != features.end(); ++it) { std::string name = it->getKey().str(); @@ -2368,7 +2370,7 @@ static void CodegenLLVMRegisterReflection() { #else llvm::StringMap features; if (llvm::sys::getHostCPUFeatures(features)) { - Map ret; + ffi::Map ret; for (auto it = features.begin(); it != features.end(); ++it) { std::string name = it->getKey().str(); bool value = it->getValue(); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e1667b637578..cdaac859e430 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -125,7 +125,8 @@ class CodeGenLLVM : public ExprFunctor, * this option influences whether global ctors are used. */ virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, - Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime); + ffi::Optional system_lib_prefix, bool dynamic_lookup, + bool target_c_runtime); /*! * \brief Turn on fast math flags for floating point operations. @@ -266,7 +267,7 @@ class CodeGenLLVM : public ExprFunctor, /*! * \brief Convert tvm::ffi::String into llvm::StringRef */ - static llvm::StringRef MakeStringRef(const String& string) { + static llvm::StringRef MakeStringRef(const ffi::String& string) { return llvm::StringRef(string.c_str(), string.size()); } /*! @@ -293,8 +294,8 @@ class CodeGenLLVM : public ExprFunctor, virtual llvm::Value* CreateIntrinsic(const CallNode* op); // create extern function call // skip first arg mode used for call extern intrinsic. - virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg); + virtual llvm::Value* CreateCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg); /*! \brief Insert a printf() call to the generated LLVM * @@ -359,7 +360,8 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + Buffer buffer, ffi::Array indices, ffi::Optional predicate, + DataType value_dtype, std::function make_instruction); @@ -585,7 +587,7 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); void EmitDebugLocation(); - void EmitDebugLocation(const Optional& span); + void EmitDebugLocation(const ffi::Optional& span); void EmitDebugLocation(const StmtNode* op); // Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only @@ -594,7 +596,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::DIType* GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm); // Adds the DWARF debug information for |function| to |dbg_info_|. - void AddDebugInformation(llvm::Function* f_llvm, const Array& tvm_param_types); + void AddDebugInformation(llvm::Function* f_llvm, const ffi::Array& tvm_param_types); // Adds the DWARF debug information for |tir_var| to |dbg_info_|. void AddDebugInformation(llvm::Value* llvm_value, const Var& tir_var, llvm::Instruction* insert_before = nullptr); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a1c967e644cb..054cfedb4b7c 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -316,7 +316,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { } int GetCUDAComputeVersion(const Target& target) { - Optional mcpu = target->GetAttr("mcpu"); + ffi::Optional mcpu = target->GetAttr("mcpu"); ICHECK(mcpu.has_value()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target"; std::string sm_version = mcpu.value(); return std::stoi(sm_version.substr(3)); diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index b38ff0674943..bb78af0a8434 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -39,7 +39,7 @@ namespace llvm { using tir::FLowerIntrinsic; inline PrimExpr TVMExternCall(const tir::CallNode* call, const std::string& fname) { - Array new_args = {tir::StringImm(fname)}; + ffi::Array new_args = {tir::StringImm(fname)}; for (PrimExpr arg : call->args) { new_args.push_back(arg); } @@ -51,7 +51,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); ICHECK(call != nullptr); - Array new_args; + ffi::Array new_args; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -183,7 +183,7 @@ TVM_REGISTER_OP("tir.sigmoid") const PrimExpr v1 = tir::Max(x, MinBound); const PrimExpr v2 = tir::Min(v1, MaxBound); - Array new_args = {v2}; + ffi::Array new_args = {v2}; const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 17de699e00b4..4ce7ce9f2291 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -264,7 +264,7 @@ TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimEx const tir::CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); - Array cargs; + ffi::Array cargs; cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); cargs.push_back(call->args[0]); cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index aa4f68d0b090..445d33522c7e 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -41,7 +41,7 @@ template inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); ICHECK_EQ(call->args.size(), num_signature) @@ -58,7 +58,7 @@ template inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); ICHECK_EQ(call->args.size(), num_signature) diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 48fc64172215..a5fef4f5d411 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -49,7 +49,7 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { intrinsic_name << "__nv_" << name.substr(4); if (call->dtype.bits() == 32) intrinsic_name << "f"; - Array new_args = {StringImm(intrinsic_name.str())}; + ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 30afcee92acc..d4c92a38d1ba 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -52,7 +52,7 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); - Array new_args = {StringImm(intrinsic_name.str())}; + ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index e494a2bbf9e9..32bada242ceb 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -203,19 +203,19 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) : LLVMTargetInfo(instance, target->Export()) {} LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) { - triple_ = Downcast(target.Get("mtriple").value_or(String("default"))); + triple_ = Downcast(target.Get("mtriple").value_or(ffi::String("default"))); if (triple_.empty() || triple_ == "default") { triple_ = llvm::sys::getDefaultTargetTriple(); } - cpu_ = Downcast(target.Get("mcpu").value_or(String(defaults::cpu))); + cpu_ = Downcast(target.Get("mcpu").value_or(ffi::String(defaults::cpu))); - if (const auto& v = Downcast>>(target.Get("mattr"))) { - for (const String& s : v.value()) { + if (const auto& v = Downcast>>(target.Get("mattr"))) { + for (const ffi::String& s : v.value()) { attrs_.push_back(s); } } // llvm module target - if (Downcast(target.Get("kind").value()) == "llvm") { + if (Downcast(target.Get("kind").value()) == "llvm") { // legalize -mcpu with the target -mtriple auto arches = GetAllLLVMTargetArches(); bool has_arch = @@ -225,16 +225,16 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // give the code a chance to run with a less-specific target. LOG(ERROR) << "Using LLVM " << LLVM_VERSION_STRING << " with `-mcpu=" << cpu_ << "` is not valid in `-mtriple=" << triple_ << "`" - << ", using default `-mcpu=" << String(defaults::cpu) << "`"; + << ", using default `-mcpu=" << ffi::String(defaults::cpu) << "`"; // LLVM default cpu fallback - cpu_ = String(defaults::cpu); + cpu_ = ffi::String(defaults::cpu); } } - if (const auto& v = Downcast>>(target.Get("cl-opt"))) { + if (const auto& v = Downcast>>(target.Get("cl-opt"))) { llvm::StringMap& options = llvm::cl::getRegisteredOptions(); bool parse_error = false; - for (const String& s : v.value()) { + for (const ffi::String& s : v.value()) { Option opt = ParseOptionString(s); if (opt.type == Option::OptType::Invalid) { parse_error = true; @@ -252,8 +252,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; - if (const auto& v = Downcast>(target.Get("mfloat-abi"))) { - String value = v.value(); + if (const auto& v = Downcast>(target.Get("mfloat-abi"))) { + ffi::String value = v.value(); if (value == "hard") { float_abi = llvm::FloatABI::Hard; } else if (value == "soft") { @@ -264,8 +264,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } // LLVM JIT engine options - if (const auto& v = Downcast>(target.Get("jit").value_or(nullptr))) { - String value = v.value(); + if (const auto& v = Downcast>(target.Get("jit").value_or(nullptr))) { + ffi::String value = v.value(); if ((value == "mcjit") || (value == "orcjit")) { jit_engine_ = value; } else { @@ -274,7 +274,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } // TVM & LLVM vector width options - if (const auto& w = Downcast>(target.Get("vector-width").value_or(nullptr))) { + if (const auto& w = + Downcast>(target.Get("vector-width").value_or(nullptr))) { vector_width_ = w.value(); if ((vector_width_ <= 0) || (vector_width_ > 65536)) { LOG(FATAL) << "Invalid -vector-width value: " << vector_width_; @@ -288,7 +289,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) code_model_ = llvm::CodeModel::Medium; #if TVM_LLVM_VERSION >= 140 // get VLEN from the LLVM backend (zvlXXXb) - Map features = GetAllLLVMCpuFeatures(); + ffi::Map features = GetAllLLVMCpuFeatures(); // check vector ISA if (features.count("v") > 0) { vector_width_ = 0; @@ -320,7 +321,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.NoNaNsFPMath = true; target_options_.FloatABIType = float_abi; if (target.find("mabi") != target.end()) { - target_options_.MCOptions.ABIName = Downcast(target.Get("mabi").value()); + target_options_.MCOptions.ABIName = Downcast(target.Get("mabi").value()); } auto maybe_level = target.Get("opt-level"); @@ -833,8 +834,8 @@ void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { } } -const Array LLVMTargetInfo::GetAllLLVMTargets() const { - Array llvm_targets; +const ffi::Array LLVMTargetInfo::GetAllLLVMTargets() const { + ffi::Array llvm_targets; // iterate all archtypes for (auto a = llvm::Triple::ArchType(llvm::Triple::ArchType::UnknownArch + 1); a < llvm::Triple::ArchType::LastArchType; a = llvm::Triple::ArchType(a + 1)) { @@ -848,8 +849,8 @@ const Array LLVMTargetInfo::GetAllLLVMTargets() const { return llvm_targets; } -const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { - Array cpu_arches; +const ffi::Array LLVMTargetInfo::GetAllLLVMTargetArches() const { + ffi::Array cpu_arches; // get the subtarget info module auto llvm_instance = CreateLLVMTargetInstance(triple_, true); std::unique_ptr target_machine = @@ -873,7 +874,7 @@ const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { return cpu_arches; } -const Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { +const ffi::Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { std::string feats = ""; for (const auto& attr : attrs_) { feats += feats.empty() ? attr : ("," + attr); @@ -892,7 +893,7 @@ const Map LLVMTargetInfo::GetAllLLVMCpuFeatures() const { MCInfo->getAllProcessorFeatures(); #endif // TVM doesn't have an FFI friendly Set, so use a Map instead for now - Map cpu_features; + ffi::Map cpu_features; for (const auto& feat : llvm_features) { if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { cpu_features.Set(feat.Key, ""); diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index a68637cc844e..a41c57d6fae6 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -324,14 +324,14 @@ class LLVMTargetInfo { * \brief Get all supported targets from the LLVM backend * \return list with all valid targets */ - const Array GetAllLLVMTargets() const; + const ffi::Array GetAllLLVMTargets() const; /*! * \brief Get all CPU arches from target * \return list with all valid cpu architectures * \note The arches are fetched from the LLVM backend using the target `-mtriple`. */ - const Array GetAllLLVMTargetArches() const; + const ffi::Array GetAllLLVMTargetArches() const; /*! * \brief Get all CPU features from target @@ -340,7 +340,7 @@ class LLVMTargetInfo { * \note The features are fetched from the LLVM backend using the target `-mtriple` * and the `-mcpu` architecture, but also consider the `-mattr` attributes. */ - const Map GetAllLLVMCpuFeatures() const; + const ffi::Map GetAllLLVMCpuFeatures() const; /*! * \brief Check the target if has a specific cpu feature diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 6c88d6943423..c31e1f1a7811 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -95,7 +95,7 @@ class LLVMModuleNode final : public ffi::ModuleObj { const char* kind() const final { return "llvm"; } - Optional GetFunction(const String& name) final; + ffi::Optional GetFunction(const ffi::String& name) final; /*! \brief Get the property of the runtime module .*/ // TODO(tvm-team): Make it serializable @@ -103,15 +103,15 @@ class LLVMModuleNode final : public ffi::ModuleObj { return ffi::Module::kRunnable | ffi::Module::kCompilationExportable; } - void WriteToFile(const String& file_name, const String& format) const final; + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final; ffi::Bytes SaveToBytes() const final; - String InspectSource(const String& format) const final; + ffi::String InspectSource(const ffi::String& format) const final; void Init(const IRModule& mod, const Target& target); void Init(std::unique_ptr module, std::unique_ptr llvm_instance); void LoadIR(const std::string& file_name); - bool ImplementsFunction(const String& name) final; + bool ImplementsFunction(const ffi::String& name) final; void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; } @@ -135,7 +135,7 @@ class LLVMModuleNode final : public ffi::ModuleObj { // (EngineBuilder takes ownership of the module). std::unique_ptr module_owning_ptr_; /* \brief names of the external functions declared in this module */ - Array function_names_; + ffi::Array function_names_; std::string jit_engine_; }; @@ -155,7 +155,7 @@ LLVMModuleNode::~LLVMModuleNode() { module_owning_ptr_.reset(); } -Optional LLVMModuleNode::GetFunction(const String& name) { +ffi::Optional LLVMModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (name == "__tvm_is_system_module") { bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); @@ -189,10 +189,10 @@ Optional LLVMModuleNode::GetFunction(const String& name) { TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name; + ffi::String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name; faddr = reinterpret_cast(GetFunctionAddr(name_with_prefix, *llvm_target)); if (faddr == nullptr) return std::nullopt; - ffi::Module self_strong_ref = GetRef(this); + ffi::Module self_strong_ref = ffi::GetRef(this); return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { TVM_FFI_ICHECK_LT(rv->type_index(), ffi::TypeIndex::kTVMFFIStaticObjectBegin); TVM_FFI_CHECK_SAFE_CALL((*faddr)(nullptr, reinterpret_cast(args.data()), @@ -236,7 +236,8 @@ bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm, llvm::legacy::PassManager* } // namespace -void LLVMModuleNode::WriteToFile(const String& file_name_str, const String& format) const { +void LLVMModuleNode::WriteToFile(const ffi::String& file_name_str, + const ffi::String& format) const { // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; std::string file_name = file_name_str; std::string fmt = runtime::GetFileFormat(file_name, format); @@ -275,7 +276,7 @@ ffi::Bytes LLVMModuleNode::SaveToBytes() const { LOG(FATAL) << "LLVMModule: SaveToBytes not supported"; } -String LLVMModuleNode::InspectSource(const String& format) const { +ffi::String LLVMModuleNode::InspectSource(const ffi::String& format) const { std::string fmt = runtime::GetFileFormat("", format); std::string type_str; llvm::SmallString<256> str; @@ -325,7 +326,8 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { std::string entry_func; - Optional system_lib_prefix = mod->GetAttr(tvm::attr::kSystemLibPrefix); + ffi::Optional system_lib_prefix = + mod->GetAttr(tvm::attr::kSystemLibPrefix); for (auto kv : mod->functions) { if (!kv.second->IsInstance()) { @@ -333,7 +335,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { continue; } auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; @@ -386,7 +388,7 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { Init(std::move(module), std::move(llvm_instance)); } -bool LLVMModuleNode::ImplementsFunction(const String& name) { +bool LLVMModuleNode::ImplementsFunction(const ffi::String& name) { return std::find(function_names_.begin(), function_names_.end(), ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end(); } @@ -445,7 +447,7 @@ void LLVMModuleNode::InitMCJIT() { *ctx_addr = this; } - ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + ffi::Module::VisitContextSymbols([this, &llvm_target](const ffi::String& name, void* symbol) { if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { *ctx_addr = symbol; } @@ -493,7 +495,7 @@ void LLVMModuleNode::InitORCJIT() { } // data layout - String module_name = module_->getModuleIdentifier(); + ffi::String module_name = module_->getModuleIdentifier(); llvm::DataLayout layout(tm->createDataLayout()); ICHECK(layout == module_->getDataLayout()) << "Data layout mismatch between module(" @@ -595,7 +597,7 @@ void LLVMModuleNode::InitORCJIT() { reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } - ffi::Module::VisitContextSymbols([this, &llvm_target](const String& name, void* symbol) { + ffi::Module::VisitContextSymbols([this, &llvm_target](const ffi::String& name, void* symbol) { if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(name, *llvm_target))) { *ctx_addr = symbol; } @@ -658,7 +660,7 @@ static void LLVMReflectionRegister() { refl::GlobalDef() .def("target.build.llvm", [](IRModule mod, Target target) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); n->Init(mod, target); return ffi::Module(n); }) @@ -666,7 +668,7 @@ static void LLVMReflectionRegister() { [](std::string target_str, std::string module_name) -> ffi::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); - auto n = make_object(); + auto n = ffi::make_object(); // Generate a LLVM module from an input target string auto module = std::make_unique(module_name, *llvm_target->GetContext()); llvm_target->SetTargetMetadata(module.get()); @@ -689,9 +691,9 @@ static void LLVMReflectionRegister() { #endif }) .def("target.llvm_get_intrinsic_name", - [](int64_t id) -> String { return llvmGetIntrinName(id); }) + [](int64_t id) -> ffi::String { return llvmGetIntrinName(id); }) .def("target.llvm_get_system_x86_vendor", - []() -> String { + []() -> ffi::String { #if TVM_LLVM_VERSION >= 120 #if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) using namespace llvm::sys::detail::x86; @@ -720,22 +722,22 @@ static void LLVMReflectionRegister() { return llvm_backend.GetVectorWidth(); }) .def("target.llvm_get_system_triple", - []() -> String { return llvm::sys::getDefaultTargetTriple(); }) + []() -> ffi::String { return llvm::sys::getDefaultTargetTriple(); }) .def("target.llvm_get_system_cpu", - []() -> String { return llvm::sys::getHostCPUName().str(); }) + []() -> ffi::String { return llvm::sys::getHostCPUName().str(); }) .def("target.llvm_get_targets", - []() -> Array { + []() -> ffi::Array { auto llvm_instance = std::make_unique(); LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); return llvm_backend.GetAllLLVMTargets(); }) .def("target.llvm_get_cpu_archlist", - [](const Target& target) -> Array { + [](const Target& target) -> ffi::Array { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { if (target->kind->name != "llvm") { - return Array{}; + return ffi::Array{}; } } auto llvm_instance = std::make_unique(); @@ -743,7 +745,7 @@ static void LLVMReflectionRegister() { return llvm_backend.GetAllLLVMTargetArches(); }) .def("target.llvm_get_cpu_features", - [](const Target& target) -> Map { + [](const Target& target) -> ffi::Map { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -756,7 +758,7 @@ static void LLVMReflectionRegister() { return llvm_backend.GetAllLLVMCpuFeatures(); }) .def("target.llvm_cpu_has_feature", - [](const String feature, const Target& target) -> bool { + [](const ffi::String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -771,7 +773,7 @@ static void LLVMReflectionRegister() { return has_feature; }) .def("target.target_has_feature", - [](const String feature, const Target& target) -> bool { + [](const ffi::String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target if (target.defined()) { @@ -786,7 +788,7 @@ static void LLVMReflectionRegister() { .def("target.llvm_version_major", []() -> int { return TVM_LLVM_VERSION / 10; }) .def("ffi.Module.load_from_file.ll", [](std::string filename, std::string fmt) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); n->SetJITEngine("orcjit"); n->LoadIR(filename); return ffi::Module(n); @@ -801,7 +803,7 @@ static void LLVMReflectionRegister() { .def("codegen.codegen_blob", [](std::string data, bool system_lib, std::string llvm_target_string, std::string c_symbol_prefix) -> ffi::Module { - auto n = make_object(); + auto n = ffi::make_object(); auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, llvm_target_string); std::unique_ptr blob = diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 6072a483877c..7b1356118d16 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -131,7 +131,7 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { CodeGenCUDA cg; cg.Init(output_ssa); - Map functions; + ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); @@ -177,6 +177,6 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.cuda", BuildCUDA); }); -TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", String); +TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", ffi::String); } // namespace codegen } // namespace tvm diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index 65bd6a66aedb..4edff94baeda 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -35,8 +35,8 @@ namespace target { namespace parsers { namespace aprofile { -double GetArchVersion(Array mattr) { - for (const String& attr : mattr) { +double GetArchVersion(ffi::Array mattr) { + for (const ffi::String& attr : mattr) { std::string attr_string = attr; size_t attr_len = attr_string.size(); if (attr_len >= 4 && attr_string.substr(0, 2) == "+v" && attr_string.back() == 'a') { @@ -47,14 +47,14 @@ double GetArchVersion(Array mattr) { return 0.0; } -double GetArchVersion(Optional> attr) { +double GetArchVersion(ffi::Optional> attr) { if (!attr) { return false; } return GetArchVersion(attr.value()); } -bool IsAArch32(Optional mtriple, Optional mcpu) { +bool IsAArch32(ffi::Optional mtriple, ffi::Optional mcpu) { if (mtriple) { bool is_mprofile = mcpu && support::StartsWith(mcpu.value(), "cortex-m"); return support::StartsWith(mtriple.value(), "arm") && !is_mprofile; @@ -62,7 +62,7 @@ bool IsAArch32(Optional mtriple, Optional mcpu) { return false; } -bool IsAArch64(Optional mtriple) { +bool IsAArch64(ffi::Optional mtriple) { if (mtriple) { return support::StartsWith(mtriple.value(), "aarch64"); } @@ -70,28 +70,32 @@ bool IsAArch64(Optional mtriple) { } bool IsArch(TargetJSON attrs) { - Optional mtriple = Downcast>(attrs.Get("mtriple").value_or(nullptr)); - Optional mcpu = Downcast>(attrs.Get("mcpu").value_or(nullptr)); + ffi::Optional mtriple = + Downcast>(attrs.Get("mtriple").value_or(nullptr)); + ffi::Optional mcpu = + Downcast>(attrs.Get("mcpu").value_or(nullptr)); return IsAArch32(mtriple, mcpu) || IsAArch64(mtriple); } -bool CheckContains(Array array, String predicate) { - return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; }); +bool CheckContains(ffi::Array array, ffi::String predicate) { + return std::any_of(array.begin(), array.end(), [&](ffi::String var) { return var == predicate; }); } static TargetFeatures GetFeatures(TargetJSON target) { #ifdef TVM_LLVM_VERSION - String kind = Downcast(target.Get("kind").value()); + ffi::String kind = Downcast(target.Get("kind").value()); ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'"; - Optional mtriple = Downcast>(target.Get("mtriple").value_or(nullptr)); - Optional mcpu = Downcast>(target.Get("mcpu").value_or(nullptr)); + ffi::Optional mtriple = + Downcast>(target.Get("mtriple").value_or(nullptr)); + ffi::Optional mcpu = + Downcast>(target.Get("mcpu").value_or(nullptr)); // Check that LLVM has been compiled with the correct target support auto llvm_instance = std::make_unique(); - codegen::LLVMTargetInfo llvm_backend(*llvm_instance, {{"kind", String("llvm")}}); - Array targets = llvm_backend.GetAllLLVMTargets(); + codegen::LLVMTargetInfo llvm_backend(*llvm_instance, {{"kind", ffi::String("llvm")}}); + ffi::Array targets = llvm_backend.GetAllLLVMTargets(); if ((IsAArch64(mtriple) && !CheckContains(targets, "aarch64")) || (IsAArch32(mtriple, mcpu) && !CheckContains(targets, "arm"))) { LOG(WARNING) << "Cannot parse target features for target: " << target @@ -100,9 +104,9 @@ static TargetFeatures GetFeatures(TargetJSON target) { } codegen::LLVMTargetInfo llvm_target(*llvm_instance, target); - Map features = llvm_target.GetAllLLVMCpuFeatures(); + ffi::Map features = llvm_target.GetAllLLVMCpuFeatures(); - auto has_feature = [features](const String& feature) { + auto has_feature = [features](const ffi::String& feature) { return features.find(feature) != features.end(); }; @@ -120,15 +124,15 @@ static TargetFeatures GetFeatures(TargetJSON target) { return {}; } -static Array MergeKeys(Optional> existing_keys) { - const Array kExtraKeys = {"arm_cpu", "cpu"}; +static ffi::Array MergeKeys(ffi::Optional> existing_keys) { + const ffi::Array kExtraKeys = {"arm_cpu", "cpu"}; if (!existing_keys) { return kExtraKeys; } - Array keys = existing_keys.value(); - for (String key : kExtraKeys) { + ffi::Array keys = existing_keys.value(); + for (ffi::String key : kExtraKeys) { if (std::find(keys.begin(), keys.end(), key) == keys.end()) { keys.push_back(key); } @@ -138,7 +142,8 @@ static Array MergeKeys(Optional> existing_keys) { TargetJSON ParseTarget(TargetJSON target) { target.Set("features", GetFeatures(target)); - target.Set("keys", MergeKeys(Downcast>>(target.Get("keys")))); + target.Set("keys", + MergeKeys(Downcast>>(target.Get("keys")))); return target; } diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index ee9bf814d323..ac187a03bbdc 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -28,24 +28,24 @@ namespace target { namespace parsers { namespace cpu { -Optional DetectSystemTriple() { +ffi::Optional DetectSystemTriple() { #ifdef TVM_LLVM_VERSION auto pf = tvm::ffi::Function::GetGlobal("target.llvm_get_system_triple"); ICHECK(pf.has_value()) << "The target llvm_get_system_triple was not found, " "please compile with USE_LLVM = ON"; - return (*pf)().cast(); + return (*pf)().cast(); #endif return {}; } TargetJSON ParseTarget(TargetJSON target) { - String kind = Downcast(target.Get("kind").value()); - Optional mtriple = Downcast>(target.Get("mtriple")); - Optional mcpu = Downcast>(target.Get("mcpu")); + ffi::String kind = Downcast(target.Get("kind").value()); + ffi::Optional mtriple = Downcast>(target.Get("mtriple")); + ffi::Optional mcpu = Downcast>(target.Get("mcpu")); // Try to fill in the blanks by detecting target information from the system if (kind == "llvm" && !mtriple.has_value() && !mcpu.has_value()) { - String system_triple = DetectSystemTriple().value_or(""); + ffi::String system_triple = DetectSystemTriple().value_or(""); target.Set("mtriple", system_triple); } diff --git a/src/target/parsers/mprofile.cc b/src/target/parsers/mprofile.cc index acd878c667c0..bd3bf5848a68 100644 --- a/src/target/parsers/mprofile.cc +++ b/src/target/parsers/mprofile.cc @@ -41,7 +41,7 @@ static const char* dspCPUs[] = {"cortex-m55", "cortex-m4", "cortex-m7", static const char* mveCPUs[] = {"cortex-m55", "cortex-m85"}; template -static inline bool MatchesCpu(Optional mcpu, const Container& cpus) { +static inline bool MatchesCpu(ffi::Optional mcpu, const Container& cpus) { if (!mcpu) { return false; } @@ -50,31 +50,32 @@ static inline bool MatchesCpu(Optional mcpu, const Container& cpus) { return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != std::end(cpus); } -static inline bool HasFlag(String attr, std::string flag) { +static inline bool HasFlag(ffi::String attr, std::string flag) { std::string attr_str = attr; return attr_str.find(flag) != std::string::npos; } -static inline bool HasFlag(Optional attr, std::string flag) { +static inline bool HasFlag(ffi::Optional attr, std::string flag) { if (!attr) { return false; } return HasFlag(attr.value(), flag); } -static inline bool HasFlag(Optional> attr, std::string flag) { +static inline bool HasFlag(ffi::Optional> attr, std::string flag) { if (!attr) { return false; } - Array attr_array = attr.value(); + ffi::Array attr_array = attr.value(); - auto matching_attr = std::find_if(attr_array.begin(), attr_array.end(), - [flag](String attr_str) { return HasFlag(attr_str, flag); }); + auto matching_attr = + std::find_if(attr_array.begin(), attr_array.end(), + [flag](ffi::String attr_str) { return HasFlag(attr_str, flag); }); return matching_attr != attr_array.end(); } bool IsArch(TargetJSON attrs) { - Optional mcpu = Downcast>(attrs.Get("mcpu")); + ffi::Optional mcpu = Downcast>(attrs.Get("mcpu")); if (mcpu) { bool matches_base = MatchesCpu(mcpu, baseCPUs); bool matches_dsp = MatchesCpu(mcpu, dspCPUs); @@ -85,8 +86,9 @@ bool IsArch(TargetJSON attrs) { } static TargetFeatures GetFeatures(TargetJSON target) { - Optional mcpu = Downcast>(target.Get("mcpu")); - Optional> mattr = Downcast>>(target.Get("mattr")); + ffi::Optional mcpu = Downcast>(target.Get("mcpu")); + ffi::Optional> mattr = + Downcast>>(target.Get("mattr")); bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve"); bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp"); @@ -104,15 +106,15 @@ static TargetFeatures GetFeatures(TargetJSON target) { return kNoExt; } -static Array MergeKeys(Optional> existing_keys) { - const Array kExtraKeys = {"arm_cpu", "cpu"}; +static ffi::Array MergeKeys(ffi::Optional> existing_keys) { + const ffi::Array kExtraKeys = {"arm_cpu", "cpu"}; if (!existing_keys) { return kExtraKeys; } - Array keys = existing_keys.value(); - for (String key : kExtraKeys) { + ffi::Array keys = existing_keys.value(); + for (ffi::String key : kExtraKeys) { if (std::find(keys.begin(), keys.end(), key) == keys.end()) { keys.push_back(key); } @@ -122,7 +124,8 @@ static Array MergeKeys(Optional> existing_keys) { TargetJSON ParseTarget(TargetJSON target) { target.Set("features", GetFeatures(target)); - target.Set("keys", MergeKeys(Downcast>>(target.Get("keys")))); + target.Set("keys", + MergeKeys(Downcast>>(target.Get("keys")))); return target; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 49b444e49516..ddd904c555a2 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -76,7 +76,7 @@ void CodeGenC::ReserveKeywordsAsUnique() { name_supply_->ReserveName("return"); } -void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func, +void CodeGenC::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { PrintFuncPrefix(os); PrintType(func->ret_type, os); @@ -136,8 +136,8 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { return; } - auto function_name = [&]() -> String { - if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto function_name = [&]() -> ffi::String { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = global_symbol.value(); ICHECK(!func_name_supply_->ContainsName(name)) << "Function " << gvar << " must use global symbol " << name @@ -159,7 +159,7 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { fwd_decl_stream << ";\n"; } -String CodeGenC::GetFunctionName(const GlobalVar& gvar) { +ffi::String CodeGenC::GetFunctionName(const GlobalVar& gvar) { auto it = internal_functions_.find(gvar); ICHECK(it != internal_functions_.end()) << "Attempted to find name of " << gvar @@ -592,8 +592,9 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->a, os); } -void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os) { // NOLINT(*) os << global_symbol << "("; for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { this->PrintExpr(args[i], os); @@ -614,12 +615,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); - this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), func->value, op->args, true, os); // If the call_extern refers to an function within the IRModule, then // the forward declaration is already provided from DeclareFunction. if (!func_name_supply_->ContainsName(func->value)) { - Array arg_types; + ffi::Array arg_types; for (size_t i = 1; i < op->args.size(); i++) { arg_types.push_back(GetType(op->args[i])); } @@ -628,7 +629,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. - this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + this->PrintCallExtern(GetType(ffi::GetRef(op)), op_attr_global_symbol_[call_op], op->args, false, os); } else if (op->op.same_as(builtin::bitwise_and())) { PrintBinaryIntrinsic(op, " & ", os, this); @@ -732,7 +733,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (auto opt = op->op.as()) { auto gvar = opt.value(); auto callee_name = GetFunctionName(gvar); - PrintCallExtern(GetType(GetRef(op)), callee_name, op->args, false, os); + PrintCallExtern(GetType(ffi::GetRef(op)), callee_name, op->args, false, os); } else { LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " << "nor a GlobalVar reference to another function in the IRModule"; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 02cb4cd9a779..920e6a13a04e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -90,7 +90,7 @@ class CodeGenC : public ExprFunctor, * \param gvar The GlobalVar of the function * \returns The string name of the function */ - String GetFunctionName(const GlobalVar& gvar); + ffi::String GetFunctionName(const GlobalVar& gvar); /*! * \brief Finalize the compilation and return the code. @@ -131,7 +131,7 @@ class CodeGenC : public ExprFunctor, * * \param os The output stream */ - virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + virtual void PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os); /*! @@ -271,8 +271,8 @@ class CodeGenC : public ExprFunctor, * \param ret_type The return type of the function * \param os The output stream. */ - virtual void GenerateForwardFunctionDeclarations(String global_symbol, - const Array& arg_types, + virtual void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) {} /*! @@ -283,8 +283,9 @@ class CodeGenC : public ExprFunctor, * \param skip_first_arg Whether to skip the first arguments. * \param os The output stream. */ - virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os); // NOLINT(*) + virtual void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os); // NOLINT(*) /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -339,7 +340,7 @@ class CodeGenC : public ExprFunctor, * functions, this is the name of the function's GlobalVar, possibly * altered to prevent duplicate names. */ - std::unordered_map internal_functions_; + std::unordered_map internal_functions_; /* \brief Name supply to generate unique function names */ NameSupply func_name_supply_; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index a4cbc46f0cca..6a27036d6e6c 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -67,7 +67,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, bool emit_fwd_func_decl) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (global_symbol) { function_names_.push_back(global_symbol.value()); } @@ -90,8 +90,8 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, } } -void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, - const Array& arg_types, +void CodeGenCHost::GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) { if (!emit_fwd_func_decl_) { return; @@ -363,9 +363,9 @@ ffi::Module BuildCHost(IRModule mod, Target target) { bool emit_fwd_func_decl = true; std::unordered_set devices; - if (mod->GetAttr>("device_contexts") != nullptr) { - Map device_contexts = - mod->GetAttr>("device_contexts").value(); + if (mod->GetAttr>("device_contexts") != nullptr) { + ffi::Map device_contexts = + mod->GetAttr>("device_contexts").value(); for (auto const& context : device_contexts) { devices.insert(context.second.data()); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 1c7e65b3b2cb..feb0f715d847 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -70,16 +70,17 @@ class CodeGenCHost : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) - void GenerateForwardFunctionDeclarations(String global_symbol, const Array& arg_types, + void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array& arg_types, const Type& ret_type) override; - Array GetFunctionNames() { return function_names_; } + ffi::Array GetFunctionNames() { return function_names_; } private: std::string module_name_; /* \brief mapping global packed func to the unique name */ std::unordered_map declared_globals_; /* \brief names of the functions declared in this module */ - Array function_names_; + ffi::Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; /*! \brief whether to emit forwared function declarations in the resulting C code */ diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 951415c3b353..4454dd319768 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -140,7 +140,7 @@ void CodeGenCUDA::Init(bool output_ssa) { ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::PrintFunctionSignature(const String& function_name, const PrimFunc& func, +void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { auto calling_conv = func->GetAttr(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault)); @@ -866,8 +866,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array& args, bool skip_first_arg, + std::ostream& os) { // NOLINT(*) DataType ret_dtype = GetRuntimeDataType(ret_type); if (ret_dtype.is_fixed_length_vector()) { // diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 6441f87909db..02fc0603a52f 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -46,7 +46,7 @@ class CodeGenCUDA final : public CodeGenC { enable_fp4_ || need_math_constants_h_ || need_mma_h_); } // override behavior - void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + void PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; @@ -74,7 +74,7 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const AttrStmtNode* op) final; protected: - void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + void PrintCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg, std::ostream& os) final; // NOLINT(*) private: diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index dc019c28a7a0..eab7646ee53d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -77,7 +77,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { name_supply_->FreshName("v_"); // add to alloc buffer type. - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; @@ -149,7 +149,8 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams).value(); + auto launch_params = + func->GetAttr>(tir::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -359,7 +360,7 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) CHECK(!op->op.as()) << "CodegenMetal does not support inter-function calls, " - << "but expression " << GetRef(op) << " calls PrimFunc " << op->op; + << "but expression " << ffi::GetRef(op) << " calls PrimFunc " << op->op; auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { ICHECK(col->IsInstance() && row->IsInstance()) << "Only constant shape is supported for simdgroup matrix, but got " << col << "x" << row; @@ -442,7 +443,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; - auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()); std::string func_name = global_symbol.value(); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 1342464665f3..4f4f763a74ae 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -475,10 +475,10 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { // Enable atomics extension if used. if (func->value == "atomic_add" && op->dtype.is_float()) { enable_atomics_ = true; - this->PrintCallExtern(GetType(GetRef(op)), "atomic_add_float_emu", op->args, true, - os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, + true, os); } else if (func->value == "nearbyint") { - this->PrintCallExtern(GetType(GetRef(op)), "round", op->args, true, os); + this->PrintCallExtern(GetType(ffi::GetRef(op)), "round", op->args, true, os); } else { if (func->value == "atomic_add") { enable_atomics_ = true; @@ -635,7 +635,7 @@ void CodeGenOpenCL::SetTextureScope( ffi::Module BuildOpenCL(IRModule mod, Target target) { #if TVM_ENABLE_SPIRV - Optional device = target->GetAttr("device"); + ffi::Optional device = target->GetAttr("device"); if (device && device.value() == "spirv") { auto [smap, spirv_text] = LowerToSPIRV(mod, target); return runtime::OpenCLModuleCreate(smap, spirv_text, ExtractFuncInfo(mod)); @@ -644,7 +644,7 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { bool output_ssa = false; - Map functions; + ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto prim_func = Downcast(base_func); @@ -679,12 +679,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("target.build.opencl", BuildOpenCL); }); -String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { +ffi::String DeviceScopeCompatibilityFromTarget(Target target, ffi::String memory_scope) { auto prototype_keys = target->GetKeys(); bool is_adreno = std::find(prototype_keys.begin(), prototype_keys.end(), "adreno") != prototype_keys.end(); if (is_adreno) { - return String("global"); + return ffi::String("global"); } return memory_scope; } diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 97828249ce24..104bf2cbdc34 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -150,9 +150,9 @@ ffi::Module SourceModuleCreate(std::string code, std::string fmt); * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ -ffi::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, - const Array& const_vars = {}); +ffi::Module CSourceModuleCreate(const ffi::String& code, const ffi::String& fmt, + const ffi::Array& func_names, + const ffi::Array& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. @@ -164,8 +164,8 @@ ffi::Module CSourceModuleCreate(const String& code, const String& fmt, * \return The wrapped module. */ ffi::Module CreateMetadataModule(const std::unordered_map& params, - ffi::Module target_module, const Array& ext_modules, - Target target); + ffi::Module target_module, + const ffi::Array& ext_modules, Target target); /*! * \brief Create a source module for viewing and limited saving for device. diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 28d158c3c21e..374402742271 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -63,7 +63,7 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { private: void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); - Var buffer_var = GetRef(op); + Var buffer_var = ffi::GetRef(op); if (buffer_var.dtype().is_handle()) { info_.write_access_set.insert(buffer_var); } @@ -137,7 +137,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; @@ -233,7 +233,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re << "var " << val_pod_args << " : " << type_pod_args << ";\n\n"; // setup thread tags and param access in launch param tags; - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& thread_tag : opt.value()) { func_info.launch_param_tags.push_back(thread_tag); } @@ -716,7 +716,7 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; } @@ -729,7 +729,7 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { return ffi::Bytes(buffer); } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (format == "func_info") { std::ostringstream stream; dmlc::JSONWriter(&stream).Write(fmap_); @@ -770,7 +770,7 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); @@ -780,7 +780,7 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { smap[f_name] = code; } - auto n = make_object(smap, fmap); + auto n = ffi::make_object(smap, fmap); return ffi::Module(n); } diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index e762bde69f4d..56b575cc6c38 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -144,7 +144,7 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; + ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index b7561e86715e..e74c63a79ba3 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -48,7 +48,7 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { const CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - Array metal_args{{call->args[1], call->args[2]}}; + ffi::Array metal_args{{call->args[1], call->args[2]}}; return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); } diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index bd9e148b187d..ea3a1c58bc3f 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -109,7 +109,8 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { arith::Analyzer analyzer; ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; + ffi::Array opencl_args{ + {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; return Call(call->dtype, builtin::call_pure_extern(), opencl_args); } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 6638ed0e05a5..a0ae36691fa8 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -56,14 +56,14 @@ class SourceModuleNode : public ffi::ModuleObj { SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} const char* kind() const final { return "source"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; } - String InspectSource(const String& format) const final { return code_; } + ffi::String InspectSource(const ffi::String& format) const final { return code_; } - Array GetWriteFormats() const override { return {fmt_}; } + ffi::Array GetWriteFormats() const override { return {fmt_}; } protected: std::string code_; @@ -71,7 +71,7 @@ class SourceModuleNode : public ffi::ModuleObj { }; ffi::Module SourceModuleCreate(std::string code, std::string fmt) { - auto n = make_object(code, fmt); + auto n = ffi::make_object(code, fmt); return ffi::Module(n); } @@ -79,14 +79,15 @@ ffi::Module SourceModuleCreate(std::string code, std::string fmt) { class CSourceModuleNode : public ffi::ModuleObj { public: CSourceModuleNode(const std::string& code, const std::string& fmt, - const Array& func_names, const Array& const_vars) + const ffi::Array& func_names, + const ffi::Array& const_vars) : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) { if (fmt_.empty()) fmt_ = "c"; } const char* kind() const final { return "c"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); // Currently c-source module is used as demonstration purposes with binary metadata module // that expects get_symbol interface. When c-source module is used as external module, it @@ -106,9 +107,9 @@ class CSourceModuleNode : public ffi::ModuleObj { } } - String InspectSource(const String& format) const final { return code_; } + ffi::String InspectSource(const ffi::String& format) const final { return code_; } - Array GetWriteFormats() const override { return {fmt_}; } + ffi::Array GetWriteFormats() const override { return {fmt_}; } ffi::Bytes SaveToBytes() const final { std::string buffer; @@ -138,17 +139,17 @@ class CSourceModuleNode : public ffi::ModuleObj { CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed"; CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed"; - Array func_names; - for (auto func_name : tmp_func_names) func_names.push_back(String(func_name)); + ffi::Array func_names; + for (auto func_name : tmp_func_names) func_names.push_back(ffi::String(func_name)); - Array const_vars; - for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var)); + ffi::Array const_vars; + for (auto const_var : tmp_const_vars) const_vars.push_back(ffi::String(const_var)); - auto n = make_object(code, fmt, func_names, const_vars); + auto n = ffi::make_object(code, fmt, func_names, const_vars); return ffi::Module(n); } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { @@ -163,21 +164,22 @@ class CSourceModuleNode : public ffi::ModuleObj { return ffi::Module::kBinarySerializable | ffi::Module::kCompilationExportable; } - bool ImplementsFunction(const String& name) final { + bool ImplementsFunction(const ffi::String& name) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); } protected: std::string code_; std::string fmt_; - Array const_vars_; - Array func_names_; + ffi::Array const_vars_; + ffi::Array func_names_; }; -ffi::Module CSourceModuleCreate(const String& code, const String& fmt, - const Array& func_names, const Array& const_vars) { - auto n = make_object(code.operator std::string(), fmt.operator std::string(), - func_names, const_vars); +ffi::Module CSourceModuleCreate(const ffi::String& code, const ffi::String& fmt, + const ffi::Array& func_names, + const ffi::Array& const_vars) { + auto n = ffi::make_object(code.operator std::string(), + fmt.operator std::string(), func_names, const_vars); return ffi::Module(n); } @@ -210,12 +212,12 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { std::function fget_source) : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { if (fget_source_ != nullptr) { return fget_source_(format); } else { @@ -227,7 +229,7 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } - void WriteToFile(const String& file_name, const String& format) const final { + void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -257,7 +259,7 @@ ffi::Module DeviceSourceModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source) { - auto n = make_object(data, fmt, fmap, type_key, fget_source); + auto n = ffi::make_object(data, fmt, fmap, type_key, fget_source); return ffi::Module(n); } @@ -265,9 +267,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.SourceModuleCreate", SourceModuleCreate) - .def("runtime.CSourceModuleCreate", [](String code, String fmt, - Optional> func_names, - Optional> const_vars) { + .def("runtime.CSourceModuleCreate", [](ffi::String code, ffi::String fmt, + ffi::Optional> func_names, + ffi::Optional> const_vars) { return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); }); }); diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 3010b74dd976..a689a550c4aa 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -34,10 +34,10 @@ namespace codegen { namespace spirv { // num_signature means number of arguments used to query signature template -PrimExpr CallGLSLIntrin(PrimExpr e, const Array& args) { +PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); - Array cargs; + ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index a17a694da4dd..91b45b85bbd0 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -94,8 +94,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { supports_integer_dot_product = target->GetAttr("supports_integer_dot_product").value(); } // Check whether integer dot product is enabled in mattr. - if (const Optional>& v = target->GetAttr>("mattr")) { - for (const String& s : v.value()) { + if (const ffi::Optional>& v = + target->GetAttr>("mattr")) { + for (const ffi::String& s : v.value()) { if (s.compare("+dotprod") == 0) { supports_integer_dot_product = true; break; diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index f0226466f625..a4cec2c0fd65 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -129,7 +129,7 @@ std::pair, std::string> Lo auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/tag.cc b/src/target/tag.cc index f305c84e09a4..8835ea64c9a3 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -45,11 +45,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ using TargetTagRegistry = AttrRegistry; -TargetTagRegEntry& TargetTagRegEntry::RegisterOrGet(const String& target_tag_name) { +TargetTagRegEntry& TargetTagRegEntry::RegisterOrGet(const ffi::String& target_tag_name) { return TargetTagRegistry::Global()->RegisterOrGet(target_tag_name); } -Optional TargetTag::Get(const String& target_tag_name) { +ffi::Optional TargetTag::Get(const ffi::String& target_tag_name) { const TargetTagRegEntry* reg = TargetTagRegistry::Global()->Get(target_tag_name); if (reg == nullptr) { return std::nullopt; @@ -57,15 +57,15 @@ Optional TargetTag::Get(const String& target_tag_name) { return Target(reg->tag_->config); } -Map TargetTag::ListTags() { - Map result; - for (const String& tag : TargetTagRegistry::Global()->ListAllNames()) { +ffi::Map TargetTag::ListTags() { + ffi::Map result; + for (const ffi::String& tag : TargetTagRegistry::Global()->ListAllNames()) { result.Set(tag, TargetTag::Get(tag).value()); } return result; } -Target TargetTag::AddTag(String name, Map config, bool override) { +Target TargetTag::AddTag(ffi::String name, ffi::Map config, bool override) { TargetTagRegEntry& tag = TargetTagRegEntry::RegisterOrGet(name).set_name(); ICHECK(override || tag.tag_->config.empty()) << "Tag \"" << name << "\" has been previously defined as: " << tag.tag_->config; @@ -77,73 +77,78 @@ Target TargetTag::AddTag(String name, Map config, bool overrid #if TVM_LLVM_HAS_AARCH64_TARGET TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") - .set_config({{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a72")}, - {"mattr", Array{"+neon"}}, + .set_config({{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a72")}, + {"mattr", ffi::Array{"+neon"}}, {"num-cores", 4}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a72")}, - {"mattr", Array{"+neon"}}, - {"num-cores", 4}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a72")}, + {"mattr", ffi::Array{"+neon"}}, + {"num-cores", 4}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_72")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_72")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 8}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("carmel")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("carmel")}, - {"num-cores", 6}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("carmel")}, + {"num-cores", 6}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 8}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a78")}, + {"num-cores", 8}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") - .set_config({{"kind", String("cuda")}, - {"arch", String("sm_87")}, + .set_config({{"kind", ffi::String("cuda")}, + {"arch", ffi::String("sm_87")}, {"max_shared_memory_per_block", 49152}, {"max_threads_per_block", 1024}, {"thread_warp_size", 32}, {"registers_per_block", 65536}, - {"host", Map{{"kind", String("llvm")}, - {"mtriple", String("aarch64-linux-gnu")}, - {"mcpu", String("cortex-a78")}, - {"num-cores", 12}}}}); + {"host", + ffi::Map{{"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("aarch64-linux-gnu")}, + {"mcpu", ffi::String("cortex-a78")}, + {"num-cores", 12}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ - {"kind", String("cuda")}, \ - {"keys", Array{"cuda", "gpu"}}, \ - {"arch", String(Arch)}, \ + {"kind", ffi::String("cuda")}, \ + {"keys", ffi::Array{"cuda", "gpu"}}, \ + {"arch", ffi::String(Arch)}, \ {"max_shared_memory_per_block", SharedMem}, \ {"max_threads_per_block", 1024}, \ {"thread_warp_size", 32}, \ @@ -421,10 +426,10 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); #undef TVM_REGISTER_CUDA_TAG -#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ - TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ - {"keys", Array{"x86", "cpu"}}, \ - {"mcpu", String(Arch)}, \ +#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch) \ + TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", ffi::String("llvm")}, \ + {"keys", ffi::Array{"x86", "cpu"}}, \ + {"mcpu", ffi::String(Arch)}, \ {"num-cores", Cores}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); @@ -439,25 +444,25 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #undef TVM_REGISTER_TAG_AWS_C5 #if TVM_LLVM_VERSION >= 190 -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ - {"host", Map{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-m4")}}}}); +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", ffi::String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ + {"host", ffi::Map{{"kind", ffi::String("llvm")}, \ + {"mtriple", ffi::String("arm64-apple-macos")}, \ + {"mcpu", ffi::String("apple-m4")}}}}); #else -#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ - TVM_REGISTER_TARGET_TAG(Name).set_config( \ - {{"kind", String("metal")}, \ - {"max_threads_per_block", ThreadsPerBlock}, \ - {"max_shared_memory_per_block", SharedMem}, \ - {"thread_warp_size", WarpSize}, \ - {"host", Map{{"kind", String("llvm")}, \ - {"mtriple", String("arm64-apple-macos")}, \ - {"mcpu", String("apple-latest")}}}}); +#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ + TVM_REGISTER_TARGET_TAG(Name).set_config( \ + {{"kind", ffi::String("metal")}, \ + {"max_threads_per_block", ThreadsPerBlock}, \ + {"max_shared_memory_per_block", SharedMem}, \ + {"thread_warp_size", WarpSize}, \ + {"host", ffi::Map{{"kind", ffi::String("llvm")}, \ + {"mtriple", ffi::String("arm64-apple-macos")}, \ + {"mcpu", ffi::String("apple-latest")}}}}); #endif #if TVM_LLVM_HAS_AARCH64_TARGET diff --git a/src/target/target.cc b/src/target/target.cc index 1c56fa5bd210..b2c3e8fe8c1b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -49,25 +49,27 @@ class TargetInternal { public: static void EnterScope(Target target) { target.EnterWithScope(); } static void ExitScope(Target target) { target.ExitWithScope(); } - static Map Export(Target target) { return target->Export(); } + static ffi::Map Export(Target target) { return target->Export(); } static const TargetKindNode::ValueTypeInfo& FindTypeInfo(const TargetKind& kind, const std::string& key); - static Optional StringifyAttrsToRaw(const Map& attrs); + static ffi::Optional StringifyAttrsToRaw( + const ffi::Map& attrs); static Any ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info); static Any ParseType(const Any& obj, const TargetKindNode::ValueTypeInfo& info); - static ObjectPtr FromString(const String& tag_or_config_or_target_str); - static ObjectPtr FromConfigString(const String& config_str); - static ObjectPtr FromRawString(const String& target_str); - static ObjectPtr FromConfig(Map config); + static ObjectPtr FromString(const ffi::String& tag_or_config_or_target_str); + static ObjectPtr FromConfigString(const ffi::String& config_str); + static ObjectPtr FromRawString(const ffi::String& target_str); + static ObjectPtr FromConfig(ffi::Map config); static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv); static Target WithHost(const Target& target, const Target& target_host) { - ObjectPtr n = make_object(*target.get()); + ObjectPtr n = ffi::make_object(*target.get()); n->host = target_host; return (Target)n; } private: - static std::unordered_map QueryDevice(int device_id, const TargetNode* target); + static std::unordered_map QueryDevice(int device_id, + const TargetNode* target); static bool IsQuoted(const std::string& str); static std::string Quote(const std::string& str); static std::string JoinString(const std::vector& array, char separator); @@ -91,8 +93,8 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -static std::vector DeduplicateKeys(const std::vector& keys) { - std::vector new_keys; +static std::vector DeduplicateKeys(const std::vector& keys) { + std::vector new_keys; for (size_t i = 0; i < keys.size(); ++i) { bool found = false; for (size_t j = 0; j < i; ++j) { @@ -118,8 +120,8 @@ static T ObjTypeCheck(const Any& obj, const std::string& expected_type) { return opt.value(); } -static TargetKind GetTargetKind(const String& name) { - Optional kind = TargetKind::Get(name); +static TargetKind GetTargetKind(const ffi::String& name) { + ffi::Optional kind = TargetKind::Get(name); if (!kind.defined()) { TVM_FFI_THROW(TypeError) << "Target kind \"" + name + "\" is not defined"; } @@ -228,7 +230,7 @@ std::vector TargetInternal::SplitString(const std::string& str, cha } std::string TargetInternal::Interpret(const std::string& str) { - // String interpretation deals with quotes (') and escapes(\). + // ffi::String interpretation deals with quotes (') and escapes(\). // - An escape character must be followed by another character forming an // "escape sequence". (Trailing escape is not allowed.) An escape prevents // interpretation of the character that follows. This happens regardless of @@ -386,9 +388,9 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu auto end = interp_str.find_last_not_of(' '); if (start == std::string::npos || end == std::string::npos) { // The whole string is made of spaces. - return String(); + return ffi::String(); } - return String(interp_str.substr(start, (end - start + 1))); + return ffi::String(interp_str.substr(start, (end - start + 1))); } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target @@ -405,7 +407,7 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu throw Error(e.kind(), e.message() + index, e.traceback()); } } - return Array(result); + return ffi::Array(result); } TVM_FFI_THROW(TypeError) << "Unsupported type \"" + info.type_key << "\" for parsing from string: " + interp_str; @@ -420,12 +422,12 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf return ObjTypeCheck(obj, "bool"); } else if (info.type_index == ffi::TypeIndex::kTVMFFIStr) { // Parsing string - return ObjTypeCheck(obj, "String"); + return ObjTypeCheck(obj, "String"); } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); - } else if (auto str = obj.try_cast()) { + } else if (auto str = obj.try_cast()) { return Target(TargetInternal::FromString(str.value())); } else if (const auto* ptr = obj.as()) { for (const auto& kv : *ptr) { @@ -434,7 +436,7 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf << "Target object requires key of dict to be str, but get: " << kv.first.GetTypeKey(); } } - Map config = GetRef>(ptr); + ffi::Map config = ffi::GetRef>(ptr); return Target(TargetInternal::FromConfig({config.begin(), config.end()})); } TVM_FFI_THROW(TypeError) << "Expect type 'dict' or 'str' to construct Target, but get: " + @@ -451,7 +453,7 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf throw Error(e.kind(), index + e.message(), e.traceback()); } } - return Array(result); + return ffi::Array(result); } else if (info.type_index == ffi::MapObj::RuntimeTypeIndex()) { // Parsing map const auto* map = ObjTypeCheck(obj, "Map"); @@ -472,7 +474,7 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf } result[key] = val; } - return Map(result); + return ffi::Map(result); } if (info.type_index != obj.type_index()) { TVM_FFI_THROW(TypeError) << "Parsing type \"" << info.type_key @@ -489,7 +491,7 @@ std::string TargetInternal::StringifyAtomicType(const Any& obj) { return std::to_string(obj.cast()); } else if (obj.type_index() == ffi::TypeIndex::kTVMFFIInt) { return std::to_string(obj.cast()); - } else if (auto opt_str = obj.as()) { + } else if (auto opt_str = obj.as()) { std::string s = opt_str.value(); auto u = Uninterpret(s); if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { @@ -516,9 +518,10 @@ std::string TargetInternal::StringifyArray(const ffi::ArrayObj& array) { return JoinString(elements, ','); } -Optional TargetInternal::StringifyAttrsToRaw(const Map& attrs) { +ffi::Optional TargetInternal::StringifyAttrsToRaw( + const ffi::Map& attrs) { std::ostringstream os; - std::vector keys; + std::vector keys; for (const auto& kv : attrs) { keys.push_back(kv.first); } @@ -531,7 +534,7 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map // skip undefined attrs if (obj == nullptr) continue; if (const auto* array = obj.as()) { - value = String(StringifyArray(*array)); + value = ffi::String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); } @@ -539,7 +542,7 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map result.push_back("-" + key + "=" + value); } } - return String(JoinString(result, ' ')); + return ffi::String(JoinString(result, ' ')); } const std::string& TargetNode::str() const { @@ -549,7 +552,7 @@ const std::string& TargetNode::str() const { if (!this->keys.empty()) { os << " -keys="; bool is_first = true; - for (const String& s : keys) { + for (const ffi::String& s : keys) { if (is_first) { is_first = false; } else { @@ -558,7 +561,7 @@ const std::string& TargetNode::str() const { os << s; } } - if (Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { + if (ffi::Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } @@ -569,7 +572,7 @@ const std::string& TargetNode::str() const { /********** Small member methods **********/ -Target::Target(const String& tag_or_config_or_target_str) { +Target::Target(const ffi::String& tag_or_config_or_target_str) { ObjectPtr target; try { target = TargetInternal::FromString(tag_or_config_or_target_str); @@ -581,7 +584,7 @@ Target::Target(const String& tag_or_config_or_target_str) { data_ = std::move(target); } -Target::Target(const Map& config) { +Target::Target(const ffi::Map& config) { ObjectPtr target; try { target = TargetInternal::FromConfig({config.begin(), config.end()}); @@ -594,13 +597,13 @@ Target::Target(const Map& config) { } Target::Target(Target target, Target host) { - ObjectPtr n = make_object(*target.get()); + ObjectPtr n = ffi::make_object(*target.get()); n->host = std::move(host); data_ = std::move(n); } -Target::Target(TargetKind kind, Optional host, String tag, Array keys, - Map attrs) { +Target::Target(TargetKind kind, ffi::Optional host, ffi::String tag, + ffi::Array keys, ffi::Map attrs) { auto data = ffi::make_object(); data->kind = std::move(kind); data->host = std::move(host); @@ -619,7 +622,7 @@ std::vector TargetNode::GetKeys() const { } std::unordered_set TargetNode::GetLibs() const { - Optional> libs = this->GetAttr>("libs"); + ffi::Optional> libs = this->GetAttr>("libs"); if (!libs.defined()) { return {}; } @@ -630,8 +633,8 @@ std::unordered_set TargetNode::GetLibs() const { return result; } -Map TargetNode::Export() const { - Map result = { +ffi::Map TargetNode::Export() const { + ffi::Map result = { {"kind", this->kind->name}, {"tag", this->tag}, {"keys", this->keys}, @@ -645,11 +648,11 @@ Map TargetNode::Export() const { return result; } -Optional TargetNode::GetHost() const { return this->host.as(); } +ffi::Optional TargetNode::GetHost() const { return this->host.as(); } Target Target::WithoutHost() const { if ((*this)->GetHost()) { - auto output = make_object(*get()); + auto output = ffi::make_object(*get()); output->host = std::nullopt; return Target(output); } else { @@ -658,7 +661,7 @@ Target Target::WithoutHost() const { } int TargetNode::GetTargetDeviceType() const { - if (Optional device_type = GetAttr("target_device_type")) { + if (ffi::Optional device_type = GetAttr("target_device_type")) { return Downcast(device_type)->value; } return kind->default_device_type; @@ -669,7 +672,7 @@ bool TargetNode::HasKey(const std::string& query_key) const { [&query_key](const auto& key) { return key == query_key; }); } -String TargetNode::ToDebugString() const { +ffi::String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; os << "id=" << std::hex << reinterpret_cast(this); @@ -747,9 +750,9 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { const auto& arg = args[0]; if (auto opt_target = arg.as()) { *rv = Target(opt_target.value()); - } else if (auto opt_str = arg.try_cast()) { + } else if (auto opt_str = arg.try_cast()) { *rv = Target(opt_str.value()); - } else if (auto opt_map = arg.try_cast>()) { + } else if (auto opt_map = arg.try_cast>()) { *rv = Target(opt_map.value()); } else { LOG(FATAL) << "TypeError: Cannot create target with type: " << args[0].GetTypeKey(); @@ -768,8 +771,8 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " << args.size(); } -ObjectPtr TargetInternal::FromString(const String& tag_or_config_or_target_str) { - if (Optional target = TargetTag::Get(tag_or_config_or_target_str)) { +ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { + if (ffi::Optional target = TargetTag::Get(tag_or_config_or_target_str)) { Target value = target.value(); return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(value); } @@ -779,25 +782,25 @@ ObjectPtr TargetInternal::FromString(const String& tag_or_config_or_targ return TargetInternal::FromRawString(tag_or_config_or_target_str); } -ObjectPtr TargetInternal::FromConfigString(const String& config_str) { +ObjectPtr TargetInternal::FromConfigString(const ffi::String& config_str) { const auto loader = tvm::ffi::Function::GetGlobal("target._load_config_dict"); ICHECK(loader.has_value()) << "AttributeError: \"target._load_config_dict\" is not registered. Please check " "if the python module is properly loaded"; - auto config = (*loader)(config_str).cast>>(); + auto config = (*loader)(config_str).cast>>(); if (!config.defined()) { TVM_FFI_THROW(ValueError) << "Cannot load config dict with python JSON loader"; } return TargetInternal::FromConfig({config.value().begin(), config.value().end()}); } -ObjectPtr TargetInternal::FromRawString(const String& target_str) { +ObjectPtr TargetInternal::FromRawString(const ffi::String& target_str) { ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces std::vector options = SplitString(std::string(target_str), ' '); std::string name = options[0]; // Create the target config - std::unordered_map config = {{"kind", String(name)}}; + std::unordered_map config = {{"kind", ffi::String(name)}}; TargetKind kind = GetTargetKind(name); for (size_t iter = 1, end = options.size(); iter < end;) { std::string key, value; @@ -823,20 +826,20 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { return TargetInternal::FromConfig(config); } -ObjectPtr TargetInternal::FromConfig(Map config) { - const String kKind = "kind"; - const String kTag = "tag"; - const String kKeys = "keys"; - const String kDeviceName = "device"; - const String kHost = "host"; - const String kFeatures = "features"; - ObjectPtr target = make_object(); +ObjectPtr TargetInternal::FromConfig(ffi::Map config) { + const ffi::String kKind = "kind"; + const ffi::String kTag = "tag"; + const ffi::String kKeys = "keys"; + const ffi::String kDeviceName = "device"; + const ffi::String kHost = "host"; + const ffi::String kFeatures = "features"; + ObjectPtr target = ffi::make_object(); ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser"; // parse 'kind' if (config.count(kKind)) { - if (auto kind = config[kKind].try_cast()) { + if (auto kind = config[kKind].try_cast()) { target->kind = GetTargetKind(kind.value()); ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) << "Cannot use both set_attrs_preprocessor and set_target_parser"; @@ -846,7 +849,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { VLOG(9) << "TargetInternal::FromConfig - Running target_parser"; config = target->kind->target_parser(config); if (config.count(kFeatures)) { - target->features = Downcast>(config[kFeatures]); + target->features = Downcast>(config[kFeatures]); config.erase(kFeatures); } } @@ -861,7 +864,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "tag" if (config.count(kTag)) { - if (auto tag = config[kTag].try_cast()) { + if (auto tag = config[kTag].try_cast()) { target->tag = tag.value(); config.erase(kTag); } else { @@ -873,13 +876,13 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "keys" { - std::vector keys; + std::vector keys; bool has_user_keys = config.count(kKeys); if (has_user_keys) { // user provided keys if (const auto* cfg_keys = config[kKeys].as()) { for (const Any& e : *cfg_keys) { - if (auto key = e.try_cast()) { + if (auto key = e.try_cast()) { keys.push_back(key.value()); } else { TVM_FFI_THROW(TypeError) << "Expect 'keys' to be an array of strings, but it " @@ -893,7 +896,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // add device name if (config.count(kDeviceName)) { - if (auto device = config.at(kDeviceName).try_cast()) { + if (auto device = config.at(kDeviceName).try_cast()) { keys.push_back(device.value()); } } @@ -915,9 +918,9 @@ ObjectPtr TargetInternal::FromConfig(Map config) { target->host = std::nullopt; } // parse attrs - std::unordered_map attrs; + std::unordered_map attrs; for (const auto& cfg_kv : config) { - const String& key = cfg_kv.first; + const ffi::String& key = cfg_kv.first; const ffi::Any& value = cfg_kv.second; try { const TargetKindNode::ValueTypeInfo& info = TargetInternal::FindTypeInfo(target->kind, key); @@ -950,8 +953,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // do extra pre-processing if (target->kind->preprocessor != nullptr) { - target->attrs = - target->kind->preprocessor(Map(attrs)).cast>(); + target->attrs = target->kind->preprocessor(ffi::Map(attrs)) + .cast>(); } else { target->attrs = attrs; } @@ -959,9 +962,9 @@ ObjectPtr TargetInternal::FromConfig(Map config) { return target; } // namespace tvm -std::unordered_map TargetInternal::QueryDevice(int device_id, - const TargetNode* target) { - std::unordered_map output; +std::unordered_map TargetInternal::QueryDevice(int device_id, + const TargetNode* target) { + std::unordered_map output; Device device{static_cast(target->GetTargetDeviceType()), device_id}; @@ -984,7 +987,7 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, } for (const auto& kv : target->kind->key2vtype_) { - const String& key = kv.first; + const ffi::String& key = kv.first; ffi::Any ret; api->GetTargetProperty(device, key, &ret); @@ -1007,13 +1010,14 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("target.WithHost", TargetInternal::WithHost) .def("target.TargetGetDeviceType", [](const Target& target) { return target->GetTargetDeviceType(); }) - .def("target.TargetGetFeature", [](const Target& target, const String& feature_key) -> Any { - if (auto opt_any = target->GetFeature(feature_key)) { - return opt_any.value(); - } else { - return Any(); - } - }); + .def("target.TargetGetFeature", + [](const Target& target, const ffi::String& feature_key) -> Any { + if (auto opt_any = target->GetFeature(feature_key)) { + return opt_any.value(); + } else { + return Any(); + } + }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e284a75fefc3..0c835fdca266 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -45,7 +45,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // simply save as the string return node->name; }) - .def("__data_from_json__", [](const String& name) { + .def("__data_from_json__", [](const ffi::String& name) { auto kind = TargetKind::Get(name); ICHECK(kind.has_value()) << "Cannot find target kind \'" << name << '\''; return kind.value(); @@ -62,32 +62,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) using TargetKindRegistry = AttrRegistry; -Array TargetKindRegEntry::ListTargetKinds() { +ffi::Array TargetKindRegEntry::ListTargetKinds() { return TargetKindRegistry::Global()->ListAllNames(); } -Map TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { - Map options; +ffi::Map TargetKindRegEntry::ListTargetKindOptions( + const TargetKind& target_kind) { + ffi::Map options; for (const auto& kv : target_kind->key2vtype_) { options.Set(kv.first, kv.second.type_key); } return options; } -TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { +TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const ffi::String& target_kind_name) { return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } -void TargetKindRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { +void TargetKindRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel) { TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel); } const AttrRegistryMapContainerMap& TargetKind::GetAttrMapContainer( - const String& attr_name) { + const ffi::String& attr_name) { return TargetKindRegistry::Global()->GetAttrMap(attr_name); } -Optional TargetKind::Get(const String& target_kind_name) { +ffi::Optional TargetKind::Get(const ffi::String& target_kind_name) { const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name); if (reg == nullptr) { return std::nullopt; @@ -140,12 +141,13 @@ static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::A return true; } -void CheckOrSetAttr(Map* attrs, const String& name, const String& value) { +void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& name, + const ffi::String& value) { auto iter = attrs->find(name); if (iter == attrs->end()) { attrs->Set(name, value); } else { - auto str = (*iter).second.try_cast(); + auto str = (*iter).second.try_cast(); ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; } @@ -162,7 +164,7 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) { // Update -arch=sm_xx if (target.count("arch")) { // If -arch has been specified, validate the correctness - String archStr = Downcast(target.at("arch")); + ffi::String archStr = Downcast(target.at("arch")); ICHECK(support::StartsWith(archStr, "sm_")) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr; } else { @@ -175,7 +177,7 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) { } else { archInt = std::stod(version.cast()) * 10 + 0.1; } - target.Set("arch", String("sm_") + std::to_string(archInt)); + target.Set("arch", ffi::String("sm_") + std::to_string(archInt)); } return target; } @@ -190,7 +192,7 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { // Update -mcpu=sm_xx if (target.count("mcpu")) { // If -mcpu has been specified, validate the correctness - String mcpu = Downcast(target.at("mcpu")); + ffi::String mcpu = Downcast(target.at("mcpu")); ICHECK(support::StartsWith(mcpu, "sm_")) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; } else { @@ -203,7 +205,7 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { } else { arch = std::stod(version.cast()) * 10 + 0.1; } - target.Set("mcpu", String("sm_") + std::to_string(arch)); + target.Set("mcpu", ffi::String("sm_") + std::to_string(arch)); } return target; } @@ -218,7 +220,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { // Update -mcpu=gfx std::string arch = "gfx900"; if (target.count("mcpu")) { - String mcpu = Downcast(target.at("mcpu")); + ffi::String mcpu = Downcast(target.at("mcpu")); arch = ExtractStringWithPrefix(mcpu, "gfx"); ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; } else { @@ -226,7 +228,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { arch = (*f_get_rocm_arch)().cast(); } - target.Set("mcpu", String(arch)); + target.Set("mcpu", ffi::String(arch)); } // Update -mattr before ROCm 3.5: // Before ROCm 3.5 we needed code object v2, starting @@ -241,9 +243,9 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { version = val.cast(); } if (version < 305) { - Array mattr; + ffi::Array mattr; if (target.count("mattr")) { - mattr = Downcast>(target.at("mattr")); + mattr = Downcast>(target.at("mattr")); } mattr.push_back("-code-object-v3"); target.Set("mattr", mattr); @@ -257,7 +259,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", true}}; + ffi::Map features = {{"is_test", true}}; target.Set("features", features); return target; } @@ -265,11 +267,11 @@ TargetJSON TestTargetParser(TargetJSON target) { /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) - .add_attr_option>("mattr") - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option("mfloat-abi") - .add_attr_option("mabi") + .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option("mfloat-abi") + .add_attr_option("mabi") .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags .add_attr_option("fast-math") // implies all the below @@ -281,9 +283,9 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("fast-math-reassoc") .add_attr_option("opt-level") // LLVM command line flags, see below - .add_attr_option>("cl-opt") + .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit - .add_attr_option("jit") + .add_attr_option("jit") // TVM & LLVM custom vector bit width .add_attr_option("vector-width") .set_default_keys({"cpu"}) @@ -314,16 +316,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) // Hence the type is "uint". TVM_REGISTER_TARGET_KIND("c", kDLCPU) - .add_attr_option("mcpu") - .add_attr_option("march") + .add_attr_option("mcpu") + .add_attr_option("march") .add_attr_option("workspace-byte-alignment") .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("arch") + .add_attr_option("mcpu") + .add_attr_option("arch") .add_attr_option("max_shared_memory_per_block") .add_attr_option("max_threads_per_block") .add_attr_option("thread_warp_size", 32) @@ -334,17 +336,17 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("mtriple") + .add_attr_option("mcpu") + .add_attr_option("mtriple") .add_attr_option("max_num_threads", 1024) .add_attr_option("thread_warp_size", 32) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 .add_attr_option("max_num_threads", 256) @@ -382,7 +384,7 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) - .add_attr_option>("mattr") + .add_attr_option>("mattr") // Feature support .add_attr_option("supports_float16") .add_attr_option("supports_float32", true) @@ -412,9 +414,9 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_per_stage_descriptor_storage_buffer") .add_attr_option("max_shared_memory_per_block") // Other device properties - .add_attr_option("device_type") - .add_attr_option("device_name") - .add_attr_option("driver_name") + .add_attr_option("device_type") + .add_attr_option("device_name") + .add_attr_option("driver_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") @@ -426,10 +428,10 @@ TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) - .add_attr_option>("mattr") - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("llvm-options") + .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("llvm-options") .add_attr_option("num-cores") .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); @@ -437,7 +439,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break - .add_attr_option>("devices"); + .add_attr_option>("devices"); TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break .set_target_parser(TestTargetParser); @@ -448,7 +450,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.TargetKindGetAttr", - [](TargetKind kind, String attr_name) -> ffi::Any { + [](TargetKind kind, ffi::String attr_name) -> ffi::Any { auto target_attr_map = TargetKind::GetAttrMap(attr_name); ffi::Any rv; if (target_attr_map.count(kind)) { @@ -458,7 +460,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("target.ListTargetKinds", TargetKindRegEntry::ListTargetKinds) .def("target.ListTargetKindOptions", TargetKindRegEntry::ListTargetKindOptions) - .def("target.ListTargetKindOptionsFromName", [](String target_kind_name) { + .def("target.ListTargetKindOptionsFromName", [](ffi::String target_kind_name) { TargetKind kind = TargetKind::Get(target_kind_name).value(); return TargetKindRegEntry::ListTargetKindOptions(kind); }); diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index ac67afcfafe5..dd1925aa3118 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -71,7 +71,7 @@ VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->ToDebugString() << " has device type " << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; - auto node = make_object(); + auto node = ffi::make_object(); node->device_type_int = device_type_int; node->virtual_device_id = virtual_device_id; node->target = std::move(target); @@ -85,7 +85,8 @@ VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target } /* static */ -Optional VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) { +ffi::Optional VirtualDevice::Join(const VirtualDevice& lhs, + const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 01b80386e2c0..2b81e82da8b5 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -84,10 +84,10 @@ DataType ComputeOpNode::output_dtype(size_t idx) const { return body[idx].dtype(); } -Array BaseComputeOpNode::output_shape(size_t idx) const { +ffi::Array BaseComputeOpNode::output_shape(size_t idx) const { ICHECK_LT(idx, num_outputs()); // for now, all outputs of a BaseComputeOp have the same shape - Array shape; + ffi::Array shape; for (const auto& ivar : this->axis) { const Range& r = ivar->dom; shape.push_back(r->extent); @@ -95,8 +95,8 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, - Map attrs) { +Tensor compute(ffi::Array shape, FCompute fcompute, std::string name, std::string tag, + ffi::Map attrs) { // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -112,8 +112,8 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, FBatchCompute fcompute, std::string name, - std::string tag, Map attrs) { +ffi::Array compute(ffi::Array shape, FBatchCompute fcompute, std::string name, + std::string tag, ffi::Map attrs) { // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -127,19 +127,19 @@ Array compute(Array shape, FBatchCompute fcompute, std::string } Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); - Array outputs; + ffi::Array outputs; for (int idx = 0; idx < op->num_outputs(); ++idx) { outputs.push_back(op.output(idx)); } return outputs; } -ComputeOp::ComputeOp(std::string name, std::string tag, Map attrs, - Array axis, Array body) { +ComputeOp::ComputeOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array axis, ffi::Array body) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -155,16 +155,16 @@ ComputeOp::ComputeOp(std::string name, std::string tag, Map at TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.ComputeOp", - [](std::string name, std::string tag, Optional> attrs, - Array axis, Array body) { - return ComputeOp(name, tag, attrs.value_or({}), axis, body); - }); + refl::GlobalDef().def("te.ComputeOp", [](std::string name, std::string tag, + ffi::Optional> attrs, + ffi::Array axis, ffi::Array body) { + return ComputeOp(name, tag, attrs.value_or({}), axis, body); + }); }); // The schedule related logics -Array ComputeOpNode::InputTensors() const { - Array ret; +ffi::Array ComputeOpNode::InputTensors() const { + ffi::Array ret; std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index ce9a5846ddf8..2a46579a1aed 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -105,19 +105,19 @@ class BufferSubstituter : public StmtExprMutator { /*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ - Array arg_list; + ffi::Array arg_list; /*! \brief The map from each Tensor to its corresponding buffer. */ std::unordered_map tensor2buffers; /*! \brief The transformer from ProducerLoad to BufferLoad. */ ProducerToBufferTransformer transformer; /*! \brief The buffers should be allocated at function root. */ - Array root_alloc; + ffi::Array root_alloc; /*! \brief The NameSupply to make block name unique. */ NameSupply name_supply; - String FreshName(String base_name) { return name_supply->FreshName(base_name); } + ffi::String FreshName(ffi::String base_name) { return name_supply->FreshName(base_name); } - explicit CreateFuncInfo(Array arg_list) + explicit CreateFuncInfo(ffi::Array arg_list) : arg_list(std::move(arg_list)), transformer(tensor2buffers) {} bool IsArg(const te::Tensor& tensor) const { @@ -131,7 +131,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { PrimFunc Process(PrimFunc func) { for (int i = 0, n = func->params.size(); i < n; ++i) { if (auto v = func->params[i].as()) { - if (Optional buffer = func->buffer_map.Get(v.value())) { + if (ffi::Optional buffer = func->buffer_map.Get(v.value())) { buffer2index_[buffer.value()] = i; } } @@ -141,7 +141,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { if (this->layout_free_buffer_indices_.empty()) { return func; } - Array indices; + ffi::Array indices; indices.reserve(this->layout_free_buffer_indices_.size()); for (int i : this->layout_free_buffer_indices_) { indices.push_back(i); @@ -153,8 +153,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { Block block = Downcast(StmtMutator::VisitStmt_(_block)); BlockNode* n = block.CopyOnWrite(); if (auto opt_ann = n->annotations.Get(topi_attr)) { - Array new_buffers; - for (Buffer buffer : Downcast>(opt_ann.value())) { + ffi::Array new_buffers; + for (Buffer buffer : Downcast>(opt_ann.value())) { auto it = buffer2index_.find(buffer); if (it != buffer2index_.end()) { layout_free_buffer_indices_.insert(it->second); @@ -168,7 +168,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { n->annotations.Set(topi_attr, new_buffers); } } - for (const String& attr : this->blocklist) { + for (const ffi::String& attr : this->blocklist) { auto it = n->annotations.find(attr); if (it != n->annotations.end()) { n->annotations.erase(attr); @@ -179,9 +179,9 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { std::unordered_map buffer2index_; std::set layout_free_buffer_indices_; - String topi_attr = "layout_free_placeholders"; - std::vector blocklist = {"const_matrix", "auto_scheduler_simplify_const_tensor_indices", - "workload"}; + ffi::String topi_attr = "layout_free_placeholders"; + std::vector blocklist = {"const_matrix", + "auto_scheduler_simplify_const_tensor_indices", "workload"}; }; /**! @@ -191,7 +191,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { **/ using NestedIterLevels = std::vector>; -NestedIterLevels GenerateNestedIterLevels(const Array& axes, arith::Analyzer* analyzer) { +NestedIterLevels GenerateNestedIterLevels(const ffi::Array& axes, + arith::Analyzer* analyzer) { int global_max_depth = 0; std::unordered_map depth; std::unordered_map var2iter; @@ -244,9 +245,9 @@ NestedIterLevels GenerateNestedIterLevels(const Array& axes, arith::Ana * \param info Generation context info. * \returns The output buffer objects, ordered by compute op's outputs. **/ -Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { +ffi::Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncInfo* info) { // Step 1. Collect output tensors in TE operation. - Array tensors; + ffi::Array tensors; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { StructuralEqual eq; @@ -265,8 +266,8 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " - << "but the first argument has body " << GetRef(reduce_) << ", while the " << k - << "-th argument has body " << GetRef(reduce); + << "but the first argument has body " << ffi::GetRef(reduce_) << ", while the " + << k << "-th argument has body " << ffi::GetRef(reduce); tensors.push_back(compute_op.output(k)); } } else { @@ -278,7 +279,7 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI // - Declare buffers // - Update `op2buffers` // - Add the non-argument tensors to `alloc_buffer` of the root block - Array buffers; + ffi::Array buffers; for (const te::Tensor& tensor : tensors) { Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; @@ -296,9 +297,9 @@ Array GenerateOutputBuffers(const te::ComputeOp& compute_op, CreateFuncI * \param info Generation context info. * \returns The block annotation dict. **/ -Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, - CreateFuncInfo* info) { - Map annotations; +ffi::Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, + CreateFuncInfo* info) { + ffi::Map annotations; auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any { if (auto tensor_value = value.try_cast()) { return info->tensor2buffers.at(tensor_value.value()); @@ -307,11 +308,11 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, } }; for (const auto& pair : compute_op->attrs) { - const String& key = pair.first; + const ffi::String& key = pair.first; const Any& value = pair.second; // TensorIR will not allow Tensor data structure if (value.as()) { - const auto array_value = Downcast>(value); + const auto array_value = Downcast>(value); annotations.Set(key, array_value.Map(mutate_attr)); } else { annotations.Set(key, mutate_attr(value)); @@ -331,17 +332,17 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, * \param info Generation context info. * \returns Init stmt. **/ -Stmt GenerateInitStmt(const Array& indices, const Array& buffers, - const ReduceNode* reduce, const Map& var_map, +Stmt GenerateInitStmt(const ffi::Array& indices, const ffi::Array& buffers, + const ReduceNode* reduce, const ffi::Map& var_map, CreateFuncInfo* info) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); }; - Optional init = std::nullopt; + ffi::Optional init = std::nullopt; Stmt body; int n_buffers = buffers.size(); - Array init_stmts; + ffi::Array init_stmts; init_stmts.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { const Buffer& buffer = buffers[i]; @@ -361,9 +362,9 @@ Stmt GenerateInitStmt(const Array& indices, const Array& buffe * \param analyzer Arithmetic analyzer in context. * \returns Init stmt. **/ -Stmt GenerateBodyStmt(const Array& indices, const Array& buffers, - const Map& var_map, PrimExpr expr_body, CreateFuncInfo* info, - arith::Analyzer* analyzer) { +Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::Array& buffers, + const ffi::Map& var_map, PrimExpr expr_body, + CreateFuncInfo* info, arith::Analyzer* analyzer) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); @@ -373,8 +374,8 @@ Stmt GenerateBodyStmt(const Array& indices, const Array& buffe // Case 1. Reduce compute int n_buffers = buffers.size(); - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; lhs.reserve(n_buffers); rhs.reserve(n_buffers); @@ -389,8 +390,8 @@ Stmt GenerateBodyStmt(const Array& indices, const Array& buffe ICHECK_EQ(left->dtype, right->dtype); } - Array temp_vars; - Array body_stmts; + ffi::Array temp_vars; + ffi::Array body_stmts; temp_vars.reserve(n_buffers); body_stmts.reserve(n_buffers); @@ -433,16 +434,16 @@ struct NestedScopeInfo { // loop var and range in the scope. std::vector> loop_vars; // block iters for current level's block. - Array block_iters; + ffi::Array block_iters; // block bindings for current level's block. - Array bindings; + ffi::Array bindings; // store indices for current level's block. - Array store_indices; + ffi::Array store_indices; // mapping from original TE compute axes to new block vars. - Map axes_remap; + ffi::Map axes_remap; // helper to add new block var - void AddBlockIter(const Optional& origin_axis, const IterVar& iter, + void AddBlockIter(const ffi::Optional& origin_axis, const IterVar& iter, const PrimExpr& value) { block_iters.push_back(iter); bindings.push_back(value); @@ -455,9 +456,9 @@ struct NestedScopeInfo { } // helper to renew leaf block var defs to ensure SSA. - void Renew(const Array& origin_axes) { + void Renew(const ffi::Array& origin_axes) { block_iters.MutateByApply([](const IterVar& itervar) { - auto n = make_object(*itervar.get()); + auto n = ffi::make_object(*itervar.get()); n->var = n->var.copy_with_suffix(""); return IterVar(n); }); @@ -474,7 +475,7 @@ struct NestedScopeInfo { Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, arith::Analyzer* analyzer) { // Step 1. Collect all iter axes in original TE compute op - Array axes = compute_op->axis; + ffi::Array axes = compute_op->axis; axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); // Step 2. Prepare nested iteration scopes. @@ -528,12 +529,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in } // Step 3. Generate output buffers for each output tensor - Array buffers = GenerateOutputBuffers(compute_op, info); + ffi::Array buffers = GenerateOutputBuffers(compute_op, info); // Step 4. Generate leaf block stmts. - Array seq_stmt; + ffi::Array seq_stmt; auto leaf = scopes.back(); - Map annotations = GenerateBlockAnnotations(compute_op, info); + ffi::Map annotations = GenerateBlockAnnotations(compute_op, info); const ReduceNode* reduce = compute_op->body[0].as(); if (reduce) { PrimExpr expr_body = compute_op->body[0]; @@ -585,7 +586,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in auto block_name = info->FreshName(compute_op->name + "_l" + std::to_string(i)); const auto& block_iters = cur.block_iters; - Optional init{std::nullopt}; + ffi::Optional init{std::nullopt}; if (reduce && std::any_of(block_iters.begin(), block_iters.end(), [](const IterVar& iter) { return iter->iter_type == IterVarType::kCommReduce; })) { @@ -666,13 +667,13 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*annotations=*/extern_op->attrs)); } -Array CollectOrderedOps(const Array& arg_list) { - Array arg_ops; +ffi::Array CollectOrderedOps(const ffi::Array& arg_list) { + ffi::Array arg_ops; for (const te::Tensor& arg : arg_list) { arg_ops.push_back(arg->op); } te::ReadGraph g = te::CreateReadGraph(arg_ops); - Array order = te::PostDFSOrder(arg_ops, g); + ffi::Array order = te::PostDFSOrder(arg_ops, g); for (const te::Operation& op : order) { if (!(op->IsInstance() || op->IsInstance() || @@ -683,7 +684,7 @@ Array CollectOrderedOps(const Array& arg_list) { return order; } -void InitializeBufferBinds(const Array& ordered_ops, CreateFuncInfo* info) { +void InitializeBufferBinds(const ffi::Array& ordered_ops, CreateFuncInfo* info) { // Process any TE operations which contain user defined buffers for (const auto& op : ordered_ops) { // Initialize the tensor2buffer binds map with buffers defined by the te.extern @@ -698,8 +699,8 @@ void InitializeBufferBinds(const Array& ordered_ops, CreateFuncIn } } -void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array* root_stmts, - arith::Analyzer* analyzer) { +void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, + ffi::Array* root_stmts, arith::Analyzer* analyzer) { if (const auto* placeholder = op.as()) { // Case 1. PlaceholderOp (te.placeholder) ICHECK_EQ(op->num_outputs(), 1); @@ -727,10 +728,10 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array& arg_list, - const Array& root_stmts, CreateFuncInfo* info) { - Array parameters; - Map buffer_map; +PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_list, + const ffi::Array& root_stmts, CreateFuncInfo* info) { + ffi::Array parameters; + ffi::Map buffer_map; for (const te::Tensor& tensor : arg_list) { Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); @@ -742,25 +743,25 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override) { // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. - Array root_stmts; + ffi::Array root_stmts; // Analyzer arith::Analyzer analyzer; // Step 1. Create ordered array of operations and validate they are supported. - Array order = CollectOrderedOps(arg_list); + ffi::Array order = CollectOrderedOps(arg_list); // Step 2. Initialize buffer binds map InitializeBufferBinds(order, &info); @@ -780,7 +781,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, return result; } -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override) { return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } @@ -788,7 +789,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("te.CreatePrimFunc", [](ffi::PackedArgs args, ffi::Any* ret) { - Array arg_list = args[0].cast>(); + ffi::Array arg_list = args[0].cast>(); std::optional index_dtype_override{std::nullopt}; // Add conversion to make std::optional compatible with FFI. if (args[1] != nullptr) { @@ -799,10 +800,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Relax version impl -PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, - const Array& root_stmts, CreateFuncInfo* info) { - Array parameters; - Map buffer_map; +PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_list, + const ffi::Array& root_stmts, CreateFuncInfo* info) { + ffi::Array parameters; + ffi::Map buffer_map; for (const ObjectRef& arg : arg_tir_var_list) { if (auto opt_tensor = arg.as()) { te::Tensor tensor = opt_tensor.value(); @@ -819,32 +820,32 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override) { - Array tensor_arg_list; + ffi::Array tensor_arg_list; for (const ObjectRef& x : arg_list) { if (auto tensor_node = x.as()) { - te::Tensor tensor = GetRef(tensor_node); + te::Tensor tensor = ffi::GetRef(tensor_node); tensor_arg_list.push_back(tensor); } } // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(tensor_arg_list); // Root body stmts. - Array root_stmts; + ffi::Array root_stmts; // Analyzer arith::Analyzer analyzer; // Step 1. Create ordered array of operations and validate they are supported. - Array order = CollectOrderedOps(tensor_arg_list); + ffi::Array order = CollectOrderedOps(tensor_arg_list); // Step 2. Initialize buffer binds map InitializeBufferBinds(order, &info); @@ -862,7 +863,7 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, return result; } -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override) { return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 9e61d87ce332..f7ad7e0e1e0e 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,7 +30,7 @@ namespace tvm { namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override = std::nullopt); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the @@ -38,12 +38,12 @@ PrimFunc CreatePrimFunc(const Array& arg_list, * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants * will be embedded in the body as AllocateConstNode. */ -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override = std::nullopt); /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list, +PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the @@ -51,8 +51,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list, * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants * will be embedded in the body as AllocateConstNode. */ -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants, +PrimFunc CreatePrimFuncWithConstants(const ffi::Array& arg_list, + const ffi::Array& constants, std::optional index_dtype_override); } // namespace tir diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 23f43a99d8e6..ef18f26165ab 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -44,15 +44,17 @@ int ExternOpNode::num_outputs() const { return static_cast(output_placehold DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } -Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } +ffi::Array ExternOpNode::output_shape(size_t i) const { + return output_placeholders[i]->shape; +} -ExternOp::ExternOp(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { +ExternOp::ExternOp(std::string name, std::string tag, ffi::Map attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -74,16 +76,17 @@ ExternOp::ExternOp(std::string name, std::string tag, Map attr TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.ExternOp", - [](std::string name, std::string tag, Optional> attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { - return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, - output_placeholders, body); - }); + refl::GlobalDef().def( + "te.ExternOp", + [](std::string name, std::string tag, ffi::Optional> attrs, + ffi::Array inputs, ffi::Array input_placeholders, + ffi::Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs.value_or({}), inputs, input_placeholders, + output_placeholders, body); + }); }); -Array ExternOpNode::InputTensors() const { return inputs; } +ffi::Array ExternOpNode::InputTensors() const { return inputs; } } // namespace te } // namespace tvm diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index f477f9129b2a..561ad6e6c43b 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -37,7 +37,7 @@ namespace te { // construct a read graph that gives readers of each operation // that the root depend on -ReadGraph CreateReadGraph(const Array& roots) { +ReadGraph CreateReadGraph(const ffi::Array& roots) { ReadGraph rmap; std::vector stack; std::unordered_set visited; @@ -50,7 +50,7 @@ ReadGraph CreateReadGraph(const Array& roots) { while (!stack.empty()) { Operation op = stack.back(); stack.pop_back(); - Array deps = op->InputTensors(); + ffi::Array deps = op->InputTensors(); rmap.Set(op, deps); for (Tensor t : deps) { if (t->op.defined() && visited.count(t->op.get()) == 0) { @@ -63,7 +63,7 @@ ReadGraph CreateReadGraph(const Array& roots) { } void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, - Array* post_order) { + ffi::Array* post_order) { if (visited->count(op)) return; visited->insert(op); for (const auto& t : g.at(op)) { @@ -72,9 +72,9 @@ void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_setpush_back(op); } -Array PostDFSOrder(const Array& roots, const ReadGraph& g) { +ffi::Array PostDFSOrder(const ffi::Array& roots, const ReadGraph& g) { std::unordered_set visited; - Array post_order; + ffi::Array post_order; for (Operation op : roots) { PostDFSOrder(op, g, &visited, &post_order); } @@ -85,7 +85,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("schedule.CreateReadGraph", CreateReadGraph) - .def("schedule.PostDFSOrder", [](const Array& roots, const ReadGraph& g) { + .def("schedule.PostDFSOrder", [](const ffi::Array& roots, const ReadGraph& g) { return PostDFSOrder(roots, g); }); }); diff --git a/src/te/operation/graph.h b/src/te/operation/graph.h index 51ab8e1aa7bb..dc2b211cf3cb 100644 --- a/src/te/operation/graph.h +++ b/src/te/operation/graph.h @@ -33,7 +33,7 @@ namespace te { /*! * \brief data structure of Operation->Tensors it reads */ -using ReadGraph = Map>; +using ReadGraph = ffi::Map>; /*! * \brief Get read graph of each operation to all the @@ -43,7 +43,7 @@ using ReadGraph = Map>; * \param roots The root operation. * \return The result map. */ -ReadGraph CreateReadGraph(const Array& roots); +ReadGraph CreateReadGraph(const ffi::Array& roots); /*! * \brief Get a post DFS ordered of operations in the graph. @@ -54,7 +54,7 @@ ReadGraph CreateReadGraph(const Array& roots); * \note PostDFSOrder is a special case of Topoligical order, * and can be used when topoligical order is needed. */ -Array PostDFSOrder(const Array& roots, const ReadGraph& g); +ffi::Array PostDFSOrder(const ffi::Array& roots, const ReadGraph& g); } // namespace te } // namespace tvm diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 160f89f1eb84..d7acfb32ef23 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -45,31 +45,31 @@ DataType PlaceholderOpNode::output_dtype(size_t i) const { return dtype; } -Array PlaceholderOpNode::output_shape(size_t i) const { +ffi::Array PlaceholderOpNode::output_shape(size_t i) const { ICHECK_EQ(i, 0U); return shape; } -PlaceholderOp::PlaceholderOp(std::string name, Array shape, DataType dtype) { - auto n = make_object(); +PlaceholderOp::PlaceholderOp(std::string name, ffi::Array shape, DataType dtype) { + auto n = ffi::make_object(); n->name = name; n->shape = shape; n->dtype = dtype; data_ = std::move(n); } -Tensor placeholder(Array shape, DataType dtype, std::string name) { +Tensor placeholder(ffi::Array shape, DataType dtype, std::string name) { return PlaceholderOp(name, shape, dtype).output(0); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.Placeholder", [](Variant> shape_arg, + refl::GlobalDef().def("te.Placeholder", [](ffi::Variant> shape_arg, DataType dtype, std::string name) { - auto shape = [&]() -> Array { + auto shape = [&]() -> ffi::Array { if (auto arg_expr = shape_arg.as()) { return {arg_expr.value()}; - } else if (auto arg_array = shape_arg.as>()) { + } else if (auto arg_array = shape_arg.as>()) { return arg_array.value(); } else { LOG(FATAL) << "Variant did not contain either allowed type"; @@ -79,7 +79,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); -Array PlaceholderOpNode::InputTensors() const { return {}; } +ffi::Array PlaceholderOpNode::InputTensors() const { return {}; } } // namespace te } // namespace tvm diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index cd621c11dfc7..dfddaa3d9b38 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -42,18 +42,19 @@ int ScanOpNode::num_outputs() const { return static_cast(update.size()); } DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } -Array ScanOpNode::output_shape(size_t i) const { +ffi::Array ScanOpNode::output_shape(size_t i) const { ICHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } -ScanOp::ScanOp(std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { +ScanOp::ScanOp(std::string name, std::string tag, + ffi::Optional> attrs, IterVar axis, + ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs) { if (!attrs.defined()) { - attrs = Map(); + attrs = ffi::Map(); } - auto n = make_object(); + auto n = ffi::make_object(); ICHECK_EQ(init.size(), update.size()); ICHECK_EQ(init.size(), state_placeholder.size()); arith::Analyzer analyzer; @@ -102,29 +103,31 @@ ScanOp::ScanOp(std::string name, std::string tag, Optional TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "te.ScanOp", [](std::string name, std::string tag, Optional> attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { + "te.ScanOp", + [](std::string name, std::string tag, ffi::Optional> attrs, + IterVar axis, ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs) { return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); }); }); -Array scan(Array init, Array update, Array state_placeholder, - Array inputs, std::string name, std::string tag, - Optional> attrs) { +ffi::Array scan(ffi::Array init, ffi::Array update, + ffi::Array state_placeholder, ffi::Array inputs, + std::string name, std::string tag, + ffi::Optional> attrs) { IterVar scan_axis = IterVar(Range::FromMinExtent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), Var(name + ".idx"), kOrdered); Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); - Array res; + ffi::Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); } return res; } -Array ScanOpNode::InputTensors() const { - Array ret; +ffi::Array ScanOpNode::InputTensors() const { + ffi::Array ret; for (Tensor t : init) { ret.push_back(t); } diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 06dc0ccbc92c..027607e504ec 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -51,8 +51,9 @@ IterVar reduce_axis(Range dom, std::string name) { Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor -inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negative_indices) const { - Array shape = (*this)->shape; +inline PrimExpr Tensor::IndexTensor(ffi::Array indices, + bool support_negative_indices) const { + ffi::Array shape = (*this)->shape; if (shape.size() != 0) { ICHECK_EQ(shape.size(), indices.size()) @@ -70,30 +71,32 @@ inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negati return ProducerLoad((*this), indices); } -PrimExpr Tensor::operator()(Array indices) const { - Array arr(indices.begin(), indices.end()); +PrimExpr Tensor::operator()(ffi::Array indices) const { + ffi::Array arr(indices.begin(), indices.end()); return operator()(arr); } -PrimExpr Tensor::operator()(Array indices) const { return IndexTensor(indices, false); } +PrimExpr Tensor::operator()(ffi::Array indices) const { + return IndexTensor(indices, false); +} -PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { - Array arr(indices.begin(), indices.end()); +PrimExpr Tensor::IndexWithNegativeIndices(ffi::Array indices) const { + ffi::Array arr(indices.begin(), indices.end()); return IndexWithNegativeIndices(arr); } -PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { +PrimExpr Tensor::IndexWithNegativeIndices(ffi::Array indices) const { return IndexTensor(indices, true); } -String TensorNode::GetNameHint() const { +ffi::String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } -PrimExpr TensorNode::ToPrimExpr() const { return GetRef(this)(); } +PrimExpr TensorNode::ToPrimExpr() const { return ffi::GetRef(this)(); } Tensor Operation::output(size_t i) const { - auto node = make_object(); + auto node = ffi::make_object(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -101,8 +104,8 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { - auto n = make_object(); +Tensor::Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index) { + auto n = ffi::make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; @@ -112,10 +115,10 @@ Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_in TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("te.Tensor", - [](Array shape, DataType dtype, Operation op, int value_index) { - return Tensor(shape, dtype, op, value_index); - }); + refl::GlobalDef().def( + "te.Tensor", [](ffi::Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 2503d12df195..d0fd976a4fcb 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -40,21 +40,21 @@ namespace tir { */ class BlockReadWriteDetector : public StmtExprVisitor { public: - explicit BlockReadWriteDetector(const Map& buffer_var_map) + explicit BlockReadWriteDetector(const ffi::Map& buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Return read regions of the block */ - Array CollectReads( + ffi::Array CollectReads( const std::unordered_set* excluded_buffers = nullptr); /*! \brief Return write regions of the block */ - Array CollectWrites( + ffi::Array CollectWrites( const std::unordered_set* excluded_buffers = nullptr); /*! * \brief Return opaque buffer regions of the block * \note The buffer accessed by load/store or call with buffer.data will * be marked as opaque. */ - Array CollectOpaques(); + ffi::Array CollectOpaques(); /*! \brief overload operator() to make sure it accepts a block node */ void operator()(const Stmt& stmt); @@ -78,7 +78,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { /*! \brief The opaque regions of the current block */ std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ - Map buffer_var_map_; + ffi::Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; /*! \brief let bindings inside the block */ @@ -97,7 +97,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ - Array CollectRegions( + ffi::Array CollectRegions( const std::vector& buffers, const std::vector>& regions, const std::unordered_set* excluded_buffers = nullptr); @@ -136,21 +136,21 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) { StmtExprVisitor::operator()(stmt); } -Array BlockReadWriteDetector::CollectReads( +ffi::Array BlockReadWriteDetector::CollectReads( const std::unordered_set* excluded_buffers) { return CollectRegions(read_buffers_, read_regions_, excluded_buffers); } -Array BlockReadWriteDetector::CollectWrites( +ffi::Array BlockReadWriteDetector::CollectWrites( const std::unordered_set* excluded_buffers) { return CollectRegions(writes_buffers_, write_regions_, excluded_buffers); } -Array BlockReadWriteDetector::CollectOpaques() { +ffi::Array BlockReadWriteDetector::CollectOpaques() { return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(ffi::GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; @@ -198,7 +198,7 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { const VarNode* buffer_var = op->args[1].as(); const IntImmNode* access_mask = op->args[4].as(); if (buffer_var && access_mask) { - auto it = buffer_var_map_.find(GetRef(buffer_var)); + auto it = buffer_var_map_.find(ffi::GetRef(buffer_var)); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); @@ -329,17 +329,17 @@ void BlockReadWriteDetector::Update(std::vector* buffers, regions->push_back(std::move(region)); } -Array BlockReadWriteDetector::CollectRegions( +ffi::Array BlockReadWriteDetector::CollectRegions( const std::vector& buffers, const std::vector>& regions, const std::unordered_set* excluded_buffers) { ICHECK_EQ(buffers.size(), regions.size()); - Array res; + ffi::Array res; res.reserve(buffers.size()); for (size_t i = 0; i < regions.size(); ++i) { if (excluded_buffers != nullptr && excluded_buffers->count(buffers[i].get())) { continue; } - Array region; + ffi::Array region; region.reserve(regions[i].size()); ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { @@ -371,11 +371,11 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { } } -Array> GetBlockAccessRegion(const Block& block, - const Map& buffer_var_map) { +ffi::Array> GetBlockAccessRegion( + const Block& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); - Array writes = detector.CollectWrites(); + ffi::Array writes = detector.CollectWrites(); std::unordered_set excluded_buffers; // exclude write buffers from read regions for reductions if init block is defined. if (block->init.defined()) { @@ -383,27 +383,27 @@ Array> GetBlockAccessRegion(const Block& block, excluded_buffers.insert(write_access->buffer.get()); } } - Array reads = detector.CollectReads(&excluded_buffers); - Array opaques = detector.CollectOpaques(); + ffi::Array reads = detector.CollectReads(&excluded_buffers); + ffi::Array opaques = detector.CollectOpaques(); return {reads, writes, opaques}; } -Array> GetBlockReadWriteRegion(const Block& block, - const Map& buffer_var_map) { +ffi::Array> GetBlockReadWriteRegion( + const Block& block, const ffi::Map& buffer_var_map) { BlockReadWriteDetector detector(buffer_var_map); detector(block); - Array opaques = detector.CollectOpaques(); + ffi::Array opaques = detector.CollectOpaques(); std::unordered_set excluded_buffers; for (const BufferRegion& opaque_access : opaques) { excluded_buffers.insert(opaque_access->buffer.get()); } - Array writes = detector.CollectWrites(&excluded_buffers); + ffi::Array writes = detector.CollectWrites(&excluded_buffers); if (block->init.defined()) { for (const BufferRegion& write_access : writes) { excluded_buffers.insert(write_access->buffer.get()); } } - Array reads = detector.CollectReads(&excluded_buffers); + ffi::Array reads = detector.CollectReads(&excluded_buffers); for (const BufferRegion& opaque_access : opaques) { reads.push_back(opaque_access); writes.push_back(opaque_access); diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 2ecd32b65a2e..07da2240a6da 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -42,7 +42,7 @@ namespace tir { */ class LCADetector : public StmtExprVisitor { public: - static Map> Detect(const PrimFunc& func) { + static ffi::Map> Detect(const PrimFunc& func) { LCADetector detector; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -60,11 +60,11 @@ class LCADetector : public StmtExprVisitor { detector.UpdateWithBlockidx(); // Prepare the return - Map> buffer_lca; + ffi::Map> buffer_lca; for (const auto& kv : detector.buffer_lca_) { - const Buffer& buffer = GetRef(kv.first); - const Optional stmt = - kv.second ? GetRef>(kv.second->stmt) : std::nullopt; + const Buffer& buffer = ffi::GetRef(kv.first); + const ffi::Optional stmt = + kv.second ? ffi::GetRef>(kv.second->stmt) : std::nullopt; buffer_lca.Set(buffer, stmt); } return buffer_lca; @@ -289,7 +289,7 @@ class LCADetector : public StmtExprVisitor { void UpdateWithBlockidx() { for (const auto& it : buffer_lca_) { const runtime::StorageScope& scope = - runtime::StorageScope::Create(GetRef(it.first).scope()); + runtime::StorageScope::Create(ffi::GetRef(it.first).scope()); if (scope.rank == runtime::StorageRank::kGlobal) { const ScopeInfo*& lca = buffer_lca_[it.first]; for (const ScopeInfo* blockidx_scope : blockidx_scopes_) { @@ -343,7 +343,7 @@ class LCADetector : public StmtExprVisitor { support::Arena arena_; }; -Map> DetectBufferAccessLCA(const PrimFunc& func) { +ffi::Map> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index feaa491cc8a2..3a944273664c 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -41,7 +41,7 @@ template class AllocationCalculator : public StmtExprVisitor { public: AllocationCalculator() = default; - tvm::Map operator()(const PrimFunc& func); + tvm::ffi::Map operator()(const PrimFunc& func); private: void VisitStmt_(const T* op) override; @@ -50,11 +50,11 @@ class AllocationCalculator : public StmtExprVisitor { }; template -tvm::Map AllocationCalculator::operator()(const PrimFunc& func) { +tvm::ffi::Map AllocationCalculator::operator()(const PrimFunc& func) { this->VisitStmt(func->body); - tvm::Map res; + tvm::ffi::Map res; for (auto [k, v] : _max_size) { - res.Set(String(k), Integer(v)); + res.Set(ffi::String(k), Integer(v)); } return res; } @@ -80,17 +80,19 @@ void AllocationCalculator::VisitStmt_(const T* op) { _current_size[storage_scope] -= size; } -tvm::Map > CalculateAllocatedBytes(const PrimFunc& func) { - tvm::Map > results; +tvm::ffi::Map > CalculateAllocatedBytes( + const PrimFunc& func) { + tvm::ffi::Map > results; results.Set("main", AllocationCalculator()(func)); return results; } -tvm::Map > CalculateAllocatedBytes(const IRModule& mod) { - tvm::Map > results; +tvm::ffi::Map > CalculateAllocatedBytes( + const IRModule& mod) { + tvm::ffi::Map > results; for (const auto& kv : mod->functions) { if (auto prim_func = kv.second.as()) { - String func_name = kv.first->name_hint; + ffi::String func_name = kv.first->name_hint; results.Set(func_name, AllocationCalculator()(prim_func.value())); } } @@ -101,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.analysis.calculate_allocated_bytes", - [](ObjectRef obj) -> tvm::Map > { + [](ObjectRef obj) -> tvm::ffi::Map > { if (auto func = obj.as()) { return CalculateAllocatedBytes(func.value()); } else if (auto mod = obj.as()) { @@ -144,8 +146,8 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; } -Array GetVTCMCompactionPasses() { - auto pass_list = Array(); +ffi::Array GetVTCMCompactionPasses() { + auto pass_list = ffi::Array(); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); @@ -168,7 +170,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass VerifyVTCMLimit(Optional default_target) { +Pass VerifyVTCMLimit(ffi::Optional default_target) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto opt = kv.second.as()) { diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index a9c2b9ecc609..8d001dd1e459 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -63,14 +63,14 @@ bool HasBufferLoad(PrimExpr expr) { return visitor.found_buffer_load; } -Optional SubstituteParamValues(const Array& param_vars, - const Array& param_values, - const PrimExpr& expr) { +ffi::Optional SubstituteParamValues(const ffi::Array& param_vars, + const ffi::Array& param_values, + const PrimExpr& expr) { ICHECK_EQ(param_vars.size(), param_values.size()) << "Expression was defined as having " << param_vars.size() << " parameters, but received " << param_values.size() << " arguments."; - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < param_values.size(); i++) { var_map.Set(param_vars[i], param_values[i]); } @@ -151,7 +151,7 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { public: using Parent = IRMutatorWithAnalyzer; - BufferConstraintApply(const Map>& axis_var_lookup, + BufferConstraintApply(const ffi::Map>& axis_var_lookup, const std::vector& knowns, Analyzer* analyzer) : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} @@ -163,10 +163,10 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { continue; } - Optional lane_var = std::nullopt; + ffi::Optional lane_var = std::nullopt; IntImm num_lanes; - Array indices = op->indices.Map([&](const auto& index) { + ffi::Array indices = op->indices.Map([&](const auto& index) { if (index.dtype().lanes() == 1) { return index; } else { @@ -192,11 +192,11 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { } } - return GetRef(op); + return ffi::GetRef(op); } private: - const Map>& axis_var_lookup_; + const ffi::Map>& axis_var_lookup_; const std::vector& knowns_; }; @@ -339,13 +339,13 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { void VisitExpr_(const BufferLoadNode* op) override { Parent::VisitExpr_(op); - BufferLoad load = GetRef(op); + BufferLoad load = ffi::GetRef(op); VisitAccess(load, BufferTouch::AccessType::Read, load); } void VisitStmt_(const BufferStoreNode* op) override { Parent::VisitStmt_(op); - VisitAccess(GetRef(op), BufferTouch::AccessType::Write, op->value); + VisitAccess(ffi::GetRef(op), BufferTouch::AccessType::Write, op->value); // Appending a control block ensures that all control blocks have // at most one statement that changes the buffer contents. auto prev_block = CurrentControlBlock(); @@ -554,7 +554,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{std::nullopt}; + ffi::Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -623,7 +623,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { // binding. When making a predicate in terms of the buffer indices, // these need to be substituted out. // std::unordered_map let_bindings_using_loop_; - Map let_bindings_using_loop_; + ffi::Map let_bindings_using_loop_; // Track in order to know what conditions limit the buffer access std::vector conditions_; @@ -635,17 +635,17 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { ControlFlowGraph* out_; }; -std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( - const tir::Buffer& buf, Array index_variables, Array indices, +std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( + const tir::Buffer& buf, ffi::Array index_variables, ffi::Array indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { const auto& current_block = *this; Analyzer local_analyzer; - Optional lane_var = std::nullopt; + ffi::Optional lane_var = std::nullopt; IntImm num_lanes; - Array index_expressions = indices.Map([&](const auto& index) { + ffi::Array index_expressions = indices.Map([&](const auto& index) { if (index.dtype().lanes() == 1) { return index; } else { @@ -656,9 +656,9 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make } }); - Array loop_vars; + ffi::Array loop_vars; - Map loop_ranges; + ffi::Map loop_ranges; for (const auto& loop_entry : current_block.active_loop_iterators) { loop_vars.push_back(loop_entry.loop_var); loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); @@ -675,7 +675,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make IntConstraintsTransform transform = [&]() { ICHECK_EQ(index_variables.size(), index_expressions.size()); - Array relations; + ffi::Array relations; for (size_t i = 0; i < index_expressions.size(); i++) { PrimExpr expr = index_expressions[i]; @@ -689,16 +689,16 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return arith::SolveLinearEquations(system); }(); - Map loop_var_to_axis_var = transform->src_to_dst; - Map free_params = transform->dst->ranges; + ffi::Map loop_var_to_axis_var = transform->src_to_dst; + ffi::Map free_params = transform->dst->ranges; PrimExpr transform_predicate = std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); - auto find_removable_params = [&]() -> Map { - Map removable_params; + auto find_removable_params = [&]() -> ffi::Map { + ffi::Map removable_params; // The arith::SolveLinearEquations is more general than the // utilities in iter_affine_map.h, but can introduce free @@ -712,13 +712,13 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return; } - Var var = GetRef(var_ptr); + Var var = ffi::GetRef(var_ptr); if (free_params.count(var) == 0) { return; } - bool uses_free_param = - UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef(v)) > 0; }); + bool uses_free_param = UsesVar( + b, [&](const VarNode* v) { return free_params.count(ffi::GetRef(v)) > 0; }); if (uses_free_param) { return; } @@ -746,7 +746,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make return local_analyzer.Simplify(Substitute(expr, removable_params)); }; - Map new_map; + ffi::Map new_map; for (const auto [loop_var, expr] : loop_var_to_axis_var) { static_cast(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 new_map.Set(loop_var, update(expr)); @@ -808,7 +808,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, const tir::Buffer& buf, - const Array& indices, + const ffi::Array& indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { ICHECK(graph); @@ -949,7 +949,7 @@ std::ostream& operator<<(std::ostream& os, const BufferState& state) { } PrimExpr BufferState::SubstituteKnownBufferValues( - PrimExpr expr, const Map>& axis_var_lookup, + PrimExpr expr, const ffi::Map>& axis_var_lookup, Analyzer* analyzer) const { BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); return mutator(std::move(expr)); @@ -961,7 +961,7 @@ void BufferState::AddCondition(const PrimExpr& condition) { } } -void BufferState::Substitute(const Map& var_remap, Analyzer* analyzer) { +void BufferState::Substitute(const ffi::Map& var_remap, Analyzer* analyzer) { if (var_remap.size()) { for (auto& prior : constraints_) { PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); @@ -1026,12 +1026,12 @@ class BufferRegionCollector : public ExprVisitor { public: struct Region { PrimExpr region_predicate; - std::unordered_map> known_values; + std::unordered_map> known_values; }; - static std::vector Collect(const Map>& axis_var_lookup, + static std::vector Collect(const ffi::Map>& axis_var_lookup, const std::vector& knowns, - const std::vector>& exprs, + const std::vector>& exprs, Analyzer* analyzer) { BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); for (const auto& expr : exprs) { @@ -1046,7 +1046,7 @@ class BufferRegionCollector : public ExprVisitor { private: using Parent = ExprVisitor; - BufferRegionCollector(const Map>& axis_var_lookup, + BufferRegionCollector(const ffi::Map>& axis_var_lookup, const std::vector& knowns, Analyzer* analyzer) : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { regions_.push_back(Region{Bool(true), {}}); @@ -1058,7 +1058,7 @@ class BufferRegionCollector : public ExprVisitor { // Helper struct for the known values of this BufferLoad struct Known { PrimExpr predicate; - Optional value; + ffi::Optional value; }; std::vector new_regions; @@ -1077,7 +1077,7 @@ class BufferRegionCollector : public ExprVisitor { touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); if (!is_zero(touch_predicate)) { - Optional known_value = + ffi::Optional known_value = SubstituteParamValues(axis_vars, op->indices, constraint.value); new_regions.push_back(Known{touch_predicate, known_value}); @@ -1112,14 +1112,14 @@ class BufferRegionCollector : public ExprVisitor { Analyzer* analyzer_; std::vector regions_; - const Map>& axis_var_lookup_; + const ffi::Map>& axis_var_lookup_; const std::vector& knowns_; }; class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { public: static PrimExpr Apply( - const std::unordered_map>& known_values, + const std::unordered_map>& known_values, PrimExpr expr, Analyzer* analyzer) { BufferRegionValueReplacer mutator(known_values, analyzer); PrimExpr result = mutator(expr); @@ -1134,7 +1134,7 @@ class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { using Parent = IRMutatorWithAnalyzer; BufferRegionValueReplacer( - const std::unordered_map>& known_values, + const std::unordered_map>& known_values, Analyzer* analyzer) : Parent(analyzer), known_values_(known_values) {} @@ -1145,17 +1145,17 @@ class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { if (it != known_values_.end() && it->second) { return it->second.value(); } else { - return GetRef(op); + return ffi::GetRef(op); } } - const std::unordered_map>& known_values_; + const std::unordered_map>& known_values_; }; -void BufferState::ApplyTouches(const Map>& axis_var_lookup, +void BufferState::ApplyTouches(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, Analyzer* analyzer) { std::vector new_knowns; - Map keep_prior_known_at; + ffi::Map keep_prior_known_at; for (auto& touch : touch_points) { if (touch.touch_type == BufferTouch::AccessType::Read) { @@ -1209,7 +1209,7 @@ void BufferState::ApplyTouches(const Map>& axis_var_lookup, for (size_t i = 0; i < new_knowns.size(); i++) { if (new_knowns[i].buffer.same_as(constraint.buffer)) { - Optional overwritten_with = new_knowns[i].value; + ffi::Optional overwritten_with = new_knowns[i].value; if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { expand_known_at = SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); @@ -1237,18 +1237,18 @@ void BufferState::ApplyTouches(const Map>& axis_var_lookup, constraints_.end()); } -void BufferState::BackpropUnusedIndices(const Map>& axis_var_lookup, +void BufferState::BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, Analyzer* analyzer) { std::vector new_knowns; - Map keep_prior_known_at; + ffi::Map keep_prior_known_at; - Map regions_written; - Map regions_read; + ffi::Map regions_written; + ffi::Map regions_read; for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { const auto& touch = *it; - Map* to_update{nullptr}; + ffi::Map* to_update{nullptr}; if (touch.touch_type == BufferTouch::AccessType::Write) { to_update = ®ions_written; @@ -1264,7 +1264,7 @@ void BufferState::BackpropUnusedIndices(const Map>& axis_var_ } auto update_map = [&](auto& map) { - Map new_map; + ffi::Map new_map; for (auto [buffer, predicate] : map) { new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); } @@ -1303,7 +1303,7 @@ void BufferState::BackpropUnusedIndices(const Map>& axis_var_ constraints_.end()); } -void BufferState::RemoveFreeParameters(const Map& free_predicate_parameters, +void BufferState::RemoveFreeParameters(const ffi::Map& free_predicate_parameters, Analyzer* analyzer) { for (auto& known : constraints_) { known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); @@ -1325,7 +1325,7 @@ bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) c return true; } -Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { +ffi::Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { return (*it).second; } else { @@ -1333,12 +1333,13 @@ Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) cons } } -Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array& indices) { +ffi::Array ControlFlowGraph::GetIndexVariables(const Buffer& buf, + const ffi::Array& indices) { if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { return (*it).second; } - Array vars; + ffi::Array vars; for (size_t i = 0; i < indices.size(); i++) { std::stringstream ss; ss << buf->name << "_axis_" << i; @@ -1620,7 +1621,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, const Stmt& context) const { - Optional> index_variables = GetIndexVariables(store->buffer); + ffi::Optional> index_variables = GetIndexVariables(store->buffer); if (!index_variables) { return false; } diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index f4babffbb74c..7bde341c38fa 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -186,7 +186,7 @@ class BufferState { * the original expression is returned. */ PrimExpr SubstituteKnownBufferValues(PrimExpr expr, - const Map>& axis_var_lookup, + const ffi::Map>& axis_var_lookup, arith::Analyzer* analyzer) const; /*! \brief Apply a condition to all known constraints @@ -205,7 +205,7 @@ class BufferState { * * \param var_remap The variable remapping to apply. */ - void Substitute(const Map& var_remap, arith::Analyzer* analyzer); + void Substitute(const ffi::Map& var_remap, arith::Analyzer* analyzer); /*! \brief Simplify the predicate of all constraints * @@ -226,7 +226,7 @@ class BufferState { * * \param analyzer The analyzer to use for simplifications */ - void ApplyTouches(const Map>& axis_var_lookup, + void ApplyTouches(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, arith::Analyzer* analyzer); /*! \brief Update unused buffer locations based on buffer touches @@ -245,7 +245,7 @@ class BufferState { * * \param analyzer The analyzer to use for simplifications */ - void BackpropUnusedIndices(const Map>& axis_var_lookup, + void BackpropUnusedIndices(const ffi::Map>& axis_var_lookup, const std::vector& touch_points, arith::Analyzer* analyzer); @@ -255,7 +255,7 @@ class BufferState { * * \param analyzer The analyzer with which to simplify after removal */ - void RemoveFreeParameters(const Map& free_predicate_parameters, + void RemoveFreeParameters(const ffi::Map& free_predicate_parameters, arith::Analyzer* analyzer); /*! \brief Check if two buffer states are equivalent @@ -462,7 +462,7 @@ class ControlFlowGraph { * * \returns Variables representing a position along the buffer's axis. */ - Array GetIndexVariables(const Buffer& buf, const Array& indices); + ffi::Array GetIndexVariables(const Buffer& buf, const ffi::Array& indices); /*! \brief Return index variables representing locations within a * buffer, if they have been generated before. @@ -473,7 +473,7 @@ class ControlFlowGraph { * * \returns Variables representing a position along the buffer's axis. */ - Optional> GetIndexVariables(const Buffer& buf) const; + ffi::Optional> GetIndexVariables(const Buffer& buf) const; /*! \brief Propagate known values from known BufferStore/assume * subsequent control flow blocks @@ -501,7 +501,7 @@ class ControlFlowGraph { * e.g. Replacing loop iterator `i` with `i-1` when following an * edge from the end of a loop to the beginning of the loop. */ - Map var_remap; + ffi::Map var_remap; /*! \brief Condition that must to true after following this edge * @@ -509,7 +509,7 @@ class ControlFlowGraph { * loop_min` when following the an edge from the end of a loop to * the beginning of the loop. */ - Optional post_condition; + ffi::Optional post_condition; }; friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge); @@ -525,7 +525,7 @@ class ControlFlowGraph { std::vector active_loop_iterators; /*! \brief Loop-dependent Let bindings that may appear within the block */ - Map let_bindings_using_loop; + ffi::Map let_bindings_using_loop; /*! \brief Predicate that must be true to have reached this block */ PrimExpr scope_predicate{Bool(true)}; @@ -577,7 +577,8 @@ class ControlFlowGraph { * \returns The newly generated BufferTouch */ BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf, - const Array& indices, BufferTouch::AccessType touch_type, + const ffi::Array& indices, + BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const; /* \brief Construct a BufferTouch instance as if it occurred in @@ -602,11 +603,11 @@ class ControlFlowGraph { * all free parameters that may occur in the BufferTouch's * predicate. */ - std::pair> MakeBufferTouch(const Buffer& buf, - Array index_variables, - Array indices, - BufferTouch::AccessType touch_type, - PrimExpr known_value_expr) const; + std::pair> MakeBufferTouch(const Buffer& buf, + ffi::Array index_variables, + ffi::Array indices, + BufferTouch::AccessType touch_type, + PrimExpr known_value_expr) const; }; friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern); @@ -629,10 +630,10 @@ class ControlFlowGraph { * the free parameters allows them to be removed later, by requiring * a predicate to be true for all values of the free parameters. */ - Map free_predicate_parameters_; + ffi::Map free_predicate_parameters_; /*! \brief Ranges of iterators found in the analyzed statement */ - Map iterator_ranges_; + ffi::Map iterator_ranges_; /* \brief A map from buffer to the variables representing positions * along the buffer's axes. @@ -642,7 +643,7 @@ class ControlFlowGraph { * variables to represent the buffer's axes, reducing the amount of * variable substitution required. */ - Map> axis_var_lookup_; + ffi::Map> axis_var_lookup_; /* \brief Assumptions that do not depend on buffer values * diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 5d85ef31e88e..9c2ea0f8442c 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -66,7 +66,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); i++) { if (!VisitExpr(lhs[i], rhs[i])) return false; @@ -74,7 +74,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { // for iter var, we require pointer equality if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); i++) { @@ -83,7 +83,7 @@ class ExprDeepEqualChecker : private ExprFunctor& lhs, const Optional& rhs) { + bool OptionalDeepEqual(const ffi::Optional& lhs, const ffi::Optional& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() && rhs.defined()) return false; if (lhs.defined() && !rhs.defined()) return false; diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 3f012d5f15af..300e3afcd6b1 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -37,7 +37,7 @@ int32_t DataType2Int(const tvm::DataType& dtype) { return converter.dst; } -String Int2DataTypeStr(int32_t dtype) { +ffi::String Int2DataTypeStr(int32_t dtype) { union { DLDataType dst; int32_t src; diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index c23eed2da997..76fbd75ba488 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -42,8 +42,8 @@ namespace tir { std::variant IdentifyMemCpyImpl(const For& loop, arith::Analyzer* analyzer) { - Map loop_intervals; - Map loop_ranges; + ffi::Map loop_intervals; + ffi::Map loop_ranges; PrimExpr total_loop_iterations = 1; // Walk through the loop nest, stopping at the first loop whose body @@ -82,8 +82,8 @@ std::variant IdentifyMemCpyImpl(const For& loop, // Now, we have a BufferStore whose value is a BufferLoad. Because // non-flat physical indices are target-dependent, only handle cases // where the buffer will be flattened to a 1-d physical buffer. - Array flattened_dst = store->buffer.OffsetOf(store->indices); - Array flattened_src = load->buffer.OffsetOf(load->indices); + ffi::Array flattened_dst = store->buffer.OffsetOf(store->indices); + ffi::Array flattened_src = load->buffer.OffsetOf(load->indices); if (flattened_dst.size() != 1 || flattened_src.size() != 1) { return static_cast( @@ -286,19 +286,19 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis._identify_memcpy", [](const Stmt& stmt) { - Array output; + ffi::Array output; struct Visitor : arith::IRVisitorWithAnalyzer { - explicit Visitor(Array* output) : output(output) {} - Array* output; + explicit Visitor(ffi::Array* output) : output(output) {} + ffi::Array* output; private: using IRVisitorWithAnalyzer::VisitStmt_; void VisitStmt_(const ForNode* op) override { - For loop = GetRef(op); + For loop = ffi::GetRef(op); auto result = IdentifyMemCpyImpl(loop, &(Visitor::analyzer_)); if (auto* ptr = std::get_if(&result)) { - output->push_back(Array{ptr->source, ptr->dest}); + output->push_back(ffi::Array{ptr->source, ptr->dest}); } else if (auto* ptr = std::get_if(&result)) { output->push_back(StringImm(*ptr)); } else { diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index 9e85e4cc86c7..f5c47a7cae00 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -79,7 +79,7 @@ class PurityChecker : TIRVisitorWithPath { LOG_IF(FATAL, assert_on_error_) << "AssertionError: " << "Pure functions must not contain calls to impure operators, " - << "but " << GetRef(call) << " calls operator " << call->op + << "but " << ffi::GetRef(call) << " calls operator " << call->op << ", which has side effect " << effect; } } diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index 72626d27188d..fd08786efa5f 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -41,9 +41,9 @@ struct OOBLocation { class OOBError : public ScheduleError { public: OOBError(IRModule mod, std::vector locations) : mod_(mod), locations_(locations) {} - String FastErrorString() const final { return "Out of bound memory access"; } + ffi::String FastErrorString() const final { return "Out of bound memory access"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream s; for (const auto& oob : locations_) { s << "Out of bounds memory access on buffer " << oob.buf->name << " dimension " @@ -56,7 +56,7 @@ class OOBError : public ScheduleError { return s.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { std::vector locs; for (auto loc : locations_) { locs.push_back(loc.index); diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 2fe2ce5235a7..779c96ccb1b8 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -98,7 +98,7 @@ Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { } } - LOG(FATAL) << "Enclosing loop not found for a block " << GetRef(block); + LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef(block); TVM_FFI_UNREACHABLE(); } @@ -145,9 +145,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("tir.analysis.find_anchor_block", [](const IRModule& mod) { auto ret = FindAnchorBlock(mod); if (ret) { - return Optional(GetRef(ret)); + return ffi::Optional(ffi::GetRef(ret)); } - return Optional(std::nullopt); + return ffi::Optional(std::nullopt); }); }); diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 95da50204b97..0ce0402a8dff 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -27,7 +27,7 @@ namespace tvm { namespace tir { -VarUseDefAnalyzer::VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent) +VarUseDefAnalyzer::VarUseDefAnalyzer(const ffi::Array& defined_vars, bool visit_thread_extent) : visit_thread_extent_(visit_thread_extent) { for (const Var v : defined_vars) { use_count_[v.get()] = 0; @@ -104,7 +104,7 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { } void VarUseDefAnalyzer::VisitExpr_(const VarNode* op) { - this->HandleUse(GetRef(op)); + this->HandleUse(ffi::GetRef(op)); StmtExprVisitor::VisitExpr_(op); } @@ -123,7 +123,7 @@ void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) { void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) { this->HandleUse(buffer->data); - auto visit_arr = [&](Array arr) { + auto visit_arr = [&](ffi::Array arr) { for (const auto& element : arr) { this->VisitExpr(element); } @@ -151,7 +151,7 @@ void VarUseDefAnalyzer::HandleUse(const Var& var) { ++it->second; } } else { - undefined_.push_back(GetRef(v)); + undefined_.push_back(ffi::GetRef(v)); use_count_[v] = -1; } } @@ -176,26 +176,26 @@ void VarUseDefAnalyzer::HandleUse(const Buffer& buf) { ++it->second; } } else { - undefined_buffers_.push_back(GetRef(ptr)); + undefined_buffers_.push_back(ffi::GetRef(ptr)); buffer_use_count_[ptr] = -1; } VisitBuffer(buf); } -Array UndefinedVars(const Stmt& stmt, const Array& args) { +ffi::Array UndefinedVars(const Stmt& stmt, const ffi::Array& args) { VarUseDefAnalyzer m(args); m(stmt); return m.undefined_; } -Array UndefinedVars(const PrimExpr& expr) { +ffi::Array UndefinedVars(const PrimExpr& expr) { VarUseDefAnalyzer m({}); m(expr); return m.undefined_; } -Array UndefinedVars(const PrimExpr& expr, const Array& args) { +ffi::Array UndefinedVars(const PrimExpr& expr, const ffi::Array& args) { VarUseDefAnalyzer m(args); m(expr); return m.undefined_; @@ -206,9 +206,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def_packed( "tir.analysis.UndefinedVars", [](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { - *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); + *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); } else if (auto opt_expr = args[0].as()) { - *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); + *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); } else { LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; } diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 64985b11a9fa..51323d65d5b2 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -40,12 +40,12 @@ namespace tir { */ class VarUseDefAnalyzer : public StmtExprVisitor { public: - explicit VarUseDefAnalyzer(const Array& defined_vars, bool visit_thread_extent = true); + explicit VarUseDefAnalyzer(const ffi::Array& defined_vars, bool visit_thread_extent = true); // The fields are publically readible to // be accessible to the users. bool visit_thread_extent_{true}; - Array undefined_; - Array undefined_buffers_; + ffi::Array undefined_; + ffi::Array undefined_buffers_; std::unordered_map use_count_; std::unordered_map def_count_; diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index c1f8b327ecea..3b7ca0b080b5 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -39,10 +39,11 @@ namespace tir { class GPUCodeVerifier : public StmtExprVisitor { public: - std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, - int64_t max_shared_memory_per_block, int64_t max_threads_per_block, - int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z, - int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) { + std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, + int64_t max_shared_memory_per_block, + int64_t max_threads_per_block, int64_t max_thread_x, + int64_t max_thread_y, int64_t max_thread_z, int64_t max_vthread, + int64_t max_vector_bytes, int64_t max_kernels) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); @@ -187,7 +188,7 @@ class GPUCodeVerifier : public StmtExprVisitor { StmtVisitor::VisitStmt_(op); } - void CheckBufferIndicesVectorizable(const Array indices) { + void CheckBufferIndicesVectorizable(const ffi::Array indices) { for (const auto index : indices) { if (const auto* ramp = index.as()) { if (!is_one(ramp->stride) && @@ -263,7 +264,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_vector_bytes_; size_t max_kernels_; - std::vector errors_; + std::vector errors_; void Reset_() { local_memory_per_block_ = 0; @@ -274,7 +275,8 @@ class GPUCodeVerifier : public StmtExprVisitor { } }; -std::vector VerifyGPUCode_(const PrimFunc& func, Map constraints) { +std::vector VerifyGPUCode_(const PrimFunc& func, + ffi::Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -317,7 +319,7 @@ std::vector VerifyGPUCode_(const PrimFunc& func, Map c max_vthread, max_vector_bytes, max_kernels); } -bool VerifyGPUCode(const PrimFunc& func, Map constraints) { +bool VerifyGPUCode(const PrimFunc& func, ffi::Map constraints) { auto errs = VerifyGPUCode_(func, constraints); return errs.size() == 0; } @@ -329,7 +331,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace transform { -Pass VerifyGPUCode(Map constraints) { +Pass VerifyGPUCode(ffi::Map constraints) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto func = kv.second.as()) { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 6a93fa0206d4..68b5e5c4e92d 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -63,7 +63,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Verification result - std::vector Errors() const { return errs_; } + std::vector Errors() const { return errs_; } protected: /// Visitor implementation @@ -158,7 +158,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Status of visitor //@{ bool in_thread_env_{false}; - std::vector errs_; + std::vector errs_; //@} tir::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type @@ -167,7 +167,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } // namespace /// Interface of VerifyMemory pass -std::vector VerifyMemory_(const PrimFunc& func) { +std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 0d5f3f6cb491..85d5ed057279 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -81,7 +81,7 @@ class SSAVerifier final : public StmtExprVisitor { } void VisitExpr_(const VarNode* node) final { - auto var = GetRef(node); + auto var = ffi::GetRef(node); if (match_scope_) { MarkDef(var, var, true); } diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 2efd3648a5bb..d9fd0831904c 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -275,7 +275,7 @@ class UndefinedVarVerifier : public Verifier { } void VisitExpr_(const VarNode* op, AccessPath path) override { - auto var = GetRef(op); + auto var = ffi::GetRef(op); auto active_def = currently_defined_.find(var); auto verify = Verify(active_def != currently_defined_.end()); @@ -342,7 +342,7 @@ class SingleEnvThreadVerifier : public Verifier { } } - std::unordered_map> env_thread_vars_; + std::unordered_map> env_thread_vars_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 87847aed2d88..7626a1dcc496 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -42,7 +42,7 @@ class BlockDependenceInfoCollector : private StmtVisitor { } void MakeBlockScope(StmtSRef scope) { - Array child_block_srefs = std::move(block_frames_.back()); + ffi::Array child_block_srefs = std::move(block_frames_.back()); self_->sref2scope[scope] = BlockScope(child_block_srefs); } @@ -67,13 +67,13 @@ class BlockDependenceInfoCollector : private StmtVisitor { BlockDependenceInfoNode* self_; /*! \brief The stack frames of blocks in the DFS visit. */ - std::vector> block_frames_; + std::vector> block_frames_; }; -BlockDependenceInfo::BlockDependenceInfo() { data_ = make_object(); } +BlockDependenceInfo::BlockDependenceInfo() { data_ = ffi::make_object(); } BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); BlockDependenceInfoNode* self = n.get(); n->stmt2ref = SRefTreeCreator::Create(mod, /* include_loops */ false); @@ -94,9 +94,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }) .def_method("tir.BlockDependenceInfoGetBlockScope", &BlockDependenceInfoNode::GetBlockScope) .def("tir.BlockDependenceInfoGetSRef", - [](BlockDependenceInfo self, Stmt stmt) -> Optional { + [](BlockDependenceInfo self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); }); }); diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index ba651b953acc..8caec68b49d0 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -52,7 +52,7 @@ void AddDependency(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& ds /******** Constructors ********/ StmtSRef::StmtSRef(const StmtNode* stmt, StmtSRefNode* parent, int64_t seq_index) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->stmt = stmt; n->parent = parent; n->seq_index = seq_index; @@ -70,19 +70,19 @@ StmtSRef StmtSRef::RootMark() { } Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->src = std::move(src); node->dst = std::move(dst); node->kind = kind; data_ = std::move(node); } -BlockScope::BlockScope() { data_ = make_object(); } +BlockScope::BlockScope() { data_ = ffi::make_object(); } -BlockScope::BlockScope(const Array& child_block_srefs) { - ObjectPtr n = make_object(); - SMap> buffer_readers; - SMap>& buffer_writers = n->buffer_writers; +BlockScope::BlockScope(const ffi::Array& child_block_srefs) { + ObjectPtr n = ffi::make_object(); + SMap> buffer_readers; + SMap>& buffer_writers = n->buffer_writers; for (const StmtSRef& child_block_sref : child_block_srefs) { const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer @@ -125,7 +125,7 @@ BlockScope::BlockScope(const Array& child_block_srefs) { /******** Dependency ********/ -Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { +ffi::Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { auto iter = this->src2deps.find(block_sref); if (iter != this->src2deps.end()) { return iter->second; @@ -134,7 +134,7 @@ Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const } } -Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { +ffi::Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { auto iter = this->dst2deps.find(block_sref); if (iter != this->dst2deps.end()) { return iter->second; @@ -197,10 +197,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.StmtSRefStmt", - [](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }) + [](StmtSRef sref) -> ffi::Optional { + return ffi::GetRef>(sref->stmt); + }) .def("tir.StmtSRefParent", - [](StmtSRef sref) -> Optional { - return GetRef>(sref->parent); + [](StmtSRef sref) -> ffi::Optional { + return ffi::GetRef>(sref->parent); }) .def("tir.StmtSRefRootMark", StmtSRef::RootMark) .def("tir.StmtSRefInlineMark", StmtSRef::InlineMark) diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1cac41ff3ce5..7376ff1f1249 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -43,19 +43,20 @@ TVM_FFI_STATIC_INIT_BLOCK({ BufferNode::RegisterReflection(); }); using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; -Array SimplifyArray(arith::Analyzer* ana, Array array) { +ffi::Array SimplifyArray(arith::Analyzer* ana, ffi::Array array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ana->Simplify(array[i])); } return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, - Optional> axis_separators, Span span) { +Buffer decl_buffer(ffi::Array shape, DataType dtype, ffi::String name, + ffi::String storage_scope, ffi::Optional> axis_separators, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, - Array(), PrimExpr(), name, 0, 0, kDefault, - axis_separators.value_or(Array()), span); + ffi::Array(), PrimExpr(), name, 0, 0, kDefault, + axis_separators.value_or(ffi::Array()), span); } // Split the given expression w.r.t the add operator @@ -250,14 +251,14 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { return no_opt_sum; } -Array Buffer::OffsetOf(Array input_indices) const { +ffi::Array Buffer::OffsetOf(ffi::Array input_indices) const { return (*this)->ElemOffset(std::move(input_indices)); } // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -Array BufferNode::ElemOffset(Array input_indices) const { +ffi::Array BufferNode::ElemOffset(ffi::Array input_indices) const { ICHECK_EQ(shape.size(), input_indices.size()) << "Buffer " << this->name << " is " << shape.size() << "-dimensional, cannot be indexed with the " << input_indices.size() @@ -272,7 +273,7 @@ Array BufferNode::ElemOffset(Array input_indices) const { // TODO(Lunderberg): Better handling for cases where there is more // than one output index. Currently, this only allows elem_offset // to be non-zero for flat memory allocations. - Array elem_offsets = {}; + ffi::Array elem_offsets = {}; if (elem_offset.defined() && !is_zero(elem_offset)) { elem_offsets = {elem_offset}; } @@ -283,7 +284,7 @@ Array BufferNode::ElemOffset(Array input_indices) const { << "there must be one element offset for each output index."; } - Array output_indices(axis_separators.size() + 1, 0); + ffi::Array output_indices(axis_separators.size() + 1, 0); size_t current_output_axis = 0; @@ -318,8 +319,9 @@ Array BufferNode::ElemOffset(Array input_indices) const { return SimplifyArray(&ana, output_indices); } -inline Array BufferOffset(const BufferNode* n, Array index, DataType dtype) { - Array offsets = n->ElemOffset(index); +inline ffi::Array BufferOffset(const BufferNode* n, ffi::Array index, + DataType dtype) { + ffi::Array offsets = n->ElemOffset(index); // If the Buffer has element type with more than one lane, scale to // get the offset in number of scalars. if (n->dtype.lanes() != 1) { @@ -338,7 +340,7 @@ inline Array BufferOffset(const BufferNode* n, Array index, return offsets; } -static void ValidateAxisSeparators(const Array& axis_separators, size_t buffer_dim) { +static void ValidateAxisSeparators(const ffi::Array& axis_separators, size_t buffer_dim) { // These checks ensure that all output axes contain at least one // input axis. for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { @@ -370,7 +372,7 @@ Buffer Buffer::GetFlattenedBuffer() const { ValidateAxisSeparators(self->axis_separators, self->shape.size()); - Array output_shape; + ffi::Array output_shape; if (self->strides.size()) { // If strides are defined, then the extent of each flattened // buffer is the stride*size for the first input axis used for @@ -386,7 +388,7 @@ Buffer Buffer::GetFlattenedBuffer() const { // of the extents of each input axis used to generate that output // axis. This also "flattens" rank-0 tensors to a rank-1 buffer // of shape [1]. - output_shape = Array(self->axis_separators.size() + 1, 1); + output_shape = ffi::Array(self->axis_separators.size() + 1, 1); size_t current_output_index = 0; for (size_t i = 0; i < self->shape.size(); i++) { if ((current_output_index < self->axis_separators.size()) && @@ -398,7 +400,7 @@ Buffer Buffer::GetFlattenedBuffer() const { } // The axis_separators for the output buffer. - Array output_axis_separators; + ffi::Array output_axis_separators; for (size_t i = 0; i < self->axis_separators.size(); i++) { auto dtype = self->axis_separators[i]->dtype; output_axis_separators.push_back(IntImm(dtype, i + 1)); @@ -416,8 +418,8 @@ Buffer Buffer::GetFlattenedBuffer() const { } } -PrimExpr Buffer::vload(Array begin, DataType value_dtype, - Optional predicate) const { +PrimExpr Buffer::vload(ffi::Array begin, DataType value_dtype, + ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); @@ -425,7 +427,7 @@ PrimExpr Buffer::vload(Array begin, DataType value_dtype, value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot load " << value_dtype << " from buffer of " << n->dtype; - Array indices = begin; + ffi::Array indices = begin; PrimExpr base = indices[indices.size() - 1]; if (value_dtype.is_fixed_length_vector()) { int factor = value_dtype.lanes() / n->dtype.lanes(); @@ -436,7 +438,8 @@ PrimExpr Buffer::vload(Array begin, DataType value_dtype, return BufferLoad(*this, indices, predicate); } -Stmt Buffer::vstore(Array begin, PrimExpr value, Optional predicate) const { +Stmt Buffer::vstore(ffi::Array begin, PrimExpr value, + ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); @@ -445,7 +448,7 @@ Stmt Buffer::vstore(Array begin, PrimExpr value, Optional pr value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot store " << value_dtype << " to buffer of " << n->dtype; - Array indices = begin; + ffi::Array indices = begin; PrimExpr base = indices[indices.size() - 1]; if (value_dtype.is_fixed_length_vector()) { int factor = value_dtype.lanes() / n->dtype.lanes(); @@ -456,7 +459,7 @@ Stmt Buffer::vstore(Array begin, PrimExpr value, Optional pr return BufferStore(*this, value, indices, predicate); } -String Buffer::scope() const { +ffi::String Buffer::scope() const { const auto* ptr_type = (*this)->data->type_annotation.as(); ICHECK(ptr_type) << "Buffer variable is not of pointer type"; if (ptr_type->storage_scope.empty()) { @@ -471,7 +474,7 @@ Buffer Buffer::MakeStrideView() const { std::vector temp; const BufferNode* self = operator->(); ICHECK(self != nullptr); - auto n = make_object(*self); + auto n = ffi::make_object(*self); PrimExpr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0; --i) { temp.push_back(acc); @@ -483,15 +486,15 @@ Buffer Buffer::MakeStrideView() const { return Buffer(n); } -Buffer Buffer::MakeSlice(Array begins, Array extents) const { +Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array extents) const { const BufferNode* n = operator->(); ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - Array elem_offset = + ffi::Array elem_offset = n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); }); - Array strides = n->strides; + ffi::Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; bool need_stride = false; @@ -526,7 +529,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const } PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset, - Optional input_extent) const { + ffi::Optional input_extent) const { const BufferNode* self = operator->(); ICHECK(self != nullptr); PrimExpr e_dtype; @@ -553,14 +556,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane if (input_extent.defined()) { extent = input_extent.value(); } - Array acc_args{e_dtype, self->data, elem_offset, extent, - make_const(DataType::Int(32), access_mask)}; + ffi::Array acc_args{e_dtype, self->data, elem_offset, extent, + make_const(DataType::Int(32), access_mask)}; return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); } -Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, int data_alignment, int offset_factor, - BufferType buffer_type, Array axis_separators, Span span) { +Buffer::Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -584,7 +587,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array ValidateAxisSeparators(axis_separators, shape.size()); - auto n = make_object(); + auto n = ffi::make_object(); n->data = std::move(data); n->dtype = dtype; @@ -614,7 +617,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array data_ = std::move(n); } -tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, +tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); @@ -637,7 +640,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, + return tir::Buffer(data, dtype, shape, ffi::Array(), elem_offset, name, data_alignment, offset_factor, buffer_type); } @@ -647,17 +650,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("tir.Buffer", [](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_EQ(args.size(), 11); - auto buffer_type = args[8].cast(); + auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; auto data = args[0].cast(); auto dtype = args[1].cast(); - auto shape = args[2].cast>(); - auto strides = args[3].cast>(); + auto shape = args[2].cast>(); + auto strides = args[3].cast>(); auto elem_offset = args[4].cast(); - auto name = args[5].cast(); + auto name = args[5].cast(); auto data_alignment = args[6].cast(); auto offset_factor = args[7].cast(); - auto axis_separators = args[9].cast>(); + auto axis_separators = args[9].cast>(); auto span = args[10].cast(); *ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment, offset_factor, type, axis_separators, span); diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index c1fd75d44efd..18fea3c45c12 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -74,8 +74,8 @@ const LayoutAxis& LayoutAxis::Get(const std::string& name) { return LayoutAxis::Get(name[0]); } -Layout::Layout(const Array& axes) { - auto node = make_object(); +Layout::Layout(const ffi::Array& axes) { + auto node = ffi::make_object(); node->axes = axes; std::ostringstream repr; for (const IterVar& axis : axes) { @@ -97,7 +97,7 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type"; if (name == "__undef__") return; - auto node = make_object(); + auto node = ffi::make_object(); node->name = name; if (name.empty()) return; // scalar @@ -149,9 +149,9 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); - if (len == 0) return Layout(Array()); + if (len == 0) return Layout(ffi::Array()); if (pos + len > ndim()) len = ndim() - pos; - Array new_layout; + ffi::Array new_layout; const auto axes = operator->()->axes; for (size_t i = pos; i < pos + len; ++i) { new_layout.push_back(axes[i]); @@ -170,7 +170,7 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) ICHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis << " has already been split in " << name; ICHECK(factor > 0) << "Invalid split size " << factor; - Array new_layout; + ffi::Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), @@ -207,7 +207,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Layout(" << l->name << ")"; }); -inline bool GetStoreRule(Array* index_rule, Array* shape_rule, +inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* shape_rule, const Layout& src_layout, const Layout& dst_layout) { if (!src_layout.defined() || src_layout.name().empty()) { LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid."; @@ -294,11 +294,11 @@ inline bool GetStoreRule(Array* index_rule, Array* shape_rul return true; } -inline Array TransformIndex(const Array& src_index, - const Array& src_axis, - const Array& transform_rule) { +inline ffi::Array TransformIndex(const ffi::Array& src_index, + const ffi::Array& src_axis, + const ffi::Array& transform_rule) { arith::Analyzer ana; - Array result; + ffi::Array result; std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; @@ -309,7 +309,7 @@ inline Array TransformIndex(const Array& src_index, return result; } -Array BijectiveLayout::ForwardIndex(const Array& src_index) const { +ffi::Array BijectiveLayout::ForwardIndex(const ffi::Array& src_index) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) @@ -317,7 +317,7 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule); } -Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { +ffi::Array BijectiveLayout::BackwardIndex(const ffi::Array& dst_index) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) @@ -325,10 +325,10 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule); } -inline Array TransformShape(const Array& src_shape, - const Array& src_axis, - const Array& target_axis, - const Array& transform_rule) { +inline ffi::Array TransformShape(const ffi::Array& src_shape, + const ffi::Array& src_axis, + const ffi::Array& target_axis, + const ffi::Array& transform_rule) { arith::Analyzer ana; ICHECK_EQ(src_shape.size(), src_axis.size()) << "Input shape size " << src_shape.size() << " mismatch with the expected shape size " @@ -361,7 +361,7 @@ inline Array TransformShape(const Array& src_shape, // infer the target shape, // for major-axis, use the forward/backward_rule directly, // for minor-axis, simply use the extent. - Array result; + ffi::Array result; ICHECK_EQ(transform_rule.size(), target_axis.size()); for (size_t i = 0; i < transform_rule.size(); ++i) { PrimExpr rule = transform_rule[i]; @@ -395,14 +395,14 @@ inline Array TransformShape(const Array& src_shape, return result; } -Array BijectiveLayout::ForwardShape(const Array& shape) const { +ffi::Array BijectiveLayout::ForwardShape(const ffi::Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->shape_forward_rule); } -Array BijectiveLayout::BackwardShape(const Array& shape) const { +ffi::Array BijectiveLayout::BackwardShape(const ffi::Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, @@ -410,7 +410,7 @@ Array BijectiveLayout::BackwardShape(const Array& shape) con } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { - auto n = make_object(); + auto n = ffi::make_object(); n->src_layout = std::move(src_layout); n->dst_layout = std::move(dst_layout); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 346f1ab63250..d6dcae6540ba 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -47,7 +47,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { BlockRealize realize = Downcast(StmtExprMutator::VisitStmt_(op)); - Array new_iter_values; + ffi::Array new_iter_values; bool changed = false; for (int i = 0; i < static_cast(op->iter_values.size()); ++i) { auto dtype = realize->block->iter_vars[i]->var->dtype; @@ -66,17 +66,18 @@ Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) { Stmt DataTypeLegalizer::VisitStmt_(const BlockNode* op) { Block new_block = Downcast(StmtExprMutator::VisitStmt_(op)); - Array new_iter_vars = MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { - auto dtype = iter->var.dtype(); - if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { - IterVar new_iter = iter; - new_iter.CopyOnWrite()->dom = - Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); - return new_iter; - } else { - return iter; - } - }); + ffi::Array new_iter_vars = + MutateArray(new_block->iter_vars, [/*this*/](const IterVar& iter) { + auto dtype = iter->var.dtype(); + if (iter->dom->min->dtype != dtype || iter->dom->extent->dtype != dtype) { + IterVar new_iter = iter; + new_iter.CopyOnWrite()->dom = + Range(cast(dtype, iter->dom->min), cast(dtype, iter->dom->extent)); + return new_iter; + } else { + return iter; + } + }); if (!op->iter_vars.same_as(new_iter_vars)) { new_block.CopyOnWrite()->iter_vars = std::move(new_iter_vars); } @@ -123,7 +124,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) { PrimExpr new_body = this->VisitExpr(op->body); if (value.same_as(op->value) && new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, new_body, op->span); } @@ -141,7 +142,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) { Stmt new_body = this->VisitStmt(op->body); if (value.same_as(op->value) && new_body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, new_body, op->span); } @@ -151,7 +152,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) { if (auto it = var_remap_.find(op); it != var_remap_.end()) { return it->second; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { @@ -160,7 +161,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) { PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); DataType dtype = true_value.dtype().with_bits(bits); @@ -174,7 +175,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { PrimExpr base = VisitExpr(op->base); PrimExpr stride = VisitExpr(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { ICHECK(base.dtype().is_int() && stride.dtype().is_int()); int bits = std::max(base.dtype().bits(), stride.dtype().bits()); @@ -194,7 +195,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CastNode* op) { PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return FUNC(a, b); \ } \ @@ -219,7 +220,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); #undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { - Call before = GetRef(op); + Call before = ffi::GetRef(op); PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); static const Op& builtin_pow_ = Op::Get("tir.pow"); @@ -264,7 +265,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { auto new_body = this->VisitStmt(op->body); if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) || !new_body.same_as(op->body)) { - Allocate new_allocate = GetRef(op); + Allocate new_allocate = ffi::GetRef(op); auto* n = new_allocate.CopyOnWrite(); n->extents = std::move(new_extents); n->condition = std::move(new_cond); @@ -272,7 +273,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { return new_allocate; } else { - return GetRef(op); + return ffi::GetRef(op); } } @@ -310,7 +311,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { Block new_body = Downcast(this->VisitStmt(op->block)); if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) || !new_body.same_as(op->block)) { - BlockRealize new_block_realize = GetRef(op); + BlockRealize new_block_realize = ffi::GetRef(op); auto* n = new_block_realize.CopyOnWrite(); n->predicate = std::move(new_predicate); n->iter_values = std::move(new_iter_values); @@ -318,14 +319,14 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { return new_block_realize; } else { - return GetRef(op); + return ffi::GetRef(op); } } Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { - Array new_alloc_buffers = + ffi::Array new_alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); }); - Array new_match_buffers = + ffi::Array new_match_buffers = op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) { Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer); BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source); @@ -336,17 +337,17 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { return match_buffer_region; } }); - Array new_reads = op->reads.Map( + ffi::Array new_reads = op->reads.Map( [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); - Array new_writes = op->writes.Map( + ffi::Array new_writes = op->writes.Map( [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); - Array new_iter_vars = + ffi::Array new_iter_vars = op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); }); - Optional new_init = std::nullopt; + ffi::Optional new_init = std::nullopt; if (op->init.defined()) { new_init = this->VisitStmt(op->init.value()); } - Map new_annotations = VisitBlockAnnotations(op->annotations); + ffi::Map new_annotations = VisitBlockAnnotations(op->annotations); Stmt new_body = this->VisitStmt(op->body); if (!new_init.same_as(op->init) || !new_body.same_as(op->body) || @@ -354,7 +355,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars) || !new_annotations.same_as(op->annotations)) { - Block new_block = GetRef(op); + Block new_block = ffi::GetRef(op); BlockNode* n = new_block.CopyOnWrite(); n->alloc_buffers = std::move(new_alloc_buffers); n->match_buffers = std::move(new_match_buffers); @@ -366,11 +367,11 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { n->body = std::move(new_body); return new_block; } - return GetRef(op); + return ffi::GetRef(op); } -Map IndexDataTypeRewriter::VisitBlockAnnotations( - const Map& annotations) { +ffi::Map IndexDataTypeRewriter::VisitBlockAnnotations( + const ffi::Map& annotations) { auto new_annotations = annotations; std::function f_mutate_obj = [this, &f_mutate_obj](const Any& obj) -> Any { @@ -383,7 +384,7 @@ Map IndexDataTypeRewriter::VisitBlockAnnotations( return new_buffer; } } else if (obj.as()) { - return Downcast>(obj).Map(f_mutate_obj); + return Downcast>(obj).Map(f_mutate_obj); } return obj; }; @@ -427,9 +428,9 @@ Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { bool is_enabled = is_enabled_; is_enabled_ = true; - Array new_shape = + ffi::Array new_shape = buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); - Array new_strides = + ffi::Array new_strides = buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); auto new_elem_offset = VisitExpr(buffer->elem_offset); is_enabled_ = is_enabled; @@ -467,7 +468,7 @@ BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer } Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); Buffer new_buffer = GetRemappedBuffer(op->buffer); auto value = this->VisitExpr(op->value); @@ -488,7 +489,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { } PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { - BufferLoad load = GetRef(op); + BufferLoad load = ffi::GetRef(op); Buffer new_buffer = GetRemappedBuffer(op->buffer); auto indices = VisitIndices(op->indices); @@ -502,7 +503,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { return load; } -Array IndexDataTypeRewriter::VisitIndices(Array indices) { +ffi::Array IndexDataTypeRewriter::VisitIndices(ffi::Array indices) { bool is_enabled = is_enabled_; is_enabled_ = true; @@ -521,18 +522,19 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { is_condition_ = is_condition; Stmt then_case = VisitStmt(op->then_case); - Optional else_case = - op->else_case.defined() ? Optional{VisitStmt(op->else_case.value())} : std::nullopt; + ffi::Optional else_case = op->else_case.defined() + ? ffi::Optional{VisitStmt(op->else_case.value())} + : std::nullopt; if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) || !else_case.same_as(op->else_case)) { - IfThenElse new_stmt = GetRef(op); + IfThenElse new_stmt = ffi::GetRef(op); auto* n = new_stmt.CopyOnWrite(); n->condition = std::move(cond); n->then_case = std::move(then_case); n->else_case = std::move(else_case); return new_stmt; } - return GetRef(op); + return ffi::GetRef(op); } Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { @@ -547,7 +549,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) || !new_body.same_as(op->body)) { - For new_for = GetRef(op); + For new_for = ffi::GetRef(op); auto* n = new_for.CopyOnWrite(); n->loop_var = new_loop_var; n->min = cast(new_loop_var.dtype(), min); @@ -556,13 +558,13 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { auto old_thread_binding = op->thread_binding.value(); auto* ptr = old_thread_binding.CopyOnWrite(); ptr->var = old_thread_binding->var.copy_with_dtype(new_loop_var.dtype()); - n->thread_binding = Optional(std::move(old_thread_binding)); + n->thread_binding = ffi::Optional(std::move(old_thread_binding)); } n->body = new_body; return new_for; } else { - return GetRef(op); + return ffi::GetRef(op); } } @@ -619,7 +621,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const SelectNode* op) { if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) { - return GetRef(op); + return ffi::GetRef(op); } else { int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits()); DataType dtype = true_value.dtype().with_bits(bits); @@ -640,14 +642,14 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { buffer_remap_.clear(); ivmap_.clear(); // start rewrite - Map new_buffer_map = func->buffer_map; + ffi::Map new_buffer_map = func->buffer_map; for (const auto& [var, buffer] : func->buffer_map) { new_buffer_map.Set(var, VisitBuffer(buffer)); } // remap params bool is_enabled = true; std::swap(is_enabled_, is_enabled); - Array params = func->params.Map([this](Var param) { + ffi::Array params = func->params.Map([this](Var param) { if (param.dtype().is_int()) { return Downcast(this->VisitExpr(param)); } else { @@ -670,15 +672,15 @@ bool IndexDataTypeNormalizer::CanRewriteDType(DataType dtype) const { PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype)) { ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); - return cast(target_data_type_, GetRef(op)); + return cast(target_data_type_, ffi::GetRef(op)); } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != target_data_type_ && !var_remap_.count(op)) { - var_remap_[op] = GetRef(op).copy_with_dtype(target_data_type_); + var_remap_[op] = ffi::GetRef(op).copy_with_dtype(target_data_type_); } return DataTypeLegalizer::VisitExpr_(op); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 4d787015cb19..646f2fd3fa08 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -83,7 +83,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.convert", - [](Variant> expr) { return expr; }); + [](ffi::Variant> expr) { return expr; }); }); #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ @@ -93,7 +93,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ - ObjectPtr node = make_object(); \ + ObjectPtr node = ffi::make_object(); \ node->dtype = a.dtype(); \ node->a = std::move(a); \ node->b = std::move(b); \ @@ -108,7 +108,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ - ObjectPtr node = make_object(); \ + ObjectPtr node = ffi::make_object(); \ DataType a_dtype = a.dtype(); \ node->dtype = \ DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ @@ -119,8 +119,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // Var -Var::Var(String name_hint, DataType dtype, Span span) { - auto n = make_object(); +Var::Var(ffi::String name_hint, DataType dtype, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->type_annotation = GetTypeFromRuntimeDataType(dtype); n->dtype = std::move(dtype); @@ -128,8 +128,8 @@ Var::Var(String name_hint, DataType dtype, Span span) { data_ = std::move(n); } -Var::Var(String name_hint, Type type_annotation, Span span) { - auto n = make_object(); +Var::Var(ffi::String name_hint, Type type_annotation, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); @@ -137,19 +137,19 @@ Var::Var(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -Var Var::copy_with_name(const String& name) const { +Var Var::copy_with_name(const ffi::String& name) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { - new_ptr = make_object(*ptr); + new_ptr = ffi::make_object(*ptr); } else { - new_ptr = make_object(*node); + new_ptr = ffi::make_object(*node); } new_ptr->name_hint = name; return Var(new_ptr); } -Var Var::copy_with_suffix(const String& suffix) const { +Var Var::copy_with_suffix(const ffi::String& suffix) const { return this->copy_with_name(get()->name_hint + suffix); } @@ -157,9 +157,9 @@ Var Var::copy_with_dtype(DataType dtype) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { - new_ptr = make_object(*ptr); + new_ptr = ffi::make_object(*ptr); } else { - new_ptr = make_object(*node); + new_ptr = ffi::make_object(*node); } new_ptr->type_annotation = GetTypeFromRuntimeDataType(dtype); new_ptr->dtype = std::move(dtype); @@ -168,7 +168,7 @@ Var Var::copy_with_dtype(DataType dtype) const { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Var", [](String name_hint, ffi::AnyView type, Span span) { + refl::GlobalDef().def("tir.Var", [](ffi::String name_hint, ffi::AnyView type, Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { @@ -178,8 +178,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // SizeVar -SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { - auto n = make_object(); +SizeVar::SizeVar(ffi::String name_hint, DataType dtype, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->type_annotation = GetTypeFromRuntimeDataType(dtype); n->dtype = std::move(dtype); @@ -187,8 +187,8 @@ SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { data_ = std::move(n); } -SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { - auto n = make_object(); +SizeVar::SizeVar(ffi::String name_hint, Type type_annotation, Span span) { + auto n = ffi::make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); @@ -199,12 +199,12 @@ SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.SizeVar", - [](String s, DataType t, Span span) { return SizeVar(s, t, span); }); + [](ffi::String s, DataType t, Span span) { return SizeVar(s, t, span); }); }); // IterVar -IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) { - ObjectPtr n = make_object(); +IterVar::IterVar(Range dom, Var var, IterVarType t, ffi::String thread_tag, Span span) { + ObjectPtr n = ffi::make_object(); if (dom.defined() && dom->extent.defined()) { CHECK(dom->extent.dtype().is_int()) << "The dtype of the domain of an IterVar must be an integer type. However, the domain's " @@ -225,14 +225,14 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.IterVar", [](Range dom, Var var, int iter_type, String thread_tag, Span span) { + "tir.IterVar", [](Range dom, Var var, int iter_type, ffi::String thread_tag, Span span) { return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); }); // StringImm -StringImm::StringImm(String value, Span span) { - ObjectPtr node = make_object(); +StringImm::StringImm(ffi::String value, Span span) { + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); node->span = std::move(span); @@ -242,7 +242,7 @@ StringImm::StringImm(String value, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.StringImm", - [](String value, Span span) { return StringImm(value, span); }); + [](ffi::String value, Span span) { return StringImm(value, span); }); }); // Cast @@ -250,7 +250,7 @@ Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = t; node->value = std::move(value); node->span = std::move(span); @@ -401,7 +401,7 @@ And::And(PrimExpr a, PrimExpr b, Span span) { ICHECK(b.dtype().is_bool()); ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); @@ -424,7 +424,7 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { ICHECK(b.dtype().is_bool()); ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); @@ -443,7 +443,7 @@ Not::Not(PrimExpr a, Span span) { ICHECK(a.defined()) << "ValueError: a is undefined"; ICHECK(a.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); DataType a_dtype = a.dtype(); node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); node->a = std::move(a); @@ -469,7 +469,7 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp << "TypeError: mismatched types. " << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -496,7 +496,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { stride = cast(base.dtype(), stride); } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); auto* lanes_as_int = lanes.as(); if (lanes_as_int) { int lanes = static_cast(lanes_as_int->value); @@ -530,7 +530,7 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { ICHECK(value.defined()); ICHECK(value.dtype().is_scalar()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); auto* lanes_int = lanes.as(); if (lanes_int) { int lanes = static_cast(lanes_int->value); @@ -564,7 +564,7 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { ICHECK(body.defined()); ICHECK_EQ(value.dtype(), var.dtype()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); @@ -581,12 +581,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Call -Call::Call(DataType dtype, RelaxExpr op, Array args, Span span) { +Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = dtype; node->op = std::move(op); node->args = std::move(args); @@ -598,18 +598,19 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.Call", - [](Optional dtype, RelaxExpr op, - Array> args, Span span) { - Array prim_expr_args; + [](ffi::Optional dtype, RelaxExpr op, + ffi::Array> args, + Span span) { + ffi::Array prim_expr_args; for (const auto& it : args) { - if (auto opt_str = it.as()) { + if (auto opt_str = it.as()) { prim_expr_args.push_back(StringImm(opt_str.value())); } else if (auto opt_dtype = it.as()) { prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); } else if (const auto* iter_var = it.as()) { prim_expr_args.push_back(iter_var->var); } else if (const auto* br = it.as()) { - Array indices; + ffi::Array indices; for (Range r : br->region) { if (is_one(r->extent)) { indices.push_back(r->min); @@ -617,7 +618,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << GetRef(br); + << ffi::GetRef(br); } } prim_expr_args.push_back(BufferLoad(br->buffer, indices)); @@ -630,7 +631,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Shuffle -Shuffle::Shuffle(Array vectors, Array indices, Span span) { +Shuffle::Shuffle(ffi::Array vectors, ffi::Array indices, Span span) { ICHECK_NE(vectors.size(), 0U); ICHECK_NE(indices.size(), 0U); @@ -643,7 +644,7 @@ Shuffle::Shuffle(Array vectors, Array indices, Span span) { } ICHECK_LE(indices.size(), static_cast(total_lanes)); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); @@ -651,12 +652,12 @@ Shuffle::Shuffle(Array vectors, Array indices, Span span) { data_ = node; } -PrimExpr Shuffle::Concat(Array vectors, Span span) { +PrimExpr Shuffle::Concat(ffi::Array vectors, Span span) { ICHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; } - Array indices; + ffi::Array indices; int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { @@ -672,13 +673,15 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Shuffle", [](Array vectors, Array indices, - Span span) { return Shuffle(vectors, indices, span); }); + refl::GlobalDef().def("tir.Shuffle", + [](ffi::Array vectors, ffi::Array indices, Span span) { + return Shuffle(vectors, indices, span); + }); }); // CommReducer -CommReducer::CommReducer(Array lhs, Array rhs, Array result, - Array identity_element, Span span) { +CommReducer::CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span) { size_t n_group = result.size(); CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " "number of elements in `results`"; @@ -708,7 +711,7 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, p_result->SetItem(i, Substitute(result[i], var_map)); } - auto node = make_object(); + auto node = ffi::make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; @@ -717,11 +720,12 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, data_ = std::move(node); } -Array CommReducerNode::operator()(Array a, Array b) const { +ffi::Array CommReducerNode::operator()(ffi::Array a, + ffi::Array b) const { ICHECK_EQ(a.size(), b.size()); ICHECK_EQ(lhs.size(), a.size()); ICHECK_EQ(rhs.size(), b.size()); - Map value_map; + ffi::Map value_map; for (size_t i = 0; i < a.size(); ++i) { value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); @@ -733,22 +737,22 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.CommReducer", - [](Array lhs, Array rhs, Array result, - Array identity_element, + [](ffi::Array lhs, ffi::Array rhs, ffi::Array result, + ffi::Array identity_element, Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }) .def_method("tir.CommReducerCombine", &tir::CommReducerNode::operator()); }); // Reduce -Reduce::Reduce(CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init, Span span) { +Reduce::Reduce(CommReducer combiner, ffi::Array source, ffi::Array axis, + PrimExpr condition, int value_index, ffi::Array init, Span span) { for (size_t i = 0; i < axis.size(); ++i) { ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); } - auto n = make_object(); + auto n = ffi::make_object(); ICHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { ICHECK(axis[i].defined()); @@ -776,11 +780,11 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Reduce", - [](CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init, Span span) { - return Reduce(combiner, source, axis, condition, value_index, init, span); - }); + refl::GlobalDef().def( + "tir.Reduce", [](CommReducer combiner, ffi::Array source, ffi::Array axis, + PrimExpr condition, int value_index, ffi::Array init, Span span) { + return Reduce(combiner, source, axis, condition, value_index, init, span); + }); }); // BufferLoad @@ -812,8 +816,8 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional predicate, - Span span) { +BufferLoad::BufferLoad(Buffer buffer, ffi::Array indices, + ffi::Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -841,7 +845,7 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); node->predicate = std::move(predicate); @@ -852,14 +856,15 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional indices, Optional predicate, - Span span) { return BufferLoad(buffer, indices, predicate, span); }); + refl::GlobalDef().def("tir.BufferLoad", [](Buffer buffer, ffi::Array indices, + ffi::Optional predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); + }); }); // ProducerLoad -ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span span) { - ObjectPtr node = make_object(); +ProducerLoad::ProducerLoad(DataProducer producer, ffi::Array indices, Span span) { + ObjectPtr node = ffi::make_object(); node->dtype = producer->GetDataType(); node->producer = std::move(producer); node->indices = std::move(indices); @@ -870,7 +875,7 @@ ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.ProducerLoad", - [](DataProducer producer, Array indices, Span span) { + [](DataProducer producer, ffi::Array indices, Span span) { return ProducerLoad(producer, indices, span); }); }); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 05e333b78ac6..19277d1013c1 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -111,7 +111,7 @@ void ExprVisitor::VisitExpr_(const ShuffleNode* op) { void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return ffi::GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); @@ -119,9 +119,9 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return BufferLoad(op->buffer, indices, op->predicate); } @@ -129,9 +129,9 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return ProducerLoad(op->producer, indices); } @@ -141,7 +141,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -149,17 +149,17 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array args = op->args.Map(fmutate); + ffi::Array args = op->args.Map(fmutate); if (args.same_as(op->args)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Call(op->dtype, op->op, args); } } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } + PrimExpr ExprMutator::VisitExpr_(const OP* op) { return ffi::GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) @@ -170,7 +170,7 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return OP(a, b); \ } \ @@ -205,17 +205,17 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag); } }; - Array axis = op->axis.Map(fitervar); + ffi::Array axis = op->axis.Map(fitervar); auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array source = op->source.Map(fexpr); - Array init = op->init.Map(fexpr); + ffi::Array source = op->source.Map(fexpr); + ffi::Array init = op->init.Map(fexpr); PrimExpr condition = this->VisitExpr(op->condition); if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) && init.same_as(op->init)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Reduce(op->combiner, source, axis, condition, op->value_index, init); } @@ -224,7 +224,7 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Cast(op->dtype, value); } @@ -233,7 +233,7 @@ PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Not(a); } @@ -245,7 +245,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(condition, true_value, false_value); } @@ -256,7 +256,7 @@ PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr stride = this->VisitExpr(op->stride); PrimExpr lanes = this->VisitExpr(op->lanes); if (base.same_as(op->base) && stride.same_as(op->stride) && lanes.same_as(op->lanes)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Ramp(base, stride, lanes); } @@ -266,7 +266,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr lanes = this->VisitExpr(op->lanes); if (value.same_as(op->value) && lanes.same_as(op->lanes)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(value, lanes); } @@ -277,7 +277,7 @@ PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto vectors = op->vectors.Map(fexpr); auto indices = op->indices.Map(fexpr); if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Shuffle(vectors, indices); } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index c8769222e02d..9b4f559fd0a8 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -38,7 +38,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace { relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { - Array params; + ffi::Array params; for (const auto& param : prim_func->params) { relax::StructInfo param_sinfo = [&]() -> relax::StructInfo { if (auto opt_buf = prim_func->buffer_map.Get(param)) { @@ -62,7 +62,7 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { if (const auto* prim = prim_func->ret_type.as()) { return relax::PrimStructInfo(prim->dtype); } else if (IsVoidType(prim_func->ret_type)) { - return relax::TupleStructInfo(Array{}); + return relax::TupleStructInfo(ffi::Array{}); } else { return relax::ObjectStructInfo(); } @@ -75,8 +75,8 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { } // namespace // Get the function type of a PrimFunc -PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs, Span span) { +PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { if (!attrs.defined()) { attrs = DictAttrs(); } @@ -85,7 +85,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, ret_type = VoidType(); } - auto n = make_object(); + auto n = ffi::make_object(); n->params = std::move(params); n->body = std::move(body); n->ret_type = std::move(ret_type); @@ -99,7 +99,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, } FuncType PrimFuncNode::func_type_annotation() const { - Array param_types; + ffi::Array param_types; for (auto param : this->params) { param_types.push_back(GetType(param)); } @@ -108,7 +108,7 @@ FuncType PrimFuncNode::func_type_annotation() const { class TensorIntrinManager { public: - Map reg; + ffi::Map reg; static TensorIntrinManager* Global() { static TensorIntrinManager* inst = new TensorIntrinManager(); @@ -129,13 +129,13 @@ TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { } ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->desc = std::move(desc); n->impl = std::move(impl); data_ = std::move(n); } -void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) { +void TensorIntrin::Register(ffi::String name, TensorIntrin intrin, bool override) { TensorIntrinManager* manager = TensorIntrinManager::Global(); if (!override) { CHECK_EQ(manager->reg.count(name), 0) @@ -144,7 +144,7 @@ void TensorIntrin::Register(String name, TensorIntrin intrin, bool override) { manager->reg.Set(name, intrin); } -Optional TensorIntrin::Get(String name, bool allow_missing) { +ffi::Optional TensorIntrin::Get(ffi::String name, bool allow_missing) { const TensorIntrinManager* manager = TensorIntrinManager::Global(); auto it = manager->reg.find(name); if (it == manager->reg.end()) { @@ -161,8 +161,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.PrimFunc", - [](Array params, Stmt body, Type ret_type, Map buffer_map, - DictAttrs attrs, + [](ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }) .def("tir.TensorIntrin", [](PrimFunc desc_func, PrimFunc intrin_func) { diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 901a5d5234ca..c9f21b1b38ec 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -30,14 +30,14 @@ namespace tir { // Implementation of Visitors template -inline void VisitArray(const Array& arr, F fvisit) { +inline void VisitArray(const ffi::Array& arr, F fvisit) { for (size_t i = 0; i < arr.size(); i++) { fvisit(arr[i]); } } template -inline Array MutateArray(Array arr, F fmutate) { +inline ffi::Array MutateArray(ffi::Array arr, F fmutate) { return arr.Map(fmutate); } diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 5c2541b10b1e..0ac6a9ab341b 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -37,18 +37,19 @@ namespace tir { TVM_FFI_STATIC_INIT_BLOCK({ IndexMapNode::RegisterReflection(); }); -IndexMap::IndexMap(Array initial_indices, Array final_indices, - Optional inverse_index_map) { - auto n = make_object(); +IndexMap::IndexMap(ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map) { + auto n = ffi::make_object(); n->initial_indices = std::move(initial_indices); n->final_indices = std::move(final_indices); n->inverse_index_map = std::move(inverse_index_map); data_ = std::move(n); } -IndexMap IndexMap::FromFunc(int ndim, ffi::TypedFunction(Array)> func, - Optional inverse_index_map) { - Array initial_indices; +IndexMap IndexMap::FromFunc(int ndim, + ffi::TypedFunction(ffi::Array)> func, + ffi::Optional inverse_index_map) { + ffi::Array initial_indices; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32))); @@ -57,7 +58,7 @@ IndexMap IndexMap::FromFunc(int ndim, ffi::TypedFunction(Array IndexMapInverseImpl(const IndexMap& self, - const Array& initial_ranges, + const ffi::Array& initial_ranges, arith::IterMapLevel check_level, arith::Analyzer* analyzer) { ICHECK(analyzer != nullptr); @@ -70,7 +71,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, } // Dummy variables to represent the inverse's inputs. - Array output_vars; + ffi::Array output_vars; for (size_t i = 0; i < self->final_indices.size(); i++) { PrimExpr index = self->final_indices[i]; // TODO(Lunderberg): Better names for these variables. A variable @@ -85,7 +86,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, } // Dummy ranges for the extent of each input. - Map input_iters; + ffi::Map input_iters; ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); for (size_t i = 0; i < initial_ranges.size(); i++) { input_iters.Set(self->initial_indices[i], initial_ranges[i]); @@ -101,11 +102,11 @@ std::pair IndexMapInverseImpl(const IndexMap& self, // Determine expressions for the input variables, in terms of the // output variables. - Map inverse_exprs_map = InverseAffineIterMap( - padded_iter_map->indices, Array(output_vars.begin(), output_vars.end())); + ffi::Map inverse_exprs_map = InverseAffineIterMap( + padded_iter_map->indices, ffi::Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. - Array inverse_exprs; + ffi::Array inverse_exprs; for (int i = 0, n = self->initial_indices.size(); i < n; ++i) { Var index = self->initial_indices[i]; PrimExpr expr; @@ -137,13 +138,13 @@ std::pair IndexMapInverseImpl(const IndexMap& self, return {IndexMap(output_vars, inverse_exprs), padding_predicate}; } -std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges, +std::pair IndexMap::NonSurjectiveInverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer); } -IndexMap IndexMap::Inverse(Array initial_ranges, arith::Analyzer* analyzer) const { +IndexMap IndexMap::Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); auto [inverse, padding_predicate] = IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer); @@ -153,18 +154,18 @@ IndexMap IndexMap::Inverse(Array initial_ranges, arith::Analyzer* analyze return inverse; } -Array IndexMapNode::MapIndices(const Array& indices, - arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapIndices(const ffi::Array& indices, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(indices.size(), initial_indices.size()); - Map vmap; + ffi::Map vmap; for (size_t i = 0; i < initial_indices.size(); i++) { vmap.Set(initial_indices[i], indices[i]); } - Array output = final_indices.Map([&](PrimExpr index) { + ffi::Array output = final_indices.Map([&](PrimExpr index) { PrimExpr result = SubstituteWithDataTypeLegalization( std::move(index), [&](const Var& var) { return vmap.Get(var); }); return analyzer->Simplify(result); @@ -172,24 +173,25 @@ Array IndexMapNode::MapIndices(const Array& indices, return output; } -Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(ranges.size(), initial_indices.size()); - Map input_iters; + ffi::Map input_iters; for (size_t i = 0; i < initial_indices.size(); i++) { input_iters.Set(initial_indices[i], ranges[i]); } auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 1, /*check_level=*/arith::IterMapLevel::NoCheck, analyzer, /*simplify_trivial_iterators=*/false); - Array output; + ffi::Array output; if (iter_map->indices.size()) { // Preferred route, requires the map to be expressible as an // affine sum. Since the terms are orthogonal, the extent of the // sum is the extent of the largest term. for (const auto& index : iter_map->indices) { - Optional extent = std::nullopt; + ffi::Optional extent = std::nullopt; for (const auto& term : index->args) { PrimExpr term_extent = term->extent * term->scale; if (extent.defined()) { @@ -235,18 +237,18 @@ Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer return output; } -Array IndexMapNode::MapShape(const Array& shape, - arith::Analyzer* analyzer) const { +ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, + arith::Analyzer* analyzer) const { ICHECK(analyzer != nullptr); ICHECK_EQ(shape.size(), initial_indices.size()); - Array ranges; + ffi::Array ranges; for (auto& dim : shape) { ranges.push_back(Range(make_zero(dim.dtype()), dim)); } - Array mapped = MapRanges(std::move(ranges), analyzer); + ffi::Array mapped = MapRanges(std::move(ranges), analyzer); - Array output; + ffi::Array output; for (auto& range : mapped) { ICHECK(is_zero(range->min)); output.push_back(range->extent); @@ -262,7 +264,7 @@ runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { << "The rank of the input array should be " << initial_indices.size() << " but got " << shape.size(); size_t size_1d = 1; - Array orig_shape; + ffi::Array orig_shape; for (size_t i = 0; i < shape.size(); ++i) { size_1d *= shape[i]; orig_shape.push_back(PrimExpr(static_cast((shape[i])))); @@ -283,7 +285,7 @@ runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { for (size_t i = 0; i < size_1d; ++i) { // Convert a linear coordinate to an N-d coordinate tuple // z * height * width + y * width + x -> (z, y, x) - Array src_indices; + ffi::Array src_indices; auto div_factor = size_1d; auto src_linear_index = i; for (auto s : shape) { @@ -311,9 +313,9 @@ runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { } IndexMap IndexMap::RenameVariables( - const std::function(const Var& var)>& f_name_map) const { + const std::function(const Var& var)>& f_name_map) const { std::unordered_set used_names; - Map var_remap; + ffi::Map var_remap; NameSupply name_supply; const IndexMapNode* n = this->get(); if (f_name_map != nullptr) { @@ -329,8 +331,8 @@ IndexMap IndexMap::RenameVariables( } visited.emplace(obj.get()); Var var = Downcast(obj); - if (Optional opt_name = f_name_map(var); opt_name.has_value()) { - String name = opt_name.value(); + if (ffi::Optional opt_name = f_name_map(var); opt_name.has_value()) { + ffi::String name = opt_name.value(); ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); name_supply->ReserveName(name, /*add_prefix=*/false); var_remap.Set(var, Var(name, var->dtype)); @@ -344,7 +346,8 @@ IndexMap IndexMap::RenameVariables( // The name of the variable is pre-defined. continue; } - String unique_name = name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false); + ffi::String unique_name = + name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false); if (unique_name != initial_index->name_hint) { var_remap.Set(initial_index, Var(unique_name)); } @@ -354,7 +357,7 @@ IndexMap IndexMap::RenameVariables( [&](const Var& var) { return Downcast(Substitute(var, var_remap)); }); auto new_final_indices = n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); }); - Optional new_inverse_index_map = std::nullopt; + ffi::Optional new_inverse_index_map = std::nullopt; if (n->inverse_index_map.defined()) { new_inverse_index_map = Downcast(n->inverse_index_map).RenameVariables(f_name_map); } @@ -367,10 +370,10 @@ IndexMap IndexMap::RenameVariables( * \param final_indices The final indices in the index map. * \return The lambda expression string. */ -std::string IndexMap2PythonLambdaExpr(const Array& initial_indices, - const Array& final_indices) { +std::string IndexMap2PythonLambdaExpr(const ffi::Array& initial_indices, + const ffi::Array& final_indices) { std::unordered_set used_names; - Map var_remap; + ffi::Map var_remap; std::ostringstream oss; oss << "lambda "; for (size_t i = 0; i < initial_indices.size(); ++i) { @@ -391,13 +394,13 @@ std::string IndexMap2PythonLambdaExpr(const Array& initial_indices, return oss.str(); } -String IndexMapNode::ToPythonString( - const std::function(const Var& var)>& f_name_map) const { - auto index_map = GetRef(this).RenameVariables(f_name_map); +ffi::String IndexMapNode::ToPythonString( + const std::function(const Var& var)>& f_name_map) const { + auto index_map = ffi::GetRef(this).RenameVariables(f_name_map); std::string lambda_expr = IndexMap2PythonLambdaExpr(index_map->initial_indices, index_map->final_indices); if (!index_map->inverse_index_map.defined()) { - return String(lambda_expr); + return ffi::String(lambda_expr); } // Also convert the inverse index map. IndexMap inverse = Downcast(index_map->inverse_index_map.value()); @@ -406,14 +409,14 @@ String IndexMapNode::ToPythonString( std::ostringstream oss; oss << "tvm.tir.IndexMap.from_func(" << lambda_expr << ", inverse_index_map=" << inverse_lambda_expr << ")"; - return String(oss.str()); + return ffi::String(oss.str()); } IndexMap Substitute(const IndexMap& index_map, - std::function(const Var& var)> f_subst) { - Array new_output = + std::function(const Var& var)> f_subst) { + ffi::Array new_output = index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); }); - Optional new_inverse_map = std::nullopt; + ffi::Optional new_inverse_map = std::nullopt; if (index_map->inverse_index_map.defined()) { new_inverse_map = Substitute(Downcast(index_map->inverse_index_map.value()), f_subst); } @@ -424,32 +427,33 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.IndexMap", - [](Array initial_indices, Array final_indices, - Optional inverse_index_map) { + [](ffi::Array initial_indices, ffi::Array final_indices, + ffi::Optional inverse_index_map) { return IndexMap(initial_indices, final_indices, inverse_index_map); }) .def("tir.IndexMapMapIndices", - [](IndexMap map, Array indices) { + [](IndexMap map, ffi::Array indices) { arith::Analyzer analyzer; return map->MapIndices(indices, &analyzer); }) .def("tir.IndexMapMapShape", - [](IndexMap map, Array shape) { + [](IndexMap map, ffi::Array shape) { arith::Analyzer analyzer; return map->MapShape(shape, &analyzer); }) .def("tir.IndexMapInverse", - [](IndexMap map, Array initial_ranges) { + [](IndexMap map, ffi::Array initial_ranges) { arith::Analyzer analyzer; return map.Inverse(initial_ranges, &analyzer); }) .def("tir.IndexMapMapTensor", [](IndexMap map, runtime::Tensor arr) { return map->MapTensor(arr); }) - .def("tir.IndexMapNonSurjectiveInverse", [](IndexMap forward, Array initial_ranges) { - arith::Analyzer analyzer; - auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); - return Array{result.first, result.second}; - }); + .def("tir.IndexMapNonSurjectiveInverse", + [](IndexMap forward, ffi::Array initial_ranges) { + arith::Analyzer analyzer; + auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); + return ffi::Array{result.first, result.second}; + }); }); } // namespace tir diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index cf5e7e80a893..871452aeb946 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -392,7 +392,7 @@ class PyStmtExprVisitor : public ObjectRef { ffi::Function f_visit_int_imm, // ffi::Function f_visit_float_imm, // ffi::Function f_visit_string_imm) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Set statement functions @@ -756,7 +756,7 @@ class PyStmtExprMutator : public ObjectRef { ffi::Function f_visit_int_imm, // ffi::Function f_visit_float_imm, // ffi::Function f_visit_string_imm) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->f_visit_stmt = std::move(f_visit_stmt); n->f_visit_expr = std::move(f_visit_expr); // Statement functions diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index d18bda77fab6..e94a3bfd9b82 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,10 +36,11 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} + explicit ScriptCompleter(ffi::Map* buffer_var_map) + : buffer_var_map_(buffer_var_map) {} private: - Map* buffer_var_map_; + ffi::Map* buffer_var_map_; Stmt VisitStmt_(const BlockRealizeNode* op) final { for (const PrimExpr& value : op->iter_values) { CHECK(value.dtype().is_int()) @@ -81,9 +82,9 @@ class ScriptCompleter : public StmtMutator { // ignore root block or blocks which already has reads/writes regions if (mask != 0) { auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); - const Array& reads = access_region[0]; - const Array& writes = access_region[1]; - const Array& opaque = access_region[2]; + const ffi::Array& reads = access_region[0]; + const ffi::Array& writes = access_region[1]; + const ffi::Array& opaque = access_region[2]; CHECK(opaque.empty()) << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " "direct access by buffer data. Please annotation the access region manually"; @@ -114,8 +115,8 @@ class ScriptCompleter : public StmtMutator { bool is_root_block_ = true; }; -PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { - Map buffer_var_map; +PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates) { + ffi::Map buffer_var_map; for (const auto& pair : func->buffer_map) { const Buffer& buffer = pair.second; buffer_var_map.Set(buffer->data, buffer); diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h index 273ca946a7ff..1facab664346 100644 --- a/src/tir/ir/script/script_complete.h +++ b/src/tir/ir/script/script_complete.h @@ -30,7 +30,7 @@ namespace tvm { namespace tir { -PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates); +PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 69a7c293b19f..7e92cc4e6983 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -54,7 +54,7 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { PrimExpr a = VisitExpr(op->a); \ PrimExpr b = VisitExpr(op->b); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return BinaryFunc(a, b); \ } \ @@ -63,7 +63,7 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { PrimExpr VisitExpr_(const UnaryNode* op) final { \ PrimExpr a = VisitExpr(op->a); \ if (a.same_as(op->a)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return UnaryFunc(a); \ } \ @@ -77,7 +77,7 @@ class PrimFuncSpecializer : public StmtExprMutator { static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { PrimFuncSpecializer specializer(var_map); // Updating Buffer map - Map buffer_map; + ffi::Map buffer_map; bool buffer_map_updated = false; for (const auto& it : f->buffer_map) { const Var& var = it.first; @@ -91,7 +91,7 @@ class PrimFuncSpecializer : public StmtExprMutator { } // Updating parmeters - Array params; + ffi::Array params; bool param_updated = false; for (const auto& var : f->params) { // Remove parmeters which has been specialized. @@ -115,7 +115,7 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block - Array alloc_buffers = + ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const auto& buf) { return MutateAllocBuffer(buf); }); // Step.1. Recursively visit block body @@ -123,14 +123,14 @@ class PrimFuncSpecializer : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - Array reads = + ffi::Array reads = op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); }); - Array writes = + ffi::Array writes = op->writes.Map([this](const auto& region) { return MutateBufferRegion(region); }); if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && writes.same_as(op->writes)) { - return GetRef(op); + return ffi::GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); n->alloc_buffers = std::move(alloc_buffers); @@ -184,7 +184,7 @@ class PrimFuncSpecializer : public StmtExprMutator { auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->buffer = new_buf; @@ -199,18 +199,18 @@ class PrimFuncSpecializer : public StmtExprMutator { auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = new_buf; return PrimExpr(n); } } PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_map_.find(GetRef(op)); + auto it = var_map_.find(ffi::GetRef(op)); if (it == var_map_.end()) { - return GetRef(op); + return ffi::GetRef(op); } else { return it->second; } @@ -242,8 +242,9 @@ class PrimFuncSpecializer : public StmtExprMutator { // of Var-to-PrimExpr remapping. Var data = VisitExpr(buffer->data).as().value_or(buffer->data); - Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); - Array strides = + ffi::Array shape = + buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); + ffi::Array strides = buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); @@ -252,7 +253,7 @@ class PrimFuncSpecializer : public StmtExprMutator { buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { return buffer; } else { - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->elem_offset = std::move(elem_offset); n->shape = std::move(shape); @@ -304,7 +305,7 @@ class PrimFuncSpecializer : public StmtExprMutator { BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { auto it = buffer_map_.find(buffer_region->buffer); const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer; - Array region = buffer_region->region.Map( + ffi::Array region = buffer_region->region.Map( std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { return buffer_region; @@ -415,11 +416,11 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map>& param_map) { +PrimFunc Specialize(PrimFunc func, const ffi::Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; - const Variant& instance = kv.second; + const ffi::Variant& instance = kv.second; if (auto opt_buffer = instance.as()) { UpdateSpecializeVarMap(func, param, opt_buffer.value(), &var_map); } else if (auto opt_expr = instance.as()) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 305dd5ec9af6..0f50d5336af6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -66,7 +66,7 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK_EQ(value.dtype(), var.dtype()); } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); @@ -82,8 +82,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // AttrStmt -AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span) { - auto n = make_object(); +AttrStmt::AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { + auto n = ffi::make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -95,7 +95,7 @@ AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Sp TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.AttrStmt", - [](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { + [](Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to // primexpr. if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { @@ -114,7 +114,7 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); @@ -132,7 +132,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding, Map annotations, Span span) { + ffi::Optional thread_binding, ffi::Map annotations, Span span) { ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); @@ -168,7 +168,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -182,12 +182,13 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, - Stmt body, Optional thread_binding, - Optional> annotations, Span span) { - return For(loop_var, min, extent, static_cast(kind), body, thread_binding, - annotations.value_or(Map()), span); - }); + refl::GlobalDef().def( + "tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, + ffi::Optional thread_binding, + ffi::Optional> annotations, Span span) { + return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + annotations.value_or(ffi::Map()), span); + }); }); std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) @@ -218,7 +219,7 @@ While::While(PrimExpr condition, Stmt body, Span span) { ICHECK(condition.as() == nullptr) << "The condition should not be trivial."; ICHECK(body.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); @@ -233,8 +234,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Allocate -Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { +Allocate::Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, + Stmt body, ffi::Map annotations, Span span) { CHECK(IsPointerType(buffer_var->type_annotation, dtype) || (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) << "The allocated data type (" << dtype @@ -250,7 +251,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -261,7 +262,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim data_ = std::move(node); } -int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { +int64_t AllocateNode::ConstantAllocationSize(const ffi::Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { @@ -279,8 +280,9 @@ int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.Allocate", [](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body, Map annotations, Span span) { + "tir.Allocate", + [](Var buffer_var, DataType type, ffi::Array extents, PrimExpr condition, Stmt body, + ffi::Map annotations, Span span) { return Allocate(buffer_var, type, extents, condition, body, annotations, span); }); }); @@ -289,9 +291,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ // The constructor to create a IRNode with constant data // depending on the type of ObjectRef, it will either // create AllocateConstNode with irmod_storage_idx or data -AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array extents, - ObjectRef data_or_idx, Stmt body, Map annotations, - Span span) { +AllocateConst::AllocateConst(Var buffer_var, DataType dtype, ffi::Array extents, + ObjectRef data_or_idx, Stmt body, + ffi::Map annotations, Span span) { ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" @@ -305,7 +307,7 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -313,18 +315,18 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext node->annotations = annotations; node->span = std::move(span); if (data_or_idx->IsInstance()) { - node->data = Optional(Downcast(data_or_idx)); - node->irmod_storage_idx = Optional(); + node->data = ffi::Optional(Downcast(data_or_idx)); + node->irmod_storage_idx = ffi::Optional(); } else if (data_or_idx->IsInstance()) { - node->data = Optional(); - node->irmod_storage_idx = Optional(Downcast(data_or_idx)); + node->data = ffi::Optional(); + node->irmod_storage_idx = ffi::Optional(Downcast(data_or_idx)); } else { LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey(); } data_ = std::move(node); } -int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents) { +int64_t AllocateConstNode::ConstantAllocationSize(const ffi::Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { if (const IntImmNode* int_size = extents[i].as()) { @@ -342,8 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "tir.AllocateConst", - [](Var buffer_var, DataType dtype, Array extents, ObjectRef data_or_idx, Stmt body, - Optional> annotations, Span span) { + [](Var buffer_var, DataType dtype, ffi::Array extents, ObjectRef data_or_idx, + Stmt body, ffi::Optional> annotations, Span span) { return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations.value_or({}), span); }); @@ -351,7 +353,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // DeclBuffer DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->body = std::move(body); node->span = std::move(span); @@ -366,7 +368,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // SeqStmt -SeqStmt::SeqStmt(Array seq, Span span) { +SeqStmt::SeqStmt(ffi::Array seq, Span span) { bool requires_flattening = std::any_of( seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance(); }); @@ -386,7 +388,7 @@ SeqStmt::SeqStmt(Array seq, Span span) { << "Use the node " << seq[0] << "directly, " << "or for dynamic usage, normalize using SeqStmt::Flatten()"; - auto node = make_object(); + auto node = ffi::make_object(); node->seq = std::move(seq); node->span = std::move(span); data_ = std::move(node); @@ -394,16 +396,17 @@ SeqStmt::SeqStmt(Array seq, Span span) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.SeqStmt", - [](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); + refl::GlobalDef().def( + "tir.SeqStmt", [](ffi::Array seq, Span span) { return SeqStmt(std::move(seq), span); }); }); // IfThenElse -IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case, Span span) { +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional else_case, + Span span) { ICHECK(condition.defined()); ICHECK(then_case.defined()); // else_case may be null. - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); @@ -423,7 +426,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ Evaluate::Evaluate(PrimExpr value, Span span) { ICHECK(value.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->value = std::move(value); node->span = std::move(span); data_ = std::move(node); @@ -436,8 +439,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -502,7 +505,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, << "`, but RHS's dtype is `" << value.dtype() << "`"; } - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); @@ -513,21 +516,22 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.BufferStore", - [](Buffer buffer, PrimExpr value, Array indices, Optional predicate, - Span span) { return BufferStore(buffer, value, indices, predicate, span); }); + refl::GlobalDef().def("tir.BufferStore", + [](Buffer buffer, PrimExpr value, ffi::Array indices, + ffi::Optional predicate, Span span) { + return BufferStore(buffer, value, indices, predicate, span); + }); }); // BufferRealize -BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, +BufferRealize::BufferRealize(Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span) { - data_ = make_object(buffer, bounds, condition, body, span); + data_ = ffi::make_object(buffer, bounds, condition, body, span); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer, Array bounds, + refl::GlobalDef().def("tir.BufferRealize", [](Buffer buffer, ffi::Array bounds, PrimExpr condition, Stmt body, Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); @@ -536,7 +540,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ // BufferRegion PrimExpr BufferRegionNode::ToPrimExpr() const { // Auto convert to PrimExpr if it is a single point load - Array indices; + ffi::Array indices; indices.reserve(this->region.size()); for (const Range& r : this->region) { if (tvm::tir::is_one(r->extent)) { @@ -544,32 +548,32 @@ PrimExpr BufferRegionNode::ToPrimExpr() const { } else if (r->extent.as()) { indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << GetRef(this); + LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ffi::GetRef(this); } } return tir::BufferLoad(this->buffer, indices); } -BufferRegion::BufferRegion(Buffer buffer, Array region) { +BufferRegion::BufferRegion(Buffer buffer, ffi::Array region) { CHECK_EQ(buffer->shape.size(), region.size()) << "The dimension between " << buffer << " and region " << region << " mismatched, the buffer is " << buffer; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->region = std::move(region); data_ = std::move(node); } BufferRegion BufferRegion::FullRegion(Buffer buffer) { - Array region; + ffi::Array region; for (PrimExpr extent : buffer->shape) { region.push_back(Range::FromMinExtent(0, extent)); } return BufferRegion(buffer, region); } -BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { - Array region; +BufferRegion BufferRegion::FromPoint(Buffer buffer, ffi::Array indices) { + ffi::Array region; for (const PrimExpr& index : indices) { if (const RampNode* ramp_index = index.as()) { region.push_back( @@ -583,7 +587,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, Array region) { + refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, ffi::Array region) { return BufferRegion(buffer, region); }); }); @@ -633,7 +637,7 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Note that we do not check elem_offset and strides in this function // Construction - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->buffer = std::move(buffer); node->source = std::move(source); data_ = std::move(node); @@ -647,10 +651,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); // Block -Block::Block(Array iter_vars, Array reads, Array writes, - String name_hint, Stmt body, Optional init, Array alloc_buffers, - Array match_buffers, Map annotations, Span span) { - ObjectPtr node = make_object(); +Block::Block(ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init, ffi::Array alloc_buffers, + ffi::Array match_buffers, ffi::Map annotations, + Span span) { + ObjectPtr node = ffi::make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); node->writes = std::move(writes); @@ -666,22 +672,24 @@ Block::Block(Array iter_vars, Array reads, Array iter_vars, Array reads, Array writes, - String name_hint, Stmt body, Optional init, Array alloc_buffers, - Array match_buffers, Map annotations, Span span) { - return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, - annotations, span); - }); + refl::GlobalDef().def("tir.Block", + [](ffi::Array iter_vars, ffi::Array reads, + ffi::Array writes, ffi::String name_hint, Stmt body, + ffi::Optional init, ffi::Array alloc_buffers, + ffi::Array match_buffers, + ffi::Map annotations, Span span) { + return Block(iter_vars, reads, writes, name_hint, body, init, + alloc_buffers, match_buffers, annotations, span); + }); }); // BlockRealize -BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block block, Span span) { +BlockRealize::BlockRealize(ffi::Array values, PrimExpr predicate, Block block, + Span span) { CHECK_EQ(block->iter_vars.size(), values.size()) << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression"; - ObjectPtr node = make_object(); + ObjectPtr node = ffi::make_object(); node->iter_values = std::move(values); node->predicate = std::move(predicate); node->block = std::move(block); @@ -691,7 +699,7 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BlockRealize", [](Array iter_values, PrimExpr predicate, + refl::GlobalDef().def("tir.BlockRealize", [](ffi::Array iter_values, PrimExpr predicate, Block block, Span span) { return BlockRealize(iter_values, predicate, block, span); }); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e580f22f6b7f..0e2759f3c4a4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -152,22 +152,22 @@ class StmtMutator::Internal { * \return The mutated array, a new copy can be created. */ template - static Array MutateArray(StmtMutator* self, const Array& arr, F fmutate) { + static ffi::Array MutateArray(StmtMutator* self, const ffi::Array& arr, F fmutate) { if (self->allow_copy_on_write_ && arr.unique()) { // if we allow copy on write, we can directly // call the inplace mutate function. - const_cast&>(arr).MutateByApply(fmutate); + const_cast&>(arr).MutateByApply(fmutate); return arr; } else { bool allow_cow = false; std::swap(allow_cow, self->allow_copy_on_write_); - Array copy = arr.Map(fmutate); + ffi::Array copy = arr.Map(fmutate); std::swap(allow_cow, self->allow_copy_on_write_); return copy; } } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const IterVar& iter_var) { PrimExpr min = self->VisitExpr(iter_var->dom->min); PrimExpr extent = self->VisitExpr(iter_var->dom->extent); @@ -181,17 +181,17 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const Range& r) { PrimExpr min = self->VisitExpr(r->min); PrimExpr extent = self->VisitExpr(r->extent); @@ -204,9 +204,9 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, const ffi::Array& arr) { auto fmutate = [self](const BufferRegion& buffer_region) { - Array region = Mutate(self, buffer_region->region); + ffi::Array region = Mutate(self, buffer_region->region); if (region.same_as(buffer_region->region)) { return buffer_region; } else { @@ -216,9 +216,10 @@ class StmtMutator::Internal { return MutateArray(self, arr, fmutate); } - static Array Mutate(StmtMutator* self, const Array& arr) { + static ffi::Array Mutate(StmtMutator* self, + const ffi::Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { - Array region = Mutate(self, match_buffer_region->source->region); + ffi::Array region = Mutate(self, match_buffer_region->source->region); if (region.same_as(match_buffer_region->source->region)) { return match_buffer_region; } else { @@ -234,7 +235,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -247,7 +248,7 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -261,7 +262,7 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); @@ -275,7 +276,7 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -285,12 +286,12 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { - Array extents = Internal::Mutate(this, op->extents); + ffi::Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); @@ -301,11 +302,11 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { - Array extents = Internal::Mutate(this, op->extents); + ffi::Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); if (extents.same_as(op->extents) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); @@ -318,7 +319,7 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt body = this->VisitStmt(op->body); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->body = std::move(body); @@ -329,13 +330,13 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -347,10 +348,10 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { PrimExpr value = this->VisitExpr(op->value); - Array indices = Internal::Mutate(this, op->indices); + ffi::Array indices = Internal::Mutate(this, op->indices); if (value.same_as(op->value) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -365,7 +366,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Stmt body = this->VisitStmt(op->body); if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->bounds = std::move(bounds); @@ -376,9 +377,9 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { - Array seq = Internal::Mutate(this, op->seq); + ffi::Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { - return SeqStmt::Flatten(GetRef(op)); + return SeqStmt::Flatten(ffi::GetRef(op)); } else { auto node = CopyOnWrite(op); node->seq = std::move(seq); @@ -400,10 +401,10 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { - Array seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) - : Internal::Mutate(this, op->seq); + ffi::Array seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) + : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->seq = std::move(seq); @@ -411,7 +412,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit } }; if (flatten_before_visit) { - Array seq; + ffi::Array seq; SeqStmt::Flattener flattener(&seq); flattener(0, op->seq); // NOTE: If copy on write is allowed @@ -435,7 +436,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->condition = std::move(condition); @@ -448,7 +449,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->value = std::move(value); @@ -457,11 +458,11 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } Stmt StmtMutator::VisitStmt_(const BlockNode* op) { - Array iter_vars = Internal::Mutate(this, op->iter_vars); - Array reads = Internal::Mutate(this, op->reads); - Array writes = Internal::Mutate(this, op->writes); - Array match_buffers = Internal::Mutate(this, op->match_buffers); - Optional init = std::nullopt; + ffi::Array iter_vars = Internal::Mutate(this, op->iter_vars); + ffi::Array reads = Internal::Mutate(this, op->reads); + ffi::Array writes = Internal::Mutate(this, op->writes); + ffi::Array match_buffers = Internal::Mutate(this, op->match_buffers); + ffi::Optional init = std::nullopt; if (op->init.defined()) { init = VisitStmt(op->init.value()); } @@ -469,7 +470,7 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && body.same_as(op->body) && init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_vars = std::move(iter_vars); @@ -483,11 +484,11 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { } Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { - Array v = Internal::Mutate(this, op->iter_values); + ffi::Array v = Internal::Mutate(this, op->iter_values); PrimExpr pred = this->VisitExpr(op->predicate); Stmt block = this->VisitStmt(op->block); if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->iter_values = std::move(v); @@ -575,7 +576,7 @@ class IRTransformer final : public StmtExprMutator { }; Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Function& f_postorder, - Optional> only_enable) { + ffi::Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { @@ -588,10 +589,10 @@ Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Funct class IRSubstitute : public StmtExprMutator { public: - explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} + explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto ret = vmap_(var); if (ret.defined()) { // Allow substitution of void variables with any expression. The TVM script parser @@ -679,7 +680,7 @@ class IRSubstitute : public StmtExprMutator { private: // Caller provided function that defines the variables to be remapped. - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /* \brief Generated map to track buffers being remapped. * @@ -691,11 +692,11 @@ class IRSubstitute : public StmtExprMutator { std::unordered_map buf_remap_; }; -Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { +Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(stmt)); } -PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { +PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstitute(vmap)(std::move(expr)); } @@ -743,14 +744,15 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { public: - explicit IRSubstituteWithDataTypeLegalization(std::function(const Var&)> vmap) + explicit IRSubstituteWithDataTypeLegalization( + std::function(const Var&)> vmap) : vmap_(vmap) {} using DataTypeLegalizer::VisitExpr_; using DataTypeLegalizer::VisitStmt_; PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto ret = vmap_(var); if (ret.defined()) { return ret.value(); @@ -811,7 +813,7 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { private: // Caller provided function that defines the variables to be remapped. - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /* \brief Generated map to track buffers being remapped. * @@ -824,12 +826,12 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { }; Stmt SubstituteWithDataTypeLegalization(Stmt stmt, - std::function(const Var&)> vmap) { + std::function(const Var&)> vmap) { return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt)); } -PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, - std::function(const Var&)> vmap) { +PrimExpr SubstituteWithDataTypeLegalization( + PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); } @@ -845,7 +847,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ObjectRef node, ffi::Function f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }) - .def("tir.Substitute", [](ObjectRef node, Map vmap) -> ObjectRef { + .def("tir.Substitute", [](ObjectRef node, ffi::Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); } else { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index aa3ca1959c5d..638340e0bd2f 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -43,7 +43,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, AccessPath path) { std::unordered_set externally_exposed; for (const auto& [gvar, func] : mod->functions) { gvars.push_back(gvar); - if (func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { + if (func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { externally_exposed.insert(gvar); } } @@ -193,7 +193,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any // symbolic shapes used within `buffer_view that are not already // defined. - Array arr = Downcast>(op->node); + ffi::Array arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2U); Buffer buffer_view = Downcast(arr[0]); Buffer orig_buffer = Downcast(arr[1]); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 0ff9da33eb6d..65673d1f2b34 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -85,7 +85,7 @@ class TIRVisitorWithPath // Utility to visit an array of nodes template - inline void Visit(const Array& arr, ffi::reflection::AccessPath path) { + inline void Visit(const ffi::Array& arr, ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { Visit(arr[i], path->ArrayItem(i)); } @@ -93,7 +93,7 @@ class TIRVisitorWithPath // Utility to visit an optional node nodes template - inline void Visit(const Optional& opt, ffi::reflection::AccessPath path) { + inline void Visit(const ffi::Optional& opt, ffi::reflection::AccessPath path) { if (opt) { Visit(opt.value(), path); } @@ -229,7 +229,7 @@ class TIRVisitorWithPath } }; auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def]( - const Array& arr, + const ffi::Array& arr, ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { try_visit_implicit_var_def(arr[i], path->ArrayItem(i)); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index aafe6277e24d..f52baa989728 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -43,7 +43,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", ffi::Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); @@ -102,7 +102,7 @@ class PrimFuncPass : public Pass { PrimFuncPass::PrimFuncPass(std::function pass_func, PassInfo pass_info) { - auto n = make_object(); + auto n = ffi::make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); data_ = std::move(n); @@ -141,7 +141,8 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } Pass CreatePrimFuncPass(std::function pass_func, - int opt_level, String name, tvm::Array required, bool traceable) { + int opt_level, ffi::String name, tvm::ffi::Array required, + bool traceable) { PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(std::move(pass_func), pass_info); } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 12c7c8d33c7f..fe095dbaa593 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -203,11 +203,11 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array) // When num_inputs are not set, the function is assumed to be variable length. TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_packed"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_packed"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_cpacked"), /*plevel=*/20); + .set_attr("TScriptPrinterName", ffi::String("call_cpacked"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -222,12 +222,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant) TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_packed_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_packed_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) - .set_attr("TScriptPrinterName", String("call_cpacked_lowered"), + .set_attr("TScriptPrinterName", ffi::String("call_cpacked_lowered"), /*plevel=*/20); TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 9ced6f556cb0..ea6f91002182 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -923,7 +923,7 @@ PrimExpr isinf(PrimExpr x, Span span) { // isfinite PrimExpr isfinite(PrimExpr x, Span span) { return !isinf(x, span) && !isnan(x, span); } -PrimExpr sum(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr sum(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Add(x, y, span); PrimExpr identity_element = make_zero(source.dtype(), span); @@ -931,7 +931,7 @@ PrimExpr sum(PrimExpr source, Array rdom, Array init, Span sp return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); } -PrimExpr all(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr all(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::all"); Var x("x", source.dtype(), span), y("y", source.dtype()); PrimExpr result = tir::And(x, y, span); @@ -940,7 +940,7 @@ PrimExpr all(PrimExpr source, Array rdom, Array init, Span sp return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); } -PrimExpr any(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr any(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::any"); Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Or(x, y, span); @@ -949,7 +949,7 @@ PrimExpr any(PrimExpr source, Array rdom, Array init, Span sp return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); } -PrimExpr max(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr max(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Max(x, y, span); PrimExpr identity_element = min_value(source.dtype(), span); @@ -957,7 +957,7 @@ PrimExpr max(PrimExpr source, Array rdom, Array init, Span sp return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); } -PrimExpr min(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr min(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Min(x, y, span); PrimExpr identity_element = max_value(source.dtype(), span); @@ -965,7 +965,7 @@ PrimExpr min(PrimExpr source, Array rdom, Array init, Span sp return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init, span); } -PrimExpr prod(PrimExpr source, Array rdom, Array init, Span span) { +PrimExpr prod(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Mul(x, y, span); PrimExpr identity_element = make_const(source.dtype(), 1, span); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 25d09ff931ea..8f3372b0ca17 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -230,7 +230,7 @@ bool IsWriteCache(const StmtSRef& block_sref); * \param analyzer The analyzer * \return A boolean flag indicating if the binding is affine */ -bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, +bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer); /*! @@ -251,7 +251,7 @@ void CheckAffineBinding(const ScheduleState& self, Block block); * \throw ScheduleError If the input block does not have an affine binding */ void CheckPartialAffineBinding(const ScheduleState& self, Block block, - const Optional& high_exclusive); + const ffi::Optional& high_exclusive); /*! * \brief Extracts the ranges of loop variables in a path of the sref tree @@ -263,17 +263,17 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, * - if the storage scope is shared, it will look for threadIdx.x/y/z * \return The loop domain */ -Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, - const Optional& high_exclusive = std::nullopt, - const runtime::StorageScope& extra_relax_scope = // - runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); +ffi::Map LoopDomainOfSRefTreePath( + const StmtSRef& low_inclusive, const ffi::Optional& high_exclusive = std::nullopt, + const runtime::StorageScope& extra_relax_scope = // + runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); /*! * \brief Returns the block var binding * \param realize The BlockRealize to be analyzed * \return The block var binding */ -Map GetBindings(const BlockRealize& realize); +ffi::Map GetBindings(const BlockRealize& realize); /*! * \brief Get the vars involved in the bindings of data parallel block vars and reduction block @@ -316,14 +316,15 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of StmtSRefs of leaf block */ -Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); +ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); /*! * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in * \param parent_sref The StmtSRef that points to the parent block/loop * \return A list of leaf BlockRealize */ -Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); /*! * \brief Get the BlockRealize of the single child block of the block or loop specified by @@ -357,7 +358,7 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref); * \return The lowest common ancestor of the input block srefs or loop srefs * \note The input array is required to have at least one sref */ -StmtSRef GetSRefLowestCommonAncestor(const Array& srefs); +StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs); /*! * \brief Checks if the given block has been applied by multi-level tiling. We check this by @@ -374,8 +375,8 @@ bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); * \return All the feasible compute-at locations of the input block, given as an array of loop srefs * and an array of their indices among the outer loops of the input block */ -std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, - const StmtSRef& block_sref); +std::pair, std::vector> CollectComputeLocation( + const ScheduleState& self, const StmtSRef& block_sref); /******** Producer-consumer relation ********/ @@ -385,7 +386,7 @@ std::pair, std::vector> CollectComputeLocation(const Schedu * \param scope The block scope where the given block is in * \return The producer blocks of the specified block */ -Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); /*! * \brief Get the consumer blocks to the given block under the given scope @@ -393,7 +394,7 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope * \param scope The block scope where the given block is in * \return The consumer blocks of the specified block */ -Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); +ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); /*! * \brief Get the list of output blocks within the given scope @@ -403,7 +404,7 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); +ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); /*! * \brief A solution to split a ordered list of subtrees into two parts, @@ -431,8 +432,9 @@ struct ProducerConsumerSplit { * \throw ScheduleError is not valid split is found */ static ProducerConsumerSplit Find( - const ScheduleState& state, const Array& subtrees, - const Array& producer_block_srefs, const Array& consumer_block_srefs, + const ScheduleState& state, const ffi::Array& subtrees, + const ffi::Array& producer_block_srefs, + const ffi::Array& consumer_block_srefs, std::unordered_map* block2realize); }; @@ -469,8 +471,8 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl * \return The defining site of the buffer and whether the buffer is allocated (otherwise the * buffer is from match_buffer). */ -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer); +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer); /******** Reduction Block Related ********/ @@ -481,8 +483,8 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ * \return The extracted init values and BufferStore updates * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ -std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( - const Optional& self, Block block); +std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( + const ffi::Optional& self, Block block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -491,7 +493,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct * \return A boolean indicating whether the input array of IterVars only contains data-parallel and * reduction block iters */ -bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters); +bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters); /*! * \brief Check whether the block's reduction block iters are not used to index the block's output @@ -511,9 +513,9 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block); * \return The corresponding CommReducer, combiner LHS values and combiner RHS values * \throw ScheduleError If no corresponding commutative reducer can be matched */ -std::tuple, Array> GetReducerAndCombinerLhsRhs( - const Optional& self, const Array& identities, - const Array& combiners); +std::tuple, ffi::Array> GetReducerAndCombinerLhsRhs( + const ffi::Optional& self, const ffi::Array& identities, + const ffi::Array& combiners); /******** Commutative Reducer ********/ @@ -522,7 +524,8 @@ std::tuple, Array> GetReducerAndCombinerL * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector(Array)>> GetReducerGetters(); +std::vector(ffi::Array)>> +GetReducerGetters(); /*! * \brief Given the input identities and the combiner BufferStores of a reduction, extract the @@ -534,8 +537,9 @@ std::vector(Array)>> GetReduc * \param rhs The extracted RHS values of the reducer * \return A boolean indicating whether a corresponding commutative reducer is found */ -bool FromIdentityCombiner(const Array& identities, const Array& combiners, - CommReducer* result_reducer, Array* lhs, Array* rhs); +bool FromIdentityCombiner(const ffi::Array& identities, + const ffi::Array& combiners, CommReducer* result_reducer, + ffi::Array* lhs, ffi::Array* rhs); /******** Misc ********/ @@ -545,7 +549,7 @@ bool FromIdentityCombiner(const Array& identities, const Array SuggestIndexMap(const Buffer& buffer, const Array& indices, - const Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer); +ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, + const ffi::Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Checks if the given AST contains the specific operators @@ -605,7 +609,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& * \param ops The list of operators to be checked * \return A boolean indicating whether the AST contains the specific operators */ -bool HasOp(const Stmt& stmt, const Array& ops); +bool HasOp(const Stmt& stmt, const ffi::Array& ops); /*! * \brief Checks if the given AST statement contains if-then-else, including @@ -697,10 +701,11 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // * \param dom_high_exclusive The highest node in the sref tree path * \return An n-dimensional integer set */ -Array AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); +ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, + const PrimExpr& predicate, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive, + arith::Analyzer* analyzer); /*! * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) @@ -712,10 +717,11 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, const P * \param analyzer The analyzer * \return An n-dimensional integer set */ -Array AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); +ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, + const PrimExpr& predicate, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive, + arith::Analyzer* analyzer); /*! * \brief Simplify non-trivial expressions @@ -733,13 +739,13 @@ PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) class TensorizeInfoNode : public Object { public: /*! \brief Maps loops in a target block to the ones in an intrinsic description */ - Map loop_map; + ffi::Map loop_map; /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ - Map desc_loop_indexer; + ffi::Map desc_loop_indexer; /*! \brief Optional padded extents of the block iters when padding is needed to match the * intrinsic description */ - Optional> block_iter_paddings; + ffi::Optional> block_iter_paddings; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -766,26 +772,27 @@ class TensorizeInfo : public ObjectRef { * \param allow_padding Whether to allow padding the block iters to match the intrinsic description * \return TensorizeInfo structure if a valid mapping is found, std::nullopt otherwise */ -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, bool allow_padding); +ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func, + bool allow_padding); /*!\brief Necessary information used to perform transformations for tensorization */ class AutoTensorizeMappingInfoNode : public Object { public: /*! \brief Possible mappings to apply to block iters */ - Array mappings; + ffi::Array mappings; /* Additional information from AutoTensorizeComparator */ /*! \brief Mapping from LHS buffer to RHS buffer */ - Map lhs_buffer_map; + ffi::Map lhs_buffer_map; /*! \brief Buffer indices on RHS */ - Map> rhs_buffer_indices; + ffi::Map> rhs_buffer_indices; /*! \brief Block iters on LHS */ - Array lhs_iters; + ffi::Array lhs_iters; /*! \brief Block iters on RHS */ - Array rhs_iters; + ffi::Array rhs_iters; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -818,9 +825,9 @@ class AutoTensorizeMappingInfo : public ObjectRef { * tensorized. We will need to apply the suggested layout transformations and then match against the * tensor intrinsics. */ -Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, - const StmtSRef& block_sref, - const PrimFunc& desc_func); +ffi::Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, + const StmtSRef& block_sref, + const PrimFunc& desc_func); /*! * \brief Perform basic checks for auto tensorization applicability, such as the structure of diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 91c63c3469bb..9607f02f1048 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -49,7 +49,7 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl } LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " "statement:\n" - << GetRef(root_block); + << ffi::GetRef(root_block); throw; } @@ -61,13 +61,13 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, public: explicit RootBlockError(IRModule mod) : mod_(mod) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive does not operate on the root block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The primitive does not operate on the root block"; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -75,10 +75,10 @@ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, public: explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The scope root is not a stage pipeline"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return R"(The scope {0} is not a stage pipeline. Definition of a scope that is a stage pipeline: - The region cover property holds for every of its child blocks @@ -87,7 +87,7 @@ Definition of a scope that is a stage pipeline: - All the statements in the scope are schedulable statements, i.e. Block and For )"; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -100,8 +100,8 @@ Definition of a scope that is a stage pipeline: const StmtSRefNode* subtree = sref.get(); for (; p != nullptr; subtree = p, p = p->parent) { if (p->stmt->IsInstance()) { - scope_root_sref = GetRef(p); - scope_root_subtree = GetRef(subtree); + scope_root_sref = ffi::GetRef(p); + scope_root_subtree = ffi::GetRef(subtree); break; } } @@ -114,7 +114,7 @@ Definition of a scope that is a stage pipeline: bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline; if (stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); - throw NotStagePipelineError(self->mod, GetRef(block)); + throw NotStagePipelineError(self->mod, ffi::GetRef(block)); } } return scope_root_sref; @@ -123,9 +123,9 @@ Definition of a scope that is a stage pipeline: ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { struct Collector : public StmtVisitor { void VisitStmt_(const BlockRealizeNode* realize) final { - result.realizes.push_back(GetRef(realize)); - const Array& iter_vars = realize->block->iter_vars; - const Array& iter_values = realize->iter_values; + result.realizes.push_back(ffi::GetRef(realize)); + const ffi::Array& iter_vars = realize->block->iter_vars; + const ffi::Array& iter_values = realize->iter_values; ICHECK_EQ(iter_vars.size(), iter_values.size()); int n = realize->iter_values.size(); for (int i = 0; i < n; ++i) { @@ -175,7 +175,7 @@ void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { */ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, const StmtSRef& block_sref) { - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; CheckSRefHigherOrEqual(scope_root_sref, block_sref); const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); if (maybe_root_block) { @@ -183,7 +183,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, buffer_writers = scope->buffer_writers; } else { // Collect all child blocks of root sub-tree, and merge their buffer writers. - Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); + ffi::Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); for (const StmtSRef& child_block_sref : child_block_srefs) { BlockScope child_scope = self->GetBlockScope(child_block_sref); for (const auto& it : child_scope->buffer_writers) { @@ -275,15 +275,15 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, public: explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { return "ScheduleError: Incomplete block"; } - String DetailRenderTemplate() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Incomplete block"; } + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a complete block - it violates condition #" << violated_cond_; os << ".\n" << kCompleteBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; int violated_cond_; @@ -292,7 +292,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw IncompleteBlockError(self->mod, GetRef(block), error_code); + throw IncompleteBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -327,7 +327,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(GetRef(block)) ? 0 : 5; + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)) ? 0 : 5; } bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -349,15 +349,15 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, public: explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } - String DetailRenderTemplate() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_; os << ".\n" << kReductionBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; int violated_cond_; @@ -366,7 +366,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); if (error_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotReductionBlockError(self->mod, GetRef(block), error_code); + throw NotReductionBlockError(self->mod, ffi::GetRef(block), error_code); } } @@ -382,10 +382,10 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl complete_block_error_code_(complete_block_error_code), reduction_block_error_code_(reduction_block_error_code) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not a complete or reduction block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block {0} is not a complete block - it violates condition #" << complete_block_error_code_; @@ -396,7 +396,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -413,8 +413,8 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl return; } const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotCompleteOrReductionBlockError(self->mod, GetRef(block), complete_block_error_code, - reduction_block_error_code); + throw NotCompleteOrReductionBlockError(self->mod, ffi::GetRef(block), + complete_block_error_code, reduction_block_error_code); } void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { @@ -429,12 +429,12 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt local_reduction_block_code_(local_reduction_block_code) { ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " "because some of its child block on SRef tree is neither a local complete block nor a " "local reduction block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because " "its child block {1} on SRef tree is neither a local complete block nor a local " @@ -448,7 +448,9 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + ffi::Array LocationsOfInterest() const final { + return {subtree_root_, violate_block_}; + } IRModule mod_; Stmt subtree_root_; @@ -457,14 +459,14 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt int local_reduction_block_code_; }; - Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); + ffi::Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); for (const StmtSRef& block_sref : child_block_srefs) { int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); if (local_complete_block_code != 0 && local_reduction_block_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), - GetRef(block), local_complete_block_code, + throw NotCompactDataFlowError(self->mod, ffi::GetRef(subtree_root->stmt), + ffi::GetRef(block), local_complete_block_code, local_reduction_block_code); } } @@ -492,19 +494,19 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, class OutputBlockError : public ScheduleError { public: explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot operate on an output block"; } - String DetailRenderTemplate() const final { return "The block {0} is an output block"; } + ffi::String DetailRenderTemplate() const final { return "The block {0} is an output block"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; if (IsOutputBlock(self, block_sref, scope_root_sref)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw OutputBlockError(self->mod, GetRef(block)); + throw OutputBlockError(self->mod, ffi::GetRef(block)); } } @@ -545,7 +547,7 @@ bool IsWriteCache(const StmtSRef& block_sref) { /******** Binding ********/ -bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, +bool IsAffineBinding(const BlockRealize& realize, const ffi::Map& loop_var_ranges, arith::Analyzer* analyzer) { if (loop_var_ranges.empty()) { return true; @@ -561,7 +563,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return false; } for (const arith::IterSumExpr& sum_expr : res->indices) { - const Array& args = sum_expr->args; + const ffi::Array& args = sum_expr->args; if (!args.empty() && !is_one(args[0]->scale)) { return false; } @@ -570,16 +572,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va } void CheckPartialAffineBinding(const ScheduleState& self, Block block, - const Optional& high_exclusive) { + const ffi::Optional& high_exclusive) { class NotAffineBindingError : public ScheduleError { public: - explicit NotAffineBindingError(IRModule mod, Block block, Optional high_exclusive) + explicit NotAffineBindingError(IRModule mod, Block block, + ffi::Optional high_exclusive) : mod_(std::move(mod)), block_(std::move(block)) { if (high_exclusive.defined()) { high_exclusive_loop_ = high_exclusive.value()->StmtAs(); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; if (high_exclusive_loop_) { ss << "ScheduleError: The block is required to have an partial affine binding under " @@ -589,7 +592,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, } return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; if (high_exclusive_loop_) { ss << "The block {0} is required to have an partial affine binding under " @@ -600,7 +603,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, return ss.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; const ForNode* high_exclusive_loop_{nullptr}; @@ -614,8 +617,8 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, if (block_sref->parent && high_exclusive.defined()) { // if it is not of global affine binding, check affineness under high_exclusive, arith::Analyzer analyzer; - Map dom_map = - LoopDomainOfSRefTreePath(GetRef(block_sref->parent), high_exclusive); + ffi::Map dom_map = + LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent), high_exclusive); if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) { return; } @@ -633,18 +636,18 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc explicit NotTrivialBindingError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The binding values of the block are not variables of outer loops."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The binding values of the {0} are not variables of outer loops."; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -652,14 +655,14 @@ void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& bloc }; if (!IsTrivialBinding(self, block_sref)) { - throw NotTrivialBindingError(self->mod, GetRef(block_sref->StmtAs())); + throw NotTrivialBindingError(self->mod, ffi::GetRef(block_sref->StmtAs())); } } -Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, - const Optional& high_exclusive, - const runtime::StorageScope& extra_relax_scope) { - Map result; +ffi::Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, + const ffi::Optional& high_exclusive, + const runtime::StorageScope& extra_relax_scope) { + ffi::Map result; const StmtSRefNode* p = low_inclusive.get(); const StmtSRefNode* limit = static_cast(high_exclusive.get()); for (; p != limit; p = p->parent) { @@ -673,7 +676,7 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, for (; p; p = p->parent) { if (const ForNode* loop = p->StmtAs()) { if (loop->kind == ForKind::kThreadBinding) { - const String& thread_tag = loop->thread_binding.value()->thread_tag; + const ffi::String& thread_tag = loop->thread_binding.value()->thread_tag; if (CanRelaxStorageUnderThread(extra_relax_scope, runtime::ThreadScope::Create(thread_tag))) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -685,12 +688,12 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, return result; } -Map GetBindings(const BlockRealize& realize) { +ffi::Map GetBindings(const BlockRealize& realize) { const BlockNode* block = realize->block.get(); - const Array& all_lhs = block->iter_vars; - const Array& all_rhs = realize->iter_values; + const ffi::Array& all_lhs = block->iter_vars; + const ffi::Array& all_rhs = realize->iter_values; ICHECK_EQ(all_lhs.size(), all_rhs.size()); - Map result; + ffi::Map result; for (int i = 0, n = all_lhs.size(); i < n; ++i) { const IterVar& lhs = all_lhs[i]; const PrimExpr& rhs = all_rhs[i]; @@ -724,7 +727,7 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, if (set == nullptr) { continue; } - Array vars_in_binding = UndefinedVars(iter_value); + ffi::Array vars_in_binding = UndefinedVars(iter_value); for (const Var& var : vars_in_binding) { set->insert(var.get()); } @@ -742,32 +745,32 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive only supports loop starting with 0"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loop {0} does not start with 0, which is not supported"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; }; const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!analyzer->CanProve(loop->min == 0)) { - throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + throw LoopNotStartWithZeroError(self->mod, ffi::GetRef(loop)); } } /******** Block-loop relation ********/ -Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, - const StmtSRef& parent_sref) { - Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); - Array child_block_srefs; +ffi::Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_srefs; child_block_srefs.reserve(child_block_realize.size()); for (BlockRealize realize : child_block_realize) { @@ -776,19 +779,19 @@ Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, return child_block_srefs; } -Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { +ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - static Array Collect(const Stmt& stmt) { + static ffi::Array Collect(const Stmt& stmt) { Collector collector; collector(stmt); return std::move(collector.result_); } void VisitStmt_(const BlockRealizeNode* block_realize) final { - result_.push_back(GetRef(block_realize)); + result_.push_back(ffi::GetRef(block_realize)); } - Array result_; + ffi::Array result_; }; if (parent_sref->stmt->IsInstance()) { @@ -807,31 +810,31 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self class NonSingleChildBlockError : public ScheduleError { public: explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) - : mod_(std::move(mod)), stmt_(GetRef(sref->stmt)) { + : mod_(std::move(mod)), stmt_(ffi::GetRef(sref->stmt)) { sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream os; os << "ScheduleError: The " << sref_type_ << " is required to have only one child block"; return os.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The " << sref_type_ << " {0} is required to have only one child block"; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {stmt_}; } + ffi::Array LocationsOfInterest() const final { return {stmt_}; } IRModule mod_; Stmt stmt_; - String sref_type_; + ffi::String sref_type_; }; - Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + ffi::Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); if (child_block_realize.size() != 1) { throw NonSingleChildBlockError(self->mod, parent_sref); } @@ -867,10 +870,10 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr return Downcast(func->body); } else { BlockRealizeFinder finder(block); - finder(GetRef(block_sref->parent->stmt)); + finder(ffi::GetRef(block_sref->parent->stmt)); ICHECK(finder.result != nullptr) - << "InternalError: Cannot find the BlockRealize of block " << GetRef(block); - return GetRef(finder.result); + << "InternalError: Cannot find the BlockRealize of block " << ffi::GetRef(block); + return ffi::GetRef(finder.result); } } @@ -928,7 +931,7 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref) { } } -StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { +StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs) { CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; std::unordered_map sref_visited_cnt; @@ -945,16 +948,17 @@ StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { p = p->parent; } ICHECK(p != nullptr); - return GetRef(p); + return ffi::GetRef(p); } bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { - return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).has_value(); + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure) + .has_value(); } -std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, - const StmtSRef& block_sref) { - Array location_srefs; +std::pair, std::vector> CollectComputeLocation( + const ScheduleState& self, const StmtSRef& block_sref) { + ffi::Array location_srefs; std::vector location_indices; // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can @@ -967,7 +971,7 @@ std::pair, std::vector> CollectComputeLocation(const Schedu location_indices.push_back(-1); // Step 2. If the block has no consumer, there is no more candidate. - Array consumers = GetConsumers(self, block_sref); + ffi::Array consumers = GetConsumers(self, block_sref); if (consumers.empty()) { return std::make_pair(location_srefs, location_indices); } @@ -975,14 +979,14 @@ std::pair, std::vector> CollectComputeLocation(const Schedu // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If // such a loop cannot be found, there is no more candidate and we just return. StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) - : GetRef(consumers[0]->parent); + : ffi::GetRef(consumers[0]->parent); if (loop_boundary->StmtAs() == nullptr) { return std::make_pair(location_srefs, location_indices); } // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position // of the boundary loop reveals the number of possible additional candidates. - Array loop_srefs = GetLoops(consumers[0]); + ffi::Array loop_srefs = GetLoops(consumers[0]); size_t lca_pos = std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); ICHECK_LT(lca_pos, loop_srefs.size()); @@ -1035,9 +1039,9 @@ std::pair, std::vector> CollectComputeLocation(const Schedu /******** Producer-consumer relation ********/ -Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { - Array edges = scope->GetDepsByDst(block_sref); - Array results; +ffi::Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { + ffi::Array edges = scope->GetDepsByDst(block_sref); + ffi::Array results; std::unordered_set result_set; results.reserve(edges.size()); for (const Dependency& edge : edges) { @@ -1050,9 +1054,9 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope return results; } -Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { - Array edges = scope->GetDepsBySrc(block_sref); - Array results; +ffi::Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { + ffi::Array edges = scope->GetDepsBySrc(block_sref); + ffi::Array results; std::unordered_set result_set; results.reserve(edges.size()); for (const Dependency& edge : edges) { @@ -1065,7 +1069,7 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope return results; } -Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { +ffi::Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { struct OutputBlockCollector : public StmtVisitor { explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} @@ -1084,7 +1088,7 @@ Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scop } const ScheduleState& self_; - Array results_; + ffi::Array results_; }; OutputBlockCollector collector(self); collector(scope_block->body); @@ -1093,8 +1097,9 @@ Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scop } ProducerConsumerSplit ProducerConsumerSplit::Find( - const ScheduleState& self, const Array& subtrees, - const Array& producer_block_srefs, const Array& consumer_block_srefs, + const ScheduleState& self, const ffi::Array& subtrees, + const ffi::Array& producer_block_srefs, + const ffi::Array& consumer_block_srefs, std::unordered_map* block2realize) { class InsertionPointNotFoundError : public ScheduleError { public: @@ -1104,12 +1109,12 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( last_producer_position_(last_producer_position), first_consumer_position_(first_consumer_position) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot find the insertion point that satisfies the producer-consumer " "constraint"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Cannot find the insertion point that satisfies the producer-consumer constraint. In " "0-based indexing, the last producer appears in subtree " + std::to_string(last_producer_position_) + @@ -1119,7 +1124,7 @@ ProducerConsumerSplit ProducerConsumerSplit::Find( IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1202,7 +1207,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl buffer_index_(buffer_index), index_type_(index_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (index_type_ == BufferIndexType::kWrite) { return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " "range " @@ -1216,7 +1221,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl } } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; size_t num = index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size(); @@ -1228,7 +1233,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1237,7 +1242,7 @@ BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& bl BufferIndexType index_type_; }; - const Array& access_region = + const ffi::Array& access_region = index_type == BufferIndexType::kWrite ? block->writes : block->reads; if (n < 0 || static_cast(access_region.size()) <= n) { @@ -1251,8 +1256,8 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, return GetNthAccessBufferRegion(self, block, n, index_type)->buffer; } -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer) { +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or // match_buffers. const StmtSRefNode* defining_site_sref = block_sref.get(); @@ -1266,13 +1271,13 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ // Try to find the buffer in `allloc_buffers` for (const Buffer& alloc_buffer : block->alloc_buffers) { if (buffer.same_as(alloc_buffer)) { - return {GetRef(defining_site_sref), true}; + return {ffi::GetRef(defining_site_sref), true}; } } // We do not allow the buffer being defined in `match_buffer`. for (const MatchBufferRegion match_buffer : block->match_buffers) { if (buffer.same_as(match_buffer)) { - return {GetRef(defining_site_sref), false}; + return {ffi::GetRef(defining_site_sref), false}; } } defining_site_sref = defining_site_sref->parent; @@ -1288,7 +1293,7 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { const StmtSRefNode* p = sref.get(); for (; p->parent != nullptr; p = p->parent) { } - return GetRef(p); + return ffi::GetRef(p); } void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, @@ -1307,7 +1312,7 @@ void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, /******** Misc ********/ -bool HasOp(const Stmt& stmt, const Array& ops) { +bool HasOp(const Stmt& stmt, const ffi::Array& ops) { std::unordered_set op_set; op_set.reserve(ops.size()); for (const Op& op : ops) { @@ -1397,7 +1402,7 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri } // Case 2. Read index cannot be recognized as `var +/- const` // where `var` is a write index and `const` is an optional constant shift - Optional opt_const = std::nullopt; + ffi::Optional opt_const = std::nullopt; const VarNode* var = static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); if (var == nullptr || !var2idx.count(var)) { @@ -1440,26 +1445,26 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri /******** Storage Scope ********/ -void CheckStorageScope(const ScheduleState& self, String storage_scope) { +void CheckStorageScope(const ScheduleState& self, ffi::String storage_scope) { class InvalidStorageScopeError : public ScheduleError { public: - explicit InvalidStorageScopeError(IRModule mod, String storage_scope) + explicit InvalidStorageScopeError(IRModule mod, ffi::String storage_scope) : mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input storage scope is invalid"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The input storage scope \"" + storage_scope_ + "\" is invalid."; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: IRModule mod_; - String storage_scope_; + ffi::String storage_scope_; }; try { @@ -1481,8 +1486,8 @@ bool IsSpatial(const StmtSRef& block_sref) { bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { TVM_SREF_TO_BLOCK(block_sref); - Array loops = GetLoops(block_sref); - Array binds = GetBlockRealize(self, block_sref)->iter_values; + ffi::Array loops = GetLoops(block_sref); + ffi::Array binds = GetBlockRealize(self, block_sref)->iter_values; if (loops.size() != binds.size()) { return false; } @@ -1532,7 +1537,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref read_buffers.reserve(block->reads.size()); for (const BufferRegion& buffer_region : block->reads) { const BufferNode* buffer = buffer_region->buffer.get(); - const Array& regions = buffer_region->region; + const ffi::Array& regions = buffer_region->region; // Step 2.1. Duplication of read buffers are not allowed if (read_buffers.insert(buffer).second == false) { return false; @@ -1584,7 +1589,7 @@ bool IsSpatialPrimFunc(const PrimFunc& func) { std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); int64_t cum_space_len = 1, cum_reduce_len = 1; /* * Return (-1, -1) if @@ -1619,7 +1624,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Array loops = tir::GetLoops(block_sref); + ffi::Array loops = tir::GetLoops(block_sref); // Cond 1. The block must have at lease one write buffer if (block->writes.size() == 0) { @@ -1742,10 +1747,10 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, return info; } -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, - bool allow_padding) { +ffi::Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func, + bool allow_padding) { arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -1773,7 +1778,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const std::vector& desc_loops = desc_info.desc_loops; const std::unordered_set& desc_loop_vars = desc_info.desc_loop_vars; const BlockRealizeNode* desc_block = desc_info.desc_block; - ObjectPtr ret = make_object(); + ObjectPtr ret = ffi::make_object(); const int n_block_vars = block->iter_values.size(); const int n_desc_vars = desc_block->iter_values.size(); const int offset = n_block_vars - n_desc_vars; @@ -1876,19 +1881,19 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } } - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + ret->loop_map.Set(block_loop_sref, ffi::GetRef(desc_loop)); break; } } for (int i = 0, n = desc_loops.size(); i < n; ++i) { - ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + ret->desc_loop_indexer.Set(ffi::GetRef(desc_loops[i]), Integer(i)); } if (!block_index_to_padding.empty()) { if (!allow_padding) { return std::nullopt; } - Array paddings; + ffi::Array paddings; for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { const IterVar& iter_var = block->block->iter_vars[i]; if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) { @@ -1918,8 +1923,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ /*! \brief IndexMap proposer for layout transformation in auto tensorization. */ class AutoTensorizeMappingProposer { public: - static Array ProposeMappings(const AutoTensorizeComparator* extractor, - arith::Analyzer* analyzer) { + static ffi::Array ProposeMappings(const AutoTensorizeComparator* extractor, + arith::Analyzer* analyzer) { AutoTensorizeMappingProposer proposer(extractor, analyzer); proposer.CollectFeasibleSet(); return proposer.ProposeAllFuseMapping(); @@ -2013,7 +2018,7 @@ class AutoTensorizeMappingProposer { for (const auto& kv : rhs_buffer_masks) { const VarNode* rhs_var = kv.first; const BufferMask& mask = kv.second; - mask_to_rhs_vars[mask].insert(GetRef(rhs_var)); + mask_to_rhs_vars[mask].insert(ffi::GetRef(rhs_var)); } std::unordered_map rhs_var_iter_type; for (const auto& iter : extractor_->rhs_iters_) { @@ -2029,7 +2034,7 @@ class AutoTensorizeMappingProposer { } } - Array ProposeAllFuseMapping() { + ffi::Array ProposeAllFuseMapping() { // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to // the same iter on RHS, they will be fused in the original order in LHS block iters. We will // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped @@ -2037,12 +2042,12 @@ class AutoTensorizeMappingProposer { // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn // the parameters of the result index map, each parameter corresponds to a LHS iter - Array index_map_src; + ffi::Array index_map_src; // the outputs of the result index map - Array index_map_tgt; + ffi::Array index_map_tgt; // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap - Map lhs_iter_extents; + ffi::Map lhs_iter_extents; for (const auto& iter : extractor_->lhs_iters_) { lhs_iter_extents.Set(iter->var, iter->dom->extent); index_map_src.push_back(iter->var.copy_with_suffix("")); @@ -2050,7 +2055,7 @@ class AutoTensorizeMappingProposer { // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion // result for each group of iters on LHS. - Map fused_lhs_iters; + ffi::Map fused_lhs_iters; for (const auto& iter : extractor_->rhs_iters_) { fused_lhs_iters.Set(iter->var, 0); } @@ -2114,19 +2119,20 @@ bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); } -Optional GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { +ffi::Optional GetAutoTensorizeMappingInfo( + const tir::ScheduleState& self, const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(self->mod); if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { return std::nullopt; } arith::Analyzer analyzer; - Array mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); + ffi::Array mappings = + AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); if (mappings.empty()) { return std::nullopt; } - ObjectPtr ret = make_object(); + ObjectPtr ret = ffi::make_object(); ret->mappings = std::move(mappings); ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_); @@ -2149,7 +2155,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ auto block_sref = sch->GetSRef(block); return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); }) - .def("tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> String { + .def("tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> ffi::String { IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); if (kind == kDataPar) { return "S"; diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index f6dc0067a800..eedf32ba06e8 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -28,7 +28,7 @@ namespace tir { * \param buffer The buffer * \return The strides */ -Array GetStrides(const Buffer& buffer) { +ffi::Array GetStrides(const Buffer& buffer) { if (!buffer->strides.empty()) { ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); return buffer->strides; @@ -37,7 +37,7 @@ Array GetStrides(const Buffer& buffer) { if (ndim == 0) { return {}; } - Array strides(ndim, PrimExpr{nullptr}); + ffi::Array strides(ndim, PrimExpr{nullptr}); PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); for (int i = ndim - 1; i >= 0; --i) { strides.Set(i, stride); @@ -75,9 +75,9 @@ class SplitExprCollector { * \return The collected split expressions */ static std::vector Collect(const PrimExpr& index, - const Map& input_iters, // - const PrimExpr& predicate, // - arith::IterMapLevel check_level, // + const ffi::Map& input_iters, // + const PrimExpr& predicate, // + arith::IterMapLevel check_level, // arith::Analyzer* analyzer) { arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, predicate, check_level, analyzer); @@ -106,7 +106,7 @@ class SplitExprCollector { failed_ = true; return; } - exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + exprs_.push_back(SplitExpr{ffi::GetRef(var), *lower_factor, *extent}); } else if (auto iter_sum_expr = expr->source->source.as()) { Visit(iter_sum_expr.value()); } else { @@ -126,13 +126,13 @@ class SplitExprCollector { std::vector exprs_; }; -Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, - const Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer) { +ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, + const ffi::Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { int ndim = buffer->shape.size(); int n_loops = loops.size(); // Step 1. Collect the domains and indices of loop variables - Map input_iters; + ffi::Map input_iters; std::unordered_map var2id; var2id.reserve(n_loops); for (int i = 0; i < n_loops; ++i) { @@ -142,7 +142,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } // Step 2. Calculate a functor that flattens a multi-dimensional index auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( - const Array& indices) -> PrimExpr { + const ffi::Array& indices) -> PrimExpr { PrimExpr flatten_index = make_const(dtype, 0); for (int i = 0; i < ndim; ++i) { flatten_index = flatten_index + strides[i] * indices[i]; @@ -179,7 +179,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& &order, // & shape = buffer->shape, // analyzer // - ](Array indices) -> Array { + ](ffi::Array indices) -> ffi::Array { ICHECK_EQ(indices.size(), shape.size()); for (int i = 0, n = indices.size(); i < n; ++i) { analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); @@ -198,7 +198,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } std::reverse(split.begin(), split.end()); // Step 5.3. Reorder the indexing pattern according to `order` - Array results; + ffi::Array results; results.reserve(ndim); for (int i = 0; i < ndim; ++i) { results.push_back(split[order[i]]); @@ -207,11 +207,11 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& }; // Step 6: Create the inverse index mapping. auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape, - analyzer](Array indices) -> Array { + analyzer](ffi::Array indices) -> ffi::Array { ICHECK_EQ(indices.size(), split_exprs.size()); // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3. // After the inverse permutation, indices[i] corresponds to split_exprs[i] - Array inv_permuted_indices; + ffi::Array inv_permuted_indices; inv_permuted_indices.reserve(indices.size()); for (int i = 0, n = indices.size(); i < n; ++i) { const Var& index = indices[inverse_order[i]]; @@ -227,14 +227,14 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& stride *= split_exprs[i].extent; } // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. - Array result; + ffi::Array result; result.reserve(shape.size()); for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i])); flattened_index = floordiv(flattened_index, shape[i]); result.push_back(index); } - return Array(result.rbegin(), result.rend()); + return ffi::Array(result.rbegin(), result.rend()); }; IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse); return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); @@ -242,11 +242,12 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.schedule.SuggestIndexMap", [](Buffer buffer, Array indices, - Array loops, PrimExpr predicate) { - arith::Analyzer analyzer; - return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); - }); + refl::GlobalDef().def( + "tir.schedule.SuggestIndexMap", + [](Buffer buffer, ffi::Array indices, ffi::Array loops, PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); }); } // namespace tir diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index d85be933820c..085a4a33de87 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -49,7 +49,7 @@ namespace tir { */ class PatternMatcher : public ExprVisitor { public: - explicit PatternMatcher(Array pattern) : pattern_(std::move(pattern)) {} + explicit PatternMatcher(ffi::Array pattern) : pattern_(std::move(pattern)) {} void VisitExpr_(const VarNode* op) final { auto it = filled_map_.find(op); @@ -258,7 +258,7 @@ class PatternMatcher : public ExprVisitor { } } - void Match(const Array& exprs_to_match) { + void Match(const ffi::Array& exprs_to_match) { this->match_success_ = true; this->filled_map_.clear(); @@ -281,7 +281,7 @@ class PatternMatcher : public ExprVisitor { private: bool match_success_{true}; - Array pattern_; + ffi::Array pattern_; PrimExpr expr_to_match_; std::unordered_map filled_map_; }; @@ -303,19 +303,19 @@ static const char* kRFactorCrossThreadReductionApplicableBlockDef = 11) The buffers written by the block should have same shape 12) The indices of all BufferStores in the reduction block should be the same)"; -void ErrorRFactorCrossThreadReductionNotApplicable(const Optional& self, Block block, - int violated_cond) { +void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optional& self, + Block block, int violated_cond) { class RFactorNotApplicableError : public ScheduleError { public: explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond) : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: RFactor cannot be applied to the block since the block does not meet " "the requirements"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "RFactor cannot be applied to block {0}, because the block violates condition #" << violated_cond_ << ".\n" @@ -324,7 +324,7 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const Optional } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -352,11 +352,12 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const Optional * \param buf2index A mapping from reduction buffers to their indices of the reduction order * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ -void ExtractReductionUpdates(const Optional& self, Block block, - const LetStmtNode* let, int n_buffers, Array* updates, +void ExtractReductionUpdates(const ffi::Optional& self, Block block, + const LetStmtNode* let, int n_buffers, + ffi::Array* updates, std::unordered_map* buf2index) { std::unordered_map var2index; - Array let_values; + ffi::Array let_values; let_values.reserve(n_buffers); updates->resize(n_buffers); @@ -390,7 +391,8 @@ void ExtractReductionUpdates(const Optional& self, Block block, if (p_seq == nullptr && p_buf_store == nullptr) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); } - Array seq = p_seq != nullptr ? p_seq->seq : Array{GetRef(p_buf_store)}; + ffi::Array seq = + p_seq != nullptr ? p_seq->seq : ffi::Array{ffi::GetRef(p_buf_store)}; if (static_cast(seq.size()) != n_buffers) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); } @@ -426,10 +428,10 @@ void ExtractReductionUpdates(const Optional& self, Block block, } } -std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( - const Optional& self, Block block) { - Array inits; - Array updates; +std::pair, ffi::Array> GetInitValuesAndUpdatesFromReductionBlock( + const ffi::Optional& self, Block block) { + ffi::Array inits; + ffi::Array updates; // Step 1. Extract the BufferStores serving as block inits. if (auto init = block->init.as()) { @@ -455,7 +457,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct int n_buffers = inits.size(); std::unordered_map buf2index; if (const auto* update = block->body.as()) { - updates.push_back(GetRef(update)); + updates.push_back(ffi::GetRef(update)); buf2index[update->buffer.get()] = 0; } else { const auto* let = block->body.as(); @@ -465,15 +467,15 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct // Step 3. Set the init values according to the buffer order in `updates`, with the help of the // mapping `buf2index`. - Array init_values; + ffi::Array init_values; init_values.resize(n_buffers); // - Check all buffers have the same shape // - Check all indices of the BufferStores are the same // - Check buffers written in the block init and the block body can match // - Check buffers do not duplicate - const Array& expected_shape = updates[0]->buffer->shape; - const Array& expected_indices = updates[0]->indices; + const ffi::Array& expected_shape = updates[0]->buffer->shape; + const ffi::Array& expected_indices = updates[0]->indices; ICHECK_EQ(expected_shape.size(), expected_indices.size()); int n_dim = expected_indices.size(); arith::Analyzer ana; @@ -511,7 +513,7 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct return std::make_pair(init_values, updates); } -bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { +bool ContainsOnlyDataParAndReductionBlockIter(const ffi::Array& iters) { for (const IterVar& iter_var : iters) { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { return false; @@ -589,18 +591,18 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) { class NoMatchedReducerError : public ScheduleError { public: - explicit NoMatchedReducerError(IRModule mod, Array identities, - Array combiners) + explicit NoMatchedReducerError(IRModule mod, ffi::Array identities, + ffi::Array combiners) : mod_(std::move(mod)), identities_(std::move(identities)), combiners_(std::move(combiners)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " "block. So rfactor and cross-thread reduction cannot be applied."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_ << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " @@ -609,18 +611,18 @@ class NoMatchedReducerError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; - Array identities_; - Array combiners_; + ffi::Array identities_; + ffi::Array combiners_; }; -std::tuple, Array> GetReducerAndCombinerLhsRhs( - const Optional& self, const Array& identities, - const Array& combiners) { +std::tuple, ffi::Array> GetReducerAndCombinerLhsRhs( + const ffi::Optional& self, const ffi::Array& identities, + const ffi::Array& combiners) { CommReducer reducer{nullptr}; - Array combiner_lhs, combiner_rhs; + ffi::Array combiner_lhs, combiner_rhs; bool matched = FromIdentityCombiner(identities, combiners, &reducer, &combiner_lhs, &combiner_rhs); if (!matched) { @@ -636,9 +638,10 @@ std::tuple, Array> GetReducerAndCombinerL /******** Commutative Reducer ********/ -bool MatchReducer(const CommReducer& reducer, const Array& identities, - const Array& combined_values, const Array& buf_loads, - Array* lhs, Array* rhs) { +bool MatchReducer(const CommReducer& reducer, const ffi::Array& identities, + const ffi::Array& combined_values, + const ffi::Array& buf_loads, ffi::Array* lhs, + ffi::Array* rhs) { ExprDeepEqual equal; ICHECK_EQ(identities.size(), combined_values.size()); int n_buffers = identities.size(); @@ -650,7 +653,7 @@ bool MatchReducer(const CommReducer& reducer, const Array& identities, PatternMatcher pattern_matcher(reducer->result); pattern_matcher.Match(combined_values); - Array lhs_tmp, rhs_tmp; + ffi::Array lhs_tmp, rhs_tmp; lhs_tmp.reserve(n_buffers); rhs_tmp.reserve(n_buffers); if (!pattern_matcher.Success()) { @@ -671,11 +674,12 @@ bool MatchReducer(const CommReducer& reducer, const Array& identities, return true; } -bool FromIdentityCombiner(const Array& identities, const Array& combiners, - CommReducer* result_reducer, Array* lhs, Array* rhs) { +bool FromIdentityCombiner(const ffi::Array& identities, + const ffi::Array& combiners, CommReducer* result_reducer, + ffi::Array* lhs, ffi::Array* rhs) { int n = identities.size(); - Array buf_loads; - Array stored_values; + ffi::Array buf_loads; + ffi::Array stored_values; buf_loads.reserve(n); stored_values.reserve(n); @@ -685,9 +689,9 @@ bool FromIdentityCombiner(const Array& identities, const Array(Array)>& reducer_getter : + for (const ffi::TypedFunction(ffi::Array)>& reducer_getter : GetReducerGetters()) { - Optional reducer = reducer_getter(identities); + ffi::Optional reducer = reducer_getter(identities); if (!reducer.defined()) { continue; } diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index 4e3f04e0f389..f9a09552c21c 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -56,19 +56,20 @@ class SRefTreeVerifier : public StmtVisitor { } ICHECK(self_->stmt2ref.count(block)) << "InternalError: A BlockNode should appear in sref map, but it didn't\n" - << GetRef(block); + << ffi::GetRef(block); ++n_sref_visited_; ++n_block_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(block); ICHECK(self_->block_info.count(sref)) << "InternalError: Cannot find scope information of the BlockNode:\n" - << GetRef(block); + << ffi::GetRef(block); ICHECK(sref->parent == ancestors_.back()) << "InternalError: Parent information mismatch for BlockNode:\n" - << GetRef(block) << "\nIts parent is supposed to be:\n" - << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" - << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(std::nullopt)); + << ffi::GetRef(block) << "\nIts parent is supposed to be:\n" + << ffi::GetRef(ancestors_.back()->stmt) + << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? ffi::Optional(ffi::GetRef(sref->parent->stmt)) + : ffi::Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); if (block->init.defined()) { ++init_block_depth_; @@ -88,16 +89,17 @@ class SRefTreeVerifier : public StmtVisitor { } ICHECK(self_->stmt2ref.count(loop)) << "InternalError: A ForNode should appear in sref map, but it didn't\n" - << GetRef(loop); + << ffi::GetRef(loop); ++n_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(loop); - Optional stmt = std::nullopt; + ffi::Optional stmt = std::nullopt; ICHECK(sref->parent == ancestors_.back()) << "InternalError: Parent information mismatch for ForNode:\n" - << GetRef(loop) << "\nIts parent is supposed to be:\n" - << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" - << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(std::nullopt)); + << ffi::GetRef(loop) << "\nIts parent is supposed to be:\n" + << ffi::GetRef(ancestors_.back()->stmt) + << "\nHowever, its parent is incorrect and is:\n" + << (sref->parent ? ffi::Optional(ffi::GetRef(sref->parent->stmt)) + : ffi::Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); StmtVisitor::VisitStmt_(loop); ancestors_.pop_back(); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 6f7e682d6c7a..b33333177816 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -26,7 +26,7 @@ namespace tir { Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; @@ -56,7 +56,7 @@ class ScheduleCopier { TSymbolTable* new_symbol_table) { const ScheduleState& src_state = self->state_; ScheduleCopier copier(src_state); - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->mod = src_state->mod; n->block_info = copier.Copy(src_state->block_info); n->stmt2ref = copier.Copy(src_state->stmt2ref); @@ -98,9 +98,9 @@ class ScheduleCopier { return old2new_[sref] = StmtSRef(nullptr, nullptr, -1); } - /*! \brief Copy Array */ - Array Copy(const Array& list) { - Array result; + /*! \brief Copy ffi::Array */ + ffi::Array Copy(const ffi::Array& list) { + ffi::Array result; result.reserve(list.size()); for (const StmtSRef& elem : list) { result.push_back(Copy(elem)); @@ -108,9 +108,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy Array */ - Array Copy(const Array& list) { - Array result; + /*! \brief Copy ffi::Array */ + ffi::Array Copy(const ffi::Array& list) { + ffi::Array result; result.reserve(list.size()); for (const Dependency& elem : list) { result.push_back(Dependency(Copy(elem->src), Copy(elem->dst), elem->kind)); @@ -118,9 +118,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy SMap> */ - SMap> Copy(const SMap>& map) { - SMap> result; + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; result.reserve(map.size()); for (const auto& kv : map) { result[Copy(kv.first)] = Copy(kv.second); @@ -128,9 +128,9 @@ class ScheduleCopier { return result; } - /*! \brief Copy SMap> */ - SMap> Copy(const SMap>& map) { - SMap> result; + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; result.reserve(map.size()); for (const auto& kv : map) { result[kv.first] = Copy(kv.second); @@ -145,7 +145,7 @@ class ScheduleCopier { const StmtSRef& old_sref = kv.first; const BlockInfo& old_info = kv.second; BlockInfo new_info = old_info; - ObjectPtr scope = make_object(); + ObjectPtr scope = ffi::make_object(); scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); @@ -184,7 +184,7 @@ class ScheduleCopier { std::unordered_map old2new_; }; -void ConcreteScheduleNode::WorkOn(const String& func_name) { +void ConcreteScheduleNode::WorkOn(const ffi::String& func_name) { this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name); } @@ -194,7 +194,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb } Schedule ConcreteScheduleNode::Copy() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->func_working_on_ = this->func_working_on_; n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); @@ -233,18 +233,18 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } -Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, - int max_innermost_factor, - Optional> decision) { +ffi::Array ConcreteScheduleNode::SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); // use None RV object to denotes auto-infer tile factors. return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, @@ -254,9 +254,9 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } -Array ConcreteScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, int n, - int partition_pos, int innerpart_factor, - Optional> decision) { +ffi::Array ConcreteScheduleNode::SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SamplePartitionedTile(&this->rand_state_, this->GetSRef(loop_rv), n, partition_pos, innerpart_factor, &decision)); @@ -265,7 +265,7 @@ Array ConcreteScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, } LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, - Optional decision) { + ffi::Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV( tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); @@ -275,22 +275,25 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional& func_name) { +BlockRV ConcreteScheduleNode::GetBlock(const ffi::String& name, + const ffi::Optional& func_name) { class NotSingleResult : public ScheduleError { public: - explicit NotSingleResult(String name, IRModule mod, const Array& blocks) + explicit NotSingleResult(ffi::String name, IRModule mod, const ffi::Array& blocks) : name_(name), mod_(mod), blocks_{} { blocks_.reserve(blocks.size()); for (const StmtSRef& block_sref : blocks) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - blocks_.push_back(GetRef(block)); + blocks_.push_back(ffi::GetRef(block)); } } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } + ffi::Array LocationsOfInterest() const final { + return {blocks_.begin(), blocks_.end()}; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (blocks_.empty()) { return "Cannot find a block with the name: " + name_; } else { @@ -298,7 +301,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional blocks_; + ffi::Array blocks_; }; GlobalVar gv = NullValue(); if (func_name.has_value()) { @@ -320,7 +323,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional blocks = tir::GetBlocks(this->state_, name, gv); + ffi::Array blocks = tir::GetBlocks(this->state_, name, gv); if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); @@ -329,12 +332,12 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional(blocks[0]); } -Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); @@ -342,8 +345,8 @@ Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { return result; } -Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - Array result; +ffi::Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); @@ -351,21 +354,21 @@ Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return result; } -Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); throw; } -Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { +ffi::Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); throw; } -Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { +ffi::Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); @@ -374,9 +377,9 @@ Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_ /******** Schedule: Transform loops ********/ -LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { +LoopRV ConcreteScheduleNode::Merge(const ffi::Array& loop_rvs) { CHECK(loop_rvs.size() > 1) << "ValueError: 'merge' requires at least 2 loop(s)"; - Array loop_srefs = this->GetSRefs(loop_rvs); + ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Merge(state_, loop_srefs); @@ -385,9 +388,9 @@ LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { return CreateRV(result); } -LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_iters) { +LoopRV ConcreteScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; - Array loop_srefs = this->GetSRefs(loop_rvs); + ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Fuse(state_, loop_srefs, preserve_unit_iters); @@ -400,16 +403,16 @@ class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Only one factor can be specified as -1 or none"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -419,7 +422,7 @@ class WrongFactorError : public ScheduleError { explicit WrongFactorError(IRModule mod, For loop, bool product) : mod_(mod), loop_(std::move(loop)), product_(product) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (product_) return "ScheduleError: The product of factors is not larger than or equal to the extent of " "loop"; @@ -427,7 +430,7 @@ class WrongFactorError : public ScheduleError { return "ScheduleError: The sum of factors is larger than or equal to the extent of loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (product_) return "The product of factors is not larger than or equal to the extent of loop {0}"; else @@ -435,7 +438,7 @@ class WrongFactorError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -447,18 +450,18 @@ class NonPositiveFactorError : public ScheduleError { explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx) : mod_(std::move(mod)), factor_(factor), idx_(idx) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: All the constant factors are required to be positive. However, some " "constant input factor is zero or negative."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "All the constant factors are required to be positive. However, the factor at position " << idx_ << " is " << factor_; return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -466,17 +469,17 @@ class NonPositiveFactorError : public ScheduleError { size_t idx_; }; -Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) { +ffi::Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) { // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - Array factors; + ffi::Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; PrimExpr tot_length = 1; - Array results; + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { @@ -502,7 +505,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, factors.Set(infer_index, this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { - throw WrongFactorError(state_->mod, GetRef(loop), true); + throw WrongFactorError(state_->mod, ffi::GetRef(loop), true); } results = tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); @@ -510,24 +513,24 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } -Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters) { +ffi::Array ConcreteScheduleNode::LoopPartition( + const LoopRV& loop_rv, const ffi::Array>& factor_rvs, + bool preserve_unit_iters) { class SymbolicShapeError : public ScheduleError { public: explicit SymbolicShapeError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The min and extent values of the loop are required to be known at " "compile time. However, dynamic shape has been detected."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Detected dynamic shape in either min or extent of a loop {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -536,14 +539,14 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, // Prepare for the loop_partitioning StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - Array factors; + ffi::Array factors; factors.reserve(factor_rvs.size()); int infer_index = -1; PrimExpr tot_length = 0; - Array results; + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); if (!is_const_number(loop->min) || !is_const_number(loop->extent)) { - throw SymbolicShapeError(state_->mod, GetRef(loop)); + throw SymbolicShapeError(state_->mod, ffi::GetRef(loop)); } // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { @@ -566,7 +569,7 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, } } if (this->analyzer_->CanProve(tot_length >= loop->extent)) { - throw WrongFactorError(state_->mod, GetRef(loop), false); + throw WrongFactorError(state_->mod, ffi::GetRef(loop), false); } if (infer_index != -1) { // if there is a 'None' in the factor list, 'None' becomes the difference between the extent and @@ -585,7 +588,7 @@ Array ConcreteScheduleNode::LoopPartition(const LoopRV& loop_rv, return CreateRV(results); } -void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { +void ConcreteScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { TVM_TIR_SCHEDULE_BEGIN(); tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); @@ -593,7 +596,7 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { } void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, - const Array new_order) { + const ffi::Array new_order) { TVM_TIR_SCHEDULE_BEGIN(); tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); TVM_TIR_SCHEDULE_END("reorder_block_iter_var", this->error_render_level_); @@ -634,7 +637,7 @@ void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); } -void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { +void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) { if (thread_axis == "vthread") { LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use " "`vthread.x`, `vthread.y` and `vthread.z` instead"; @@ -655,11 +658,11 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { /******** Schedule: Insert cache stages ********/ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; + ffi::Array consumer_block_refs = {}; for (BlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } @@ -672,11 +675,11 @@ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer } BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; + ffi::Array consumer_block_refs = {}; for (BlockRV block : consumer_blocks) { consumer_block_refs.push_back(this->GetSRef(block)); } @@ -689,7 +692,7 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff } BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -701,7 +704,7 @@ BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read } BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -712,27 +715,29 @@ BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int wri return CreateRV(result); } -Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) { - Array results; +ffi::Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, + int write_buffer_index, + const ffi::String& storage_scope) { + ffi::Array results; TVM_TIR_SCHEDULE_BEGIN(); results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); this->state_->DebugVerify(); - Array return_blocks; + ffi::Array return_blocks; return_blocks.push_back(CreateRV(results[0])); return_blocks.push_back(CreateRV(results[1])); return return_blocks; } -Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, - const String& storage_scope, int cse_thresh) { - Array result; +ffi::Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { + ffi::Array result; TVM_TIR_SCHEDULE_BEGIN(); result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); this->state_->DebugVerify(); - Array return_blocks; + ffi::Array return_blocks; for (const StmtSRef& blockrv : result) { return_blocks.push_back(CreateRV(blockrv)); } @@ -752,7 +757,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Data movement ********/ BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, @@ -763,7 +768,7 @@ BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block } BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, @@ -838,7 +843,7 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde } void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { TVM_TIR_SCHEDULE_BEGIN(); tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_); @@ -846,7 +851,7 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, } void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { TVM_TIR_SCHEDULE_BEGIN(); tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype); TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_); @@ -883,7 +888,8 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit return CreateRV(result); } -BlockRV ConcreteScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { +BlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, + bool preserve_unit_iters) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters); @@ -892,7 +898,7 @@ BlockRV ConcreteScheduleNode::Blockize(const Array& blocks, bool preser return CreateRV(result); } -void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), @@ -901,7 +907,7 @@ void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } -void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), @@ -929,8 +935,8 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { if (const auto* expr = ann_val.as()) { ICHECK(!expr->IsInstance()) - << "TypeError: String is expected, but gets StringImm"; - auto res_expr = this->Get(GetRef(expr)); + << "TypeError: ffi::String is expected, but gets StringImm"; + auto res_expr = this->Get(ffi::GetRef(expr)); // prefer to return int/float literals for annotations if (auto opt_intimm = res_expr.as()) { return (*std::move(opt_intimm))->value; @@ -941,7 +947,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { return res_expr; } if (const auto* arr = ann_val.as()) { - Array result; + ffi::Array result; result.reserve(arr->size()); for (size_t i = 0; i < arr->size(); i++) { result.push_back(CheckAndGetAnnotationValue(arr->at(i))); @@ -949,7 +955,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { return result; } if (const auto* dict = ann_val.as()) { - Map result; + ffi::Map result; for (auto it = dict->begin(); it != dict->end(); ++it) { const auto& key = it->first; auto value = CheckAndGetAnnotationValue(it->second); @@ -958,7 +964,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } else if (auto opt_str = key.try_cast()) { result.Set(opt_str.value(), value); } else { - LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm"; + LOG(FATAL) << "TypeError: annotation dict key expect to be ffi::String or StringImm"; } } return result; @@ -969,7 +975,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { TVM_FFI_UNREACHABLE(); } -void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, +void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); @@ -977,14 +983,14 @@ void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } -void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { +void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } -void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, +void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); tir::Annotate(state_, this->GetSRef(block_rv), ann_key, @@ -993,7 +999,7 @@ void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_k TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } -void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { +void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { TVM_TIR_SCHEDULE_BEGIN(); tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); this->state_->DebugVerify(); @@ -1004,10 +1010,10 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, + const ffi::Optional& pad_value, bool assume_injective_transform) { TVM_TIR_SCHEDULE_BEGIN(); - auto f_subst = [&](const Var& var) -> Optional { + auto f_subst = [&](const Var& var) -> ffi::Optional { if (auto opt_expr = symbol_table_.Get(var)) { return Downcast(opt_expr.value()); } else { @@ -1031,7 +1037,7 @@ void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { TVM_TIR_SCHEDULE_BEGIN(); tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, axis_separators); @@ -1050,7 +1056,7 @@ BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const Lo return CreateRV(result); } -void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array& padding) { +void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { TVM_TIR_SCHEDULE_BEGIN(); tir::PadEinsum(state_, this->GetSRef(block_rv), padding); TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_); @@ -1068,8 +1074,9 @@ void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buff /******** Schedule: Misc ********/ -void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) { +void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { TVM_TIR_SCHEDULE_BEGIN(); tir::UnsafeHideBufferAccess(state_, this->GetSRef(block_rv), buf_type, buf_index_array); TVM_TIR_SCHEDULE_END("hide-buffer-access", this->error_render_level_); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 5f3f0c8b61f1..f19fb3143e8a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -33,13 +33,13 @@ class ConcreteScheduleNode : public ScheduleNode { friend class ScheduleCopier; public: - using TSymbolTable = Map; + using TSymbolTable = ffi::Map; protected: /*! \brief The internal state of scheduling */ ScheduleState state_; /*! \brief The function to be worked on. */ - Optional func_working_on_; + ffi::Optional func_working_on_; /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ @@ -58,9 +58,9 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } - Optional trace() const override { return std::nullopt; } - Optional func_working_on() const final { return func_working_on_; } - void WorkOn(const String& func_name) final; + ffi::Optional trace() const override { return std::nullopt; } + ffi::Optional func_working_on() const final { return func_working_on_; } + void WorkOn(const ffi::String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; support::LinearCongruentialEngine::TRandState ForkSeed() final; @@ -73,8 +73,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; inline bool HasBlock(const BlockRV& block_rv) const final; - inline Array GetSRefs(const Array& rvs) const; - inline Array GetSRefs(const Array& rvs) const; + inline ffi::Array GetSRefs(const ffi::Array& rvs) const; + inline ffi::Array GetSRefs(const ffi::Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -82,59 +82,63 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) override; - Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) override; - Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) override; + ExprRV SampleCategorical(const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional decision = std::nullopt) override; + ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) override; + ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) override; LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) override; + ffi::Optional decision = std::nullopt) override; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const Optional& func_name) override; - Array GetLoops(const BlockRV& block_rv) override; - Array GetChildBlocks(const BlockRV& block_rv) override; - Array GetChildBlocks(const LoopRV& loop_rv) override; - Array GetProducers(const BlockRV& block_rv) override; - Array GetConsumers(const BlockRV& block_rv) override; - Array GetOutputBlocks(const BlockRV& scope_block_rv) override; + BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) override; + ffi::Array GetLoops(const BlockRV& block_rv) override; + ffi::Array GetChildBlocks(const BlockRV& block_rv) override; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) override; + ffi::Array GetProducers(const BlockRV& block_rv) override; + ffi::Array GetConsumers(const BlockRV& block_rv) override; + ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) override; /******** Schedule: Transform loops ********/ - LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; - LoopRV Merge(const Array& loop_rvs) override; - Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters, bool disable_predication) override; - Array LoopPartition(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters) override; - void Reorder(const Array& ordered_loop_rvs) override; - void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; + LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) override; + LoopRV Merge(const ffi::Array& loop_rvs) override; + ffi::Array Split(const LoopRV& loop_rv, const ffi::Array>& factors, + bool preserve_unit_iters, bool disable_predication) override; + ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factors, + bool preserve_unit_iters) override; + void Reorder(const ffi::Array& ordered_loop_rvs) override; + void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; - void Bind(const LoopRV& loop_rv, const String& thread_axis) override; + void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) override; void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) override; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) override; + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) override; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) override; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) override; + const ffi::String& storage_scope, const IndexMap& index_map) override; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) override; - Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) override; - Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) override; + const ffi::String& storage_scope, const IndexMap& index_map) override; + ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) override; + ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) override; + const ffi::String& storage_scope) override; BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) override; + const ffi::String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) override; @@ -145,38 +149,41 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; - void PadEinsum(const BlockRV& block_rv, const Array& padding) override; + void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) override; /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; - void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override; + void SetScope(const BlockRV& block_rv, int buffer_index, + const ffi::String& storage_scope) override; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) override; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override; - BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) override; - void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override; - void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override; + BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) override; + void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + bool preserve_unit_iters) override; + void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, + bool preserve_unit_iters) override; /******** Schedule: Annotation ********/ - void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; - void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; + void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform = false) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) override; + const ffi::Array& axis_separators) override; /******** Schedule: Padding decomposition ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Buffer transformation ********/ void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} - void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) override; + void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) override; void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) override; @@ -195,7 +202,7 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variables created */ template - inline Array CreateRV(const Array& srefs); + inline ffi::Array CreateRV(const ffi::Array& srefs); /*! * \brief Add an sref as a random variable into the symbol table * \tparam T The type of the random variable @@ -217,8 +224,8 @@ class ConcreteScheduleNode : public ScheduleNode { * Which is convention of certain primitives. * \return The new random variables created */ - inline Array CreateRV(const std::vector& value, - bool convert_negone_to_none = false); + inline ffi::Array CreateRV(const std::vector& value, + bool convert_negone_to_none = false); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -237,17 +244,17 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); const BlockNode* block = TVM_SREF_TO_BLOCK(sref); - return GetRef(block); + return ffi::GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(sref); - return GetRef(loop); + return ffi::GetRef(loop); } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> ffi::Optional { auto it = this->symbol_table_.find(var); if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; @@ -286,7 +293,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { if (sref->stmt == nullptr) { LOG(FATAL) << "ValueError: The block no longer exists in the IRModule"; } - return GetRef(sref); + return ffi::GetRef(sref); } inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { @@ -311,12 +318,13 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { if (sref->stmt == nullptr) { LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule"; } - return GetRef(sref); + return ffi::GetRef(sref); } template -inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { - Array result; +inline ffi::Array GetSRefsHelper(const ConcreteScheduleNode* sch, + const ffi::Array& rvs) { + ffi::Array result; result.reserve(rvs.size()); for (const T& rv : rvs) { result.push_back(sch->GetSRef(rv)); @@ -324,19 +332,19 @@ inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Arr return result; } -inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { +inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { return GetSRefsHelper(this, rvs); } -inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { +inline ffi::Array ConcreteScheduleNode::GetSRefs(const ffi::Array& rvs) const { return GetSRefsHelper(this, rvs); } /******** Adding/Removing elements in the symbol table ********/ template -inline Array ConcreteScheduleNode::CreateRV(const Array& srefs) { - Array result; +inline ffi::Array ConcreteScheduleNode::CreateRV(const ffi::Array& srefs) { + ffi::Array result; result.reserve(srefs.size()); for (const StmtSRef& sref : srefs) { T rv; @@ -359,9 +367,9 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return rv; } -inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, - bool convert_negone_to_none) { - Array results; +inline ffi::Array ConcreteScheduleNode::CreateRV(const std::vector& value, + bool convert_negone_to_none) { + ffi::Array results; results.reserve(value.size()); for (int64_t v : value) { if (convert_negone_to_none && v == -1) { diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 479fc34c75af..ce882ebbc9c7 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -21,13 +21,13 @@ namespace tvm { namespace tir { -String ScheduleError::RenderReport(const String& primitive) const { +ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; // get locations of interest - Array locs = LocationsOfInterest(); - std::unordered_map loc_obj_to_name; + ffi::Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); std::string msg = DetailRenderTemplate(); PrinterConfig cfg; diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index 8ddffce3ce61..093e5519dbd7 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -35,7 +35,7 @@ class ScheduleError : public tvm::runtime::Error { /*! \brief The error occurred in this IRModule */ virtual IRModule mod() const = 0; /*! \brief The locations of interest that we want to point out */ - virtual Array LocationsOfInterest() const = 0; + virtual ffi::Array LocationsOfInterest() const = 0; /*! * \brief Returns an error string template for rendering, corresponds to the "detail" mode. * \sa ScheduleErrorRenderLevel @@ -45,14 +45,14 @@ class ScheduleError : public tvm::runtime::Error { * now it only printed out all the locations in plain text, but in the future, we may want to mark * the IR with underscores and attach names to each location of interest. */ - virtual String DetailRenderTemplate() const = 0; + virtual ffi::String DetailRenderTemplate() const = 0; /*! * \brief Returns an error string without needing to render, corresponds to the "fast" mode * \sa ScheduleErrorRenderLevel */ - virtual String FastErrorString() const = 0; + virtual ffi::String FastErrorString() const = 0; /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ - String RenderReport(const String& primitive) const; + ffi::String RenderReport(const ffi::String& primitive) const; }; class LoopPositionError : public ScheduleError { @@ -63,11 +63,11 @@ class LoopPositionError : public ScheduleError { block_(std::move(block)), primitive_(primitive) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: " + primitive_ + " expect the loop to be an ancestor of block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The input loop {0} of " << primitive_ << " is required to be an ancestor of block {1}."; @@ -75,7 +75,7 @@ class LoopPositionError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_, block_}; } + ffi::Array LocationsOfInterest() const final { return {loop_, block_}; } IRModule mod_; For loop_; diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 0e911580338a..f2100930c977 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -33,9 +33,9 @@ bool InstructionKindNode::IsPostproc() const { return this == inst_enter_postproc.get(); } -Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, - Array outputs) { - ObjectPtr n = make_object(); +Instruction::Instruction(InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs) { + ObjectPtr n = ffi::make_object(); n->kind = std::move(kind); n->inputs = std::move(inputs); n->attrs = std::move(attrs); @@ -45,17 +45,17 @@ Instruction::Instruction(InstructionKind kind, Array inputs, Array att using InstructionKindRegistry = AttrRegistry; -InstructionKind InstructionKind::Get(const String& name) { +InstructionKind InstructionKind::Get(const ffi::String& name) { const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name); ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered"; return reg->inst_kind_; } InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) { - this->inst_kind_ = InstructionKind(make_object()); + this->inst_kind_ = InstructionKind(ffi::make_object()); } -InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) { +InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::String& name) { return InstructionKindRegistry::Global()->RegisterOrGet(name); } @@ -65,29 +65,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as(); ICHECK_NOTNULL(self); - Array inputs; + ffi::Array inputs; inputs.reserve(self->inputs.size()); for (const Any& obj : self->inputs) { if (obj == nullptr) { - inputs.push_back(String("None")); + inputs.push_back(ffi::String("None")); } else if (auto opt_str = obj.as()) { - inputs.push_back(String('"' + (*opt_str).operator std::string() + '"')); + inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else if (obj.as() || obj.as()) { - inputs.push_back(String("_")); + inputs.push_back(ffi::String("_")); } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { inputs.push_back(obj); } else if (obj.as() || obj.as()) { inputs.push_back(obj); } else if (const auto* expr = obj.as()) { - PrimExpr new_expr = - Substitute(GetRef(expr), [](const Var& var) -> Optional { - ObjectPtr new_var = make_object(*var.get()); + PrimExpr new_expr = Substitute( + ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { + ObjectPtr new_var = ffi::make_object(*var.get()); new_var->name_hint = "_"; return Var(new_var); }); std::ostringstream os; os << new_expr; - inputs.push_back(String(os.str())); + inputs.push_back(ffi::String(os.str())); } else if (obj.as()) { inputs.push_back(obj); } else { @@ -99,7 +99,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /*inputs=*/inputs, /*attrs=*/self->attrs, /*decision=*/Any(nullptr), - /*outputs=*/Array(self->outputs.size(), String("_"))); + /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); }); /**************** FFI ****************/ @@ -109,8 +109,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("tir.schedule.InstructionKindGet", InstructionKind::Get) .def("tir.schedule.Instruction", - [](InstructionKind kind, Array inputs, Array attrs, Array outputs) - -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); + [](InstructionKind kind, ffi::Array inputs, ffi::Array attrs, + ffi::Array outputs) -> Instruction { + return Instruction(kind, inputs, attrs, outputs); + }); }); } // namespace tir diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index bff619ca49cc..93a1dd77ab64 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -44,25 +44,25 @@ namespace tir { * static constexpr bool kIsPure = false; * * // Convertible to `InstructionKindNode::FInstructionApply` - * static Array ApplyToSchedule( + * static ffi::Array ApplyToSchedule( * const tir::Schedule& sch, - * const Array& inputs, - * const Array& attrs, - * const Optional& decision); + * const ffi::Array& inputs, + * const ffi::Array& attrs, + * const ffi::Optional& decision); * * // Convertible to `InstructionKindNode::FInstructionAsPython` - * static String AsPython( - * const Array& inputs, - * const Array& attrs, - * const Optional& decision, - * const Array& outputs); + * static ffi::String AsPython( + * const ffi::Array& inputs, + * const ffi::Array& attrs, + * const ffi::Optional& decision, + * const ffi::Array& outputs); * * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON` * static ObjectRef AttrsAsJSON( - * const Array& attrs); + * const ffi::Array& attrs); * * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON` - * static Array AttrsFromJSON( + * static ffi::Array AttrsFromJSON( * const ObjectRef& attrs_record); * }; * @@ -108,12 +108,12 @@ namespace tir { * // - The next `kNumInputs` arguments are input random variables * // - The next `kNumAttrs` arguments are attributes * // - The next argument is decision, if `kNumDecisions == 1` - * static Array UnpackedApplyToSchedule( + * static ffi::Array UnpackedApplyToSchedule( * Schedule sch, * LoopRV loop_rv, * Integer n, * Integer max_innermost_factor, - * Optional> decision) { + * ffi::Optional> decision) { * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); * } * @@ -123,12 +123,12 @@ namespace tir { * // - The next `kNumInputs` arguments are names of input random variables * // - The next `kNumAttrs` arguments are attributes * // - The next argument is decision, if `kNumDecisions == 1` - * static String UnpackedAsPython( - * Array outputs, - * String loop_rv, + * static ffi::String UnpackedAsPython( + * ffi::Array outputs, + * ffi::String loop_rv, * Integer n, * Integer max_innermost_factor, - * Optional> decision) { + * ffi::Optional> decision) { * PythonAPICall py("sample_perfect_tile"); * py.Input("loop", loop_rv); * py.Input("n", n->value); @@ -152,16 +152,16 @@ struct UnpackedInstTraits { * `TTraits::UnpackedApplyToSchedule` * \sa InstructionKindNode::f_apply_to_schedule */ - static Array ApplyToSchedule(const Schedule& sch, const Array& inputs, - const Array& attrs, const Any& decision); + static ffi::Array ApplyToSchedule(const Schedule& sch, const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision); /*! * \brief Unpack the arguments in the calling convention, and feed them into * `TTraits::UnpackedAsPython` * \sa InstructionKindNode::f_as_python */ - static String AsPython(const Array& inputs, const Array& attrs, const Any& decision, - const Array& outputs); + static ffi::String AsPython(const ffi::Array& inputs, const ffi::Array& attrs, + const Any& decision, const ffi::Array& outputs); /*! \brief No customized serializer by default */ static constexpr std::nullptr_t AttrsAsJSON = nullptr; @@ -171,12 +171,12 @@ struct UnpackedInstTraits { protected: template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs); + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs); template - static TVM_ALWAYS_INLINE void _SetAttrs(AnyView* packed_args, const Array& attrs); + static TVM_ALWAYS_INLINE void _SetAttrs(AnyView* packed_args, const ffi::Array& attrs); template static TVM_ALWAYS_INLINE void _SetDecision(AnyView* packed_args, const Any& decision); - static TVM_ALWAYS_INLINE Array _ConvertOutputs(const ffi::Any& rv); + static TVM_ALWAYS_INLINE ffi::Array _ConvertOutputs(const ffi::Any& rv); }; /*! @@ -190,32 +190,33 @@ class PythonAPICall { * \brief Constructor * \param method_name The name of the schedule API to be called */ - explicit PythonAPICall(String method_name) : method_name_(method_name), output_(std::nullopt) {} + explicit PythonAPICall(ffi::String method_name) + : method_name_(method_name), output_(std::nullopt) {} /*! \brief Add an integer input */ - inline void Input(String arg_name, int arg); + inline void Input(ffi::String arg_name, int arg); /*! \brief Add an integer input */ - inline void Input(String arg_name, int64_t arg); + inline void Input(ffi::String arg_name, int64_t arg); /*! \brief Add a bool input */ - inline void Input(String arg_name, bool arg); + inline void Input(ffi::String arg_name, bool arg); /*! \brief Add a double input */ - inline void Input(String arg_name, double arg); + inline void Input(ffi::String arg_name, double arg); /*! \brief Add an input random variable */ - inline void Input(String arg_name, String arg); + inline void Input(ffi::String arg_name, ffi::String arg); /*! \brief Add an input random variable */ - inline void Input(String arg_name, std::string arg); + inline void Input(ffi::String arg_name, std::string arg); /*! \brief Add an input, dispatched to different implementations according to the object's type */ - inline void Input(String arg_name, Any arg); + inline void Input(ffi::String arg_name, Any arg); /*! \brief Add the decision */ inline void Decision(Any decision); /*! * \brief Add a single output random variable * \param unit_array An array containing only one element */ - inline void SingleOutput(Array unit_array); + inline void SingleOutput(ffi::Array unit_array); /*! \brief Add a list of output random variables */ - inline void OutputList(Array outputs); + inline void OutputList(ffi::Array outputs); /*! \returns The schedule API call in python syntax */ - inline String Str() const; + inline ffi::String Str() const; private: /*! \brief Converts a TVM object to python string and print to the output stream */ @@ -223,13 +224,13 @@ class PythonAPICall { private: /*! \brief The name of the API to call */ - String method_name_; + ffi::String method_name_; /*! \brief The output of the instruction */ - Optional output_; + ffi::Optional output_; /*! \brief The names of input arguments */ - std::vector arg_names_; + std::vector arg_names_; /*! \brief The values of input arguments */ - std::vector args_; + std::vector args_; }; /********** implementation details **********/ @@ -272,7 +273,7 @@ template struct _IsTVMArray : std::false_type {}; template -struct _IsTVMArray> : std::true_type {}; +struct _IsTVMArray> : std::true_type {}; template struct _IsSingleObject @@ -297,10 +298,10 @@ static constexpr int IsSingleObject = _IsSingleObject>::valu }; // namespace details template -Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, - const Array& inputs, - const Array& attrs, - const Any& decision) { +ffi::Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, + const ffi::Array& inputs, + const ffi::Array& attrs, + const Any& decision) { using method_type = decltype(TTraits::UnpackedApplyToSchedule); using return_type = details::ReturnType; // static_assert(details::ArgumentAreAllObjects, @@ -329,8 +330,9 @@ Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, } template -String UnpackedInstTraits::AsPython(const Array& inputs, const Array& attrs, - const Any& decision, const Array& outputs) { +ffi::String UnpackedInstTraits::AsPython(const ffi::Array& inputs, + const ffi::Array& attrs, const Any& decision, + const ffi::Array& outputs) { using method_type = decltype(TTraits::UnpackedAsPython); using return_type = details::ReturnType; // static_assert(details::ArgumentAreAllObjects, @@ -355,13 +357,13 @@ String UnpackedInstTraits::AsPython(const Array& inputs, const Arr }); ffi::Any rv; pf.CallPacked(ffi::PackedArgs(packed_args, kNumArgs), &rv); - return rv.cast(); + return rv.cast(); } template template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(AnyView* packed_args, - const Array& inputs) { + const ffi::Array& inputs) { constexpr size_t kNumInputs = TTraits::kNumInputs; ICHECK_EQ(kNumInputs, inputs.size()) << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; @@ -373,7 +375,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(AnyView* packed_a template template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetAttrs(AnyView* packed_args, - const Array& attrs) { + const ffi::Array& attrs) { constexpr size_t kNumAttrs = TTraits::kNumAttrs; ICHECK_EQ(kNumAttrs, attrs.size()) << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; @@ -396,7 +398,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision(AnyView* packed } template -TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const ffi::Any& rv) { +TVM_ALWAYS_INLINE ffi::Array UnpackedInstTraits::_ConvertOutputs(const ffi::Any& rv) { using method_type = decltype(TTraits::UnpackedApplyToSchedule); using return_type = details::ReturnType; constexpr int is_array = details::IsTVMArray; @@ -409,7 +411,7 @@ TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const } else if (is_single_obj) { return {rv}; } else if (is_array) { - return rv.cast>(); + return rv.cast>(); } } @@ -466,17 +468,17 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { } } -void PythonAPICall::Input(String arg_name, int arg) { +void PythonAPICall::Input(ffi::String arg_name, int arg) { arg_names_.emplace_back(std::move(arg_name)); args_.push_back(std::to_string(arg)); } -void PythonAPICall::Input(String arg_name, int64_t arg) { +void PythonAPICall::Input(ffi::String arg_name, int64_t arg) { arg_names_.emplace_back(std::move(arg_name)); args_.push_back(std::to_string(arg)); } -void PythonAPICall::Input(String arg_name, bool arg) { +void PythonAPICall::Input(ffi::String arg_name, bool arg) { static const char* true_str = "True"; static const char* false_str = "False"; arg_names_.emplace_back(std::move(arg_name)); @@ -487,7 +489,7 @@ void PythonAPICall::Input(String arg_name, bool arg) { } } -void PythonAPICall::Input(String arg_name, double arg) { +void PythonAPICall::Input(ffi::String arg_name, double arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; os.precision(17); @@ -495,17 +497,17 @@ void PythonAPICall::Input(String arg_name, double arg) { args_.push_back(os.str()); } -void PythonAPICall::Input(String arg_name, String arg) { +void PythonAPICall::Input(ffi::String arg_name, ffi::String arg) { arg_names_.emplace_back(std::move(arg_name)); args_.emplace_back(std::move(arg)); } -void PythonAPICall::Input(String arg_name, std::string arg) { +void PythonAPICall::Input(ffi::String arg_name, std::string arg) { arg_names_.emplace_back(std::move(arg_name)); args_.emplace_back(std::move(arg)); } -void PythonAPICall::Input(String arg_name, Any arg) { +void PythonAPICall::Input(ffi::String arg_name, Any arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; AsPythonString(arg, os); @@ -518,12 +520,12 @@ void PythonAPICall::Decision(Any decision) { } } -void PythonAPICall::SingleOutput(Array unit_array) { +void PythonAPICall::SingleOutput(ffi::Array unit_array) { ICHECK_EQ(unit_array.size(), 1); this->output_ = unit_array[0]; } -void PythonAPICall::OutputList(Array outputs) { +void PythonAPICall::OutputList(ffi::Array outputs) { if (outputs.empty()) { return; } @@ -539,7 +541,7 @@ void PythonAPICall::OutputList(Array outputs) { this->output_ = os.str(); } -String PythonAPICall::Str() const { +ffi::String PythonAPICall::Str() const { std::ostringstream os; if (output_.has_value()) { os << output_.value() << " = "; diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 71b646855d50..bef35387cbaa 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -37,11 +37,11 @@ class TensorIntrinMismatchError : public ScheduleError { ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The stmt doesn't match the tensor intrin."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" << lhs_stmt_ << "\nDoes not match the tensorize description:\n" @@ -54,7 +54,7 @@ class TensorIntrinMismatchError : public ScheduleError { IRModule mod() const final { return lhs_mod_; } - Array LocationsOfInterest() const final { return {lhs_stmt_}; } + ffi::Array LocationsOfInterest() const final { return {lhs_stmt_}; } private: IRModule lhs_mod_; @@ -309,7 +309,7 @@ bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { const auto* rhs = other.as(); - auto lhs = GetRef(op); + auto lhs = ffi::GetRef(op); if (lhs.same_as(other)) return true; if (op->dtype.code() != rhs->dtype.code()) { if (assert_mode_) { @@ -348,8 +348,8 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { return true; } -bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, - const std::pair& rhs) { +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { if (lhs.first != rhs.first) { if (assert_mode_) { std::ostringstream os; @@ -376,8 +376,8 @@ bool TensorizeComparator::CompareAnnotation(const std::pair& l return true; } -bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, - const Map& rhs) { +bool TensorizeComparator::CompareAnnotationMap(const ffi::Map& lhs, + const ffi::Map& rhs) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) { if (assert_mode_) { @@ -389,14 +389,15 @@ bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, return false; } - auto sort_map = [](const Map& map) -> std::vector> { - std::vector> ret(map.begin(), map.end()); + auto sort_map = [](const ffi::Map& map) + -> std::vector> { + std::vector> ret(map.begin(), map.end()); sort(ret.begin(), ret.end(), [](const auto& a, const auto& b) { return a.first < b.first; }); return ret; }; - std::vector> lhs_array = sort_map(lhs); - std::vector> rhs_array = sort_map(rhs); + std::vector> lhs_array = sort_map(lhs); + std::vector> rhs_array = sort_map(rhs); for (size_t i = 0; i < lhs.size(); ++i) { if (!CompareAnnotation(lhs_array[i], rhs_array[i])) { @@ -582,7 +583,8 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { } template -bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp) { +bool TensorizeComparator::CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, + F Self::*cmp) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) { if (assert_mode_) { @@ -704,7 +706,7 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_)); } - auto is_scalar_access = [](const Array& indices, PrimExpr index) { + auto is_scalar_access = [](const ffi::Array& indices, PrimExpr index) { // Check if the indexing is of the form C[0] if (indices.size() > 1) return false; auto int_imm = index.template as(); @@ -722,8 +724,8 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { if (it_rhs == rhs_buffer_indices_map_.end()) { return false; } - auto indices_check = [&](const Array& indices, - const Array& old_indices) -> bool { + auto indices_check = [&](const ffi::Array& indices, + const ffi::Array& old_indices) -> bool { if (indices.size() != old_indices.size()) { return false; } diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index a15de7b97a91..665d093b2fa4 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -86,13 +86,14 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool DefEqual(const Var& lhs, const Var& rhs); virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); - bool CompareAnnotation(const std::pair& lhs, - const std::pair& rhs); - bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const ffi::Map& lhs, + const ffi::Map& rhs); template bool CompareBufferAccess(const T* lhs, const T* rhs); template - bool CompareArray(const Array& lhs, const Array& rhs, F Self::*cmp); + bool CompareArray(const ffi::Array& lhs, const ffi::Array& rhs, F Self::*cmp); bool CompareRange(const Range& lhs, const Range& rhs); bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); void EmitError(const std::string& error_message); @@ -151,17 +152,17 @@ class AutoTensorizeComparator : public TensorizeComparator { /*! \brief Block iters in the RHS stmt. */ std::vector rhs_iters_; /*! \brief The buffer and its access indices in the LHS stmt. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> lhs_buffer_indices_map_; /*! \brief The buffer and its access indices in the RHS stmt. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> rhs_buffer_indices_map_; /*! \brief Map from LHS buffer to RHS buffer */ std::unordered_map lhs_buffer_map_; private: /*! \brief The domain of the inner block iters. */ - Map inner_iter_dom_map_; + ffi::Map inner_iter_dom_map_; }; } // namespace tir diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index de8fe7238ea7..0c3e5a0efd21 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,9 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision); + const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. @@ -98,7 +99,7 @@ TVM_DLL std::vector SamplePerfectTile( TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, - Optional>* decision); + ffi::Optional>* decision); /*! * \brief Sample the factors to a partitioned tile for a specific loop * @@ -136,7 +137,7 @@ TVM_DLL std::vector SamplePartitionedTile( TVM_DLL std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos, - int32_t innerpart_factor, Optional>* decision); + int32_t innerpart_factor, ffi::Optional>* decision); /*! * \brief Sample a compute-at location of the given block * \param self The schedule state @@ -147,7 +148,7 @@ TVM_DLL std::vector SamplePartitionedTile( */ TVM_DLL tir::StmtSRef SampleComputeLocation( tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, - const tir::StmtSRef& block_sref, Optional* decision); + const tir::StmtSRef& block_sref, ffi::Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -157,35 +158,36 @@ TVM_DLL tir::StmtSRef SampleComputeLocation( * \param gvar The function to be retrieved * \return A list of blocks with the specific name */ -Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv); +ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv); /*! * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state * \param block_sref The query block * \return A list of loops above the given block in its scope, from outer to inner */ -Array GetLoops(const StmtSRef& block_sref); +ffi::Array GetLoops(const StmtSRef& block_sref); /*! * \brief Get the leaf blocks of a specific block/loop * \param self The schedule state * \param parent_sref The query block/loop * \return A list of leaf blocks inside a specific block/loop */ -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); /*! * \brief Get the producers of a specific block * \param self The schedule state * \param block_sref The block in the query * \return A list of blocks, the producers of the given block */ -Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); +ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); /*! * \brief Get the consumers of a specific block * \param self The schedule state * \param block_rv The block in the query * \return A list of blocks, the consumers of the given block */ -Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); +ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); /*! * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written @@ -194,7 +196,7 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); +ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: @@ -210,9 +212,9 @@ Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope * Warning: enabling this feature may result in incorrect code generation if not used * carefully. \return An array of srefs to the loops after splitting */ -TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters, - bool disable_predication); +TVM_DLL ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters, + bool disable_predication); /*! * Partition a loop into a list of consecutive loops. It requires: @@ -223,8 +225,9 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return An array of srefs to the loops after partitioning */ -TVM_DLL Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters); +TVM_DLL ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, + bool preserve_unit_iters); /*! * \brief Merge a list of loops into one. The loops under their LCA requires: @@ -236,7 +239,7 @@ TVM_DLL Array LoopPartition(ScheduleState self, const StmtSRef& loop_s * \param loop_srefs An array of srefs to the loops to be merged * \return The new loop after merge */ -TVM_DLL StmtSRef Merge(ScheduleState self, const Array& loop_srefs); +TVM_DLL StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs); /*! * \brief Fuse a list of consecutive loops into one. It requires: @@ -249,7 +252,7 @@ TVM_DLL StmtSRef Merge(ScheduleState self, const Array& loop_srefs); * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The sref to the fused loop */ -TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, +TVM_DLL StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, bool preserve_unit_loops); /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. @@ -264,7 +267,7 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, * \param self The state of the schedule * \param ordered_loop_srefs An array of srefs which indicates the new order of loops */ -TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +TVM_DLL void Reorder(ScheduleState self, const ffi::Array& ordered_loop_srefs); /*! * \brief Reorder itervars inside a block. @@ -273,7 +276,7 @@ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_sre * \param new_order The new itervar order. */ TVM_DLL void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, - const Array& new_order); + const ffi::Array& new_order); /*! * \brief Create a new unit loop on top of the specific block or loop. @@ -320,7 +323,7 @@ TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); * \param loop_sref The sref of the loop to be bound to the thread axis * \param thread_axis The thread axis to be bound to the loop */ -TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis); +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const ffi::String& thread_axis); /*! * \brief Unroll the input loop. It requires nothing * \param self The state of the schedule @@ -340,7 +343,8 @@ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); * \return The cache stage block. */ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const Array consumer_blocks = {}); + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. @@ -353,8 +357,8 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r * \return The cache stage block. */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks = {}); + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}); /*! * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. @@ -369,7 +373,7 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope, + int read_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: @@ -385,7 +389,7 @@ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope, + int write_buffer_index, const ffi::String& storage_scope, const IndexMap& index_map); /*! @@ -398,8 +402,8 @@ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sre * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ -TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope); +TVM_DLL ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const ffi::String& storage_scope); /*! * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. @@ -408,8 +412,8 @@ TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_s * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage block. */ -TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, - const String& storage_scope, int cse_thresh); +TVM_DLL ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& storage_scope, int cse_thresh); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. @@ -429,10 +433,10 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf /******** Schedule: Data movement ********/ TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope); + int read_buffer_index, const ffi::String& storage_scope); TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope); + int write_buffer_index, const ffi::String& storage_scope); /******** Schedule: Compute location ********/ /*! @@ -561,7 +565,7 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu * \param storage_scope The storage scope to be set */ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& storage_scope); + const ffi::String& storage_scope); /*! * \brief Set the data type of a buffer, where the buffer is specified by a block and a * write-index @@ -573,7 +577,7 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer * \param dtype The data type to be set */ TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype); + const ffi::String& dtype); /*! * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read * or write index @@ -584,7 +588,7 @@ TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int */ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators); + const ffi::Array& axis_separators); /******** Schedule: Blockize & Tensorize ********/ @@ -604,7 +608,7 @@ TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool pr * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new block */ -TVM_DLL StmtSRef Blockize(ScheduleState self, const Array& blocks, +TVM_DLL StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, bool preserve_unit_iters); /*! @@ -625,7 +629,7 @@ TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, * \param ann_key The annotation key * \param ann_val The annotation value */ -TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, +TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, const Any& ann_val); /*! * \brief Unannotate a block/loop's annotation with key ann_key @@ -633,7 +637,7 @@ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& an * \param sref The block/loop to be unannotated * \param ann_key The annotation key */ -TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); +TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key); /******** Schedule: Layout transformation ********/ /*! @@ -656,7 +660,8 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& */ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, bool assume_injective_transform); + const ffi::Optional& pad_value, + bool assume_injective_transform); /*! * \brief Apply a transformation represented by IndexMap to block @@ -688,7 +693,7 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref * \param padding The padding for each block iter. */ TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, - const Array& padding); + const ffi::Array& padding); /******** Schedule: Buffer transformation ********/ /*! * \brief Compute the target buffer via rolling buffering. @@ -715,7 +720,8 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w * \param buf_index_array The array of buffer indices we hide access. */ TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, - const String& buf_type, const Array& buf_index_array); + const ffi::String& buf_type, + const ffi::Array& buf_index_array); /*! * \brief Annotate the read or write region of a specific buffer in a block diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index e00ac2a5bba9..c398a46418a6 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -21,9 +21,10 @@ namespace tvm { namespace tir { -void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, const Any& ann_val) { +void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, + const Any& ann_val) { // Extract annotation - const Map* annotations = nullptr; + const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { @@ -36,27 +37,27 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, c return; } // Add the new annotation - Map new_ann(*annotations); + ffi::Map new_ann(*annotations); new_ann.Set(ann_key, ann_val); // Create the new stmt if (const auto* loop = sref->StmtAs()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = make_object(*block); + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); Block p(n); - self->Replace(sref, p, {{GetRef(block), p}}); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } -void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { +void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key) { // Extract annotation - const Map* annotations = nullptr; + const ffi::Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { @@ -67,18 +68,18 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) // Remove the annotation ICHECK(annotations->find(ann_key) != annotations->end()) << "IndexError: Cannot find annotation key: " << ann_key; - Map new_ann(*annotations); + ffi::Map new_ann(*annotations); new_ann.erase(ann_key); // Create the new stmt if (const auto* loop = sref->StmtAs()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { - ObjectPtr n = make_object(*block); + ObjectPtr n = ffi::make_object(*block); n->annotations = std::move(new_ann); Block p(n); - self->Replace(sref, p, {{GetRef(block), p}}); + self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -95,7 +96,7 @@ struct AnnotateTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, Any ann_val, - String ann_key) { + ffi::String ann_key) { if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } @@ -106,8 +107,8 @@ struct AnnotateTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, Any ann_val, - String ann_key) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef block_or_loop_rv, + Any ann_val, ffi::String ann_key) { PythonAPICall py("annotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); @@ -128,7 +129,8 @@ struct UnannotateTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, + ffi::String ann_key) { if (auto block = block_or_loop_rv.as()) { return sch->Unannotate(block.value(), ann_key); } @@ -139,8 +141,8 @@ struct UnannotateTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, - String ann_key) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef block_or_loop_rv, + ffi::String ann_key) { PythonAPICall py("unannotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc index ce767339ee50..84672dede70d 100644 --- a/src/tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -33,7 +33,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - Array regions = + ffi::Array regions = buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; @@ -47,12 +47,13 @@ class AnnotateRegionRewriter : public StmtExprMutator { } // Annotate the block with explicit_read_region or explicit_write_region - Map new_annotations = n->annotations; - String annotation_key = buffer_index_type_ == BufferIndexType::kWrite - ? attr::explicit_write_region - : attr::explicit_read_region; + ffi::Map new_annotations = n->annotations; + ffi::String annotation_key = buffer_index_type_ == BufferIndexType::kWrite + ? attr::explicit_write_region + : attr::explicit_read_region; if (new_annotations.count(annotation_key)) { - Array buffer_indices = Downcast>(new_annotations[annotation_key]); + ffi::Array buffer_indices = + Downcast>(new_annotations[annotation_key]); bool found = false; for (const Integer& index : buffer_indices) { if (index->value == buffer_index_) { @@ -65,7 +66,7 @@ class AnnotateRegionRewriter : public StmtExprMutator { new_annotations.Set(annotation_key, buffer_indices); } } else { - new_annotations.Set(annotation_key, Array{Integer(buffer_index_)}); + new_annotations.Set(annotation_key, ffi::Array{Integer(buffer_index_)}); } n->annotations = std::move(new_annotations); @@ -82,16 +83,17 @@ class AnnotateRegionRewriter : public StmtExprMutator { void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, buffer_index_type); + Buffer buffer = + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, buffer_index_type); arith::Analyzer analyzer; - Array block_iter_vars; + ffi::Array block_iter_vars; for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var->var); } - Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; - Array new_ranges; + ffi::Array new_ranges; for (size_t i = 0; i < new_indices.size(); i += 2) { // (begin, end) represents a region new_ranges.push_back(Range::FromMinExtent( @@ -101,9 +103,9 @@ void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int bu BufferRegion new_region(buffer, new_ranges); AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); - Stmt new_stmt = mutator(GetRef(block_sref->stmt)); + Stmt new_stmt = mutator(ffi::GetRef(block_sref->stmt)); - self->Replace(block_sref, new_stmt, {{GetRef(block), Downcast(new_stmt)}}); + self->Replace(block_sref, new_stmt, {{ffi::GetRef(block), Downcast(new_stmt)}}); } struct AnnotateBufferAccessTraits : public UnpackedInstTraits { @@ -122,7 +124,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraitsinitial_indices.size(); ++i) { @@ -139,11 +141,12 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits outputs, String block, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer buffer_index, Integer buffer_index_type, + IndexMap index_map) { PythonAPICall py("annotate_buffer_access"); py.Input("block", block); py.Input("buffer_index", buffer_index->value); @@ -151,7 +154,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits(buffer_index_type->value)) << "\""; - py.Input("buf_type", String(os.str())); + py.Input("buf_type", ffi::String(os.str())); py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map)); return py.Str(); diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 0e2a055d7afe..2bf62d409e2d 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -30,13 +30,13 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `axis` is out of range. It is required to be in range " "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " "storage alignment."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; int ndim = static_cast(buffer_->shape.size()); os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim @@ -47,7 +47,7 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { int ndim = static_cast(buffer->shape.size()); @@ -71,12 +71,12 @@ class NonAllocatedBufferError : public ScheduleError { public: explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " " either a function parameter or defined in `match_buffer` of a block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The input buffer " << buffer_->name << " is not allocated by a block. This means the buffer is either a function parameter or " @@ -94,7 +94,7 @@ class NonAllocatedBufferError : public ScheduleError { return defining_site_sref.value(); } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -107,12 +107,12 @@ class StorageAlignInvalidFactorError : public ScheduleError { explicit StorageAlignInvalidFactorError(IRModule mod, int factor) : mod_(std::move(mod)), factor_(factor) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `factor` of storage_align is expected to be a positive " "number."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The input `factor` of storage_align is expected to be a positive number. However, the " "input `factor` is " @@ -126,7 +126,7 @@ class StorageAlignInvalidFactorError : public ScheduleError { } } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -139,12 +139,12 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block annotation for storage align is expected to be an array of " "4-integer-tuples (buffer_index, axis, factor, offset)."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " "(buffer_index, axis, factor, offset). However, the block annotation with key " @@ -168,7 +168,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { return storage_align_annotation; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod() const final { return mod_; } private: @@ -194,7 +194,7 @@ class StorageScopeMutator : private ReplaceBufferMutator { * \return The new block after the mutation */ static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, - const String& storage_scope, Map* block_sref_reuse) { + const ffi::String& storage_scope, ffi::Map* block_sref_reuse) { Buffer new_buffer = WithScope(old_buffer, storage_scope); StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); @@ -202,8 +202,8 @@ class StorageScopeMutator : private ReplaceBufferMutator { } private: - StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope, - Map* block_sref_reuse) + StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, ffi::String storage_scope, + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -222,8 +222,8 @@ class StorageScopeMutator : private ReplaceBufferMutator { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, BufferIndexType::kWrite); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, + BufferIndexType::kWrite); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); @@ -231,7 +231,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind // Step 1: Get existing or create new annotation value. StorageAlignAnnotation storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, - GetRef(block_ptr)); + ffi::GetRef(block_ptr)); // Step 2: Update the annotation value bool found = false; @@ -250,14 +250,14 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind // Step 3: Replace the block with the new annotation Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); - self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); + self->Replace(block_sref, new_block, {{ffi::GetRef(block_ptr), new_block}}); } void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return. if (buffer.scope() == storage_scope) { @@ -274,9 +274,9 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given // storage scope. In the meanwhile, collect the block sref reuse information. - Map block_reuse_map; - Block new_block = StorageScopeMutator::Mutate(GetRef(alloc_site), buffer, storage_scope, - &block_reuse_map); + ffi::Map block_reuse_map; + Block new_block = StorageScopeMutator::Mutate(ffi::GetRef(alloc_site), buffer, + storage_scope, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -294,7 +294,7 @@ class DTypeMutator : private ReplaceBufferMutator { * \return The new block after the mutation */ static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype, - Map* block_sref_reuse) { + ffi::Map* block_sref_reuse) { Buffer new_buffer = WithDType(old_buffer, dtype); DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse); Stmt new_block = mutator.VisitStmt(allocate_site); @@ -303,7 +303,7 @@ class DTypeMutator : private ReplaceBufferMutator { private: DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse), src_dtype_(old_buffer->dtype), tgt_dtype_(dtype) {} @@ -343,11 +343,11 @@ class DTypeMutator : private ReplaceBufferMutator { }; void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); - DataType target_dtype(StringToDLDataType(dtype)); + GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, BufferIndexType::kWrite); + DataType target_dtype(ffi::StringToDLDataType(dtype)); // Step 1. If `dtype` equals the original data type, just return. if (buffer->dtype == target_dtype) { @@ -361,9 +361,9 @@ void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_i // Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given // dtype, and insert data type conversions. - Map block_reuse_map; + ffi::Map block_reuse_map; Block new_block = - DTypeMutator::Mutate(GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); + DTypeMutator::Mutate(ffi::GetRef(alloc_site), buffer, target_dtype, &block_reuse_map); self->Replace(alloc_site_sref, new_block, block_reuse_map); } @@ -384,8 +384,9 @@ struct StorageAlignTraits : public UnpackedInstTraits { offset->value); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer axis, Integer factor, Integer offset) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, Integer axis, Integer factor, + Integer offset) { PythonAPICall py("storage_align"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -409,12 +410,12 @@ struct SetScopeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - String storage_scope) { + ffi::String storage_scope) { return sch->SetScope(block_rv, buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, ffi::String storage_scope) { PythonAPICall py("set_scope"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -436,12 +437,12 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - String dtype) { + ffi::String dtype) { return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - String dtype) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, ffi::String dtype) { PythonAPICall py("unsafe_set_dtype"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 4828701bb571..fbc569ece689 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -52,18 +52,18 @@ class SubspaceNotDivisibleError : public ScheduleError { scope_loop_(std::move(scope_loop)), inner_block_(std::move(inner_block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The bindings of the inner block can not be blockized."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " "starting at {1}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + ffi::Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } private: IRModule mod_; @@ -86,17 +86,17 @@ class SubspaceNotDivisibleError : public ScheduleError { * \param inner_iters The iters of the inner space * \return The result of the subspace division. */ -Array> TrivialSubspaceDivision(const Array& iter_vars, - const Array& bindings, - const PrimExpr& predicate, - const Array& outer_iters, - const Array& inner_iters) { +ffi::Array> TrivialSubspaceDivision( + const ffi::Array& iter_vars, const ffi::Array& bindings, + const PrimExpr& predicate, const ffi::Array& outer_iters, + const ffi::Array& inner_iters) { if (!is_one(predicate)) return {}; - Array> res; + ffi::Array> res; std::unordered_set outer_loop_vars; std::unordered_set inner_loop_vars; - auto make_uses_var = [](const Array& vars) -> std::function { + auto make_uses_var = + [](const ffi::Array& vars) -> std::function { std::unordered_set var_set; var_set.reserve(vars.size()); for (const Var& var : vars) { @@ -154,15 +154,16 @@ Array> TrivialSubspaceDivision(const Array& iter * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner */ -Array> SubspaceDivide(const BlockRealize& realize, - const StmtSRef& block_sref, // - const StmtSRef& loop_sref, // - std::vector* loops, - arith::Analyzer* analyzer, bool preserve_unit_iters, - bool loop_sref_as_outer = false) { - Array inner_vars; - Array outer_vars; - Map loop_var_domain; +ffi::Array> SubspaceDivide(const BlockRealize& realize, + const StmtSRef& block_sref, // + const StmtSRef& loop_sref, // + std::vector* loops, + arith::Analyzer* analyzer, + bool preserve_unit_iters, + bool loop_sref_as_outer = false) { + ffi::Array inner_vars; + ffi::Array outer_vars; + ffi::Map loop_var_domain; bool inner = true; for (StmtSRefNode* sref = block_sref->parent; // sref && sref->stmt->IsInstance(); // @@ -179,7 +180,7 @@ Array> SubspaceDivide(const BlockRealize& realize, inner = false; } } - Array> result = + ffi::Array> result = arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, arith::IterMapLevel::Surjective, analyzer, /*simplify_trivial_iterators=*/!preserve_unit_iters); @@ -203,17 +204,18 @@ Array> SubspaceDivide(const BlockRealize& realize, * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return A substitution plan to the iterators in the original inner block. */ -Map DeriveBlockBinding(const Array& iter_vars, // - const Array>& division, // - Array* outer_iter_vars, // - Array* outer_bindings, // - Array* inner_iter_vars, // - Array* inner_bindings, // - bool preserve_unit_iters, bool reuse_outer = false) { +ffi::Map DeriveBlockBinding( + const ffi::Array& iter_vars, // + const ffi::Array>& division, // + ffi::Array* outer_iter_vars, // + ffi::Array* outer_bindings, // + ffi::Array* inner_iter_vars, // + ffi::Array* inner_bindings, // + bool preserve_unit_iters, bool reuse_outer = false) { using arith::IterMapExpr; using arith::IterMapExprNode; using arith::NormalizeIterMapToExpr; - Map block_var_subst; + ffi::Map block_var_subst; ICHECK_EQ(iter_vars.size() + 1, division.size()); arith::Analyzer ana; for (int i = 0, n = iter_vars.size(); i < n; ++i) { @@ -282,15 +284,15 @@ Map DeriveBlockBinding(const Array& iter_vars, * \return The inner block created. */ BlockRealize GenerateInner(bool is_write_reduction, - const Array& iter_vars, // - const Array& iter_values, // - const PrimExpr& predicate, // + const ffi::Array& iter_vars, // + const ffi::Array& iter_values, // + const PrimExpr& predicate, // Block block) { BlockNode* n = block.CopyOnWrite(); n->iter_vars = iter_vars; n->init = std::nullopt; if (is_write_reduction) { - Array reads; + ffi::Array reads; reads.reserve(block->writes.size() + block->reads.size()); reads.insert(reads.end(), block->writes.begin(), block->writes.end()); reads.insert(reads.end(), block->reads.begin(), block->reads.end()); @@ -308,15 +310,15 @@ BlockRealize GenerateInner(bool is_write_reduction, * \return The subtree of the init block and its outer loops. */ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize, - const std::vector& loops, String block_name) { + const std::vector& loops, ffi::String block_name) { const Block& inner_block = inner_realize->block; - Map subst_map; + ffi::Map subst_map; // Step 1: Create new block vars for the block inside the init stmt of outer block // A iter is used in the block if // 1) It is data parallel // 2) It is used in the original init block - Array iter_vars; - Array iter_values; + ffi::Array iter_vars; + ffi::Array iter_values; ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); int n = inner_block->iter_vars.size(); iter_vars.reserve(n); @@ -326,7 +328,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize const PrimExpr& iter_value = inner_realize->iter_values[i]; if (old_iter_var->iter_type == IterVarType::kDataPar && UsesVar(block_init, old_iter_var->var)) { - ObjectPtr new_iter_var = make_object(*old_iter_var.get()); + ObjectPtr new_iter_var = ffi::make_object(*old_iter_var.get()); new_iter_var->var = new_iter_var->var.copy_with_suffix("_init"); subst_map.Set(old_iter_var->var, new_iter_var->var); iter_vars.push_back(IterVar(new_iter_var)); @@ -354,7 +356,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize } } if (is_init_loop) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->loop_var = loop->loop_var.copy_with_suffix(""); new_loop->body = std::move(stmt); subst_map.Set(loop->loop_var, new_loop->loop_var); @@ -373,10 +375,10 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize * \param analyzer The analyzer for arithmetic simplification. * \return The substituted stmt. */ -Stmt Substitute(const Stmt& stmt, const Map& sub, - Map* block_sref_reuse, arith::Analyzer* analyzer) { +Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) { struct Replacer : public StmtExprMutator { - explicit Replacer(const Map& sub, Map* block_sref_reuse, + explicit Replacer(const ffi::Map& sub, ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} @@ -389,14 +391,14 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, } PrimExpr VisitExpr_(const VarNode* op) final { - if (Optional e = sub_.Get(GetRef(op))) { + if (ffi::Optional e = sub_.Get(ffi::GetRef(op))) { return e.value(); } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const BlockNode* op) final { - Block src = GetRef(op); + Block src = ffi::GetRef(op); Block tgt = Downcast(StmtExprMutator::VisitStmt_(op)); if (!src.same_as(tgt)) { block_sref_reuse_->Set(src, tgt); @@ -404,8 +406,8 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, return tgt; } - const Map& sub_; - Map* block_sref_reuse_; + const ffi::Map& sub_; + ffi::Map* block_sref_reuse_; arith::Analyzer* analyzer_; }; return Replacer(sub, block_sref_reuse, analyzer)(stmt); @@ -417,16 +419,16 @@ Stmt Substitute(const Stmt& stmt, const Map& sub, * \param dom_map The variables to be relaxed * \return The relaxed regions */ -Array EvalSetRegions(const Array& regions, - const Map& dom_map) { - Array results; +ffi::Array EvalSetRegions(const ffi::Array& regions, + const ffi::Map& dom_map) { + ffi::Array results; results.reserve(regions.size()); for (const BufferRegion& buffer_region : regions) { const Buffer& buffer = buffer_region->buffer; - Array relaxed = arith::EvalSet(buffer_region->region, dom_map); + ffi::Array relaxed = arith::EvalSet(buffer_region->region, dom_map); ICHECK_EQ(relaxed.size(), buffer->shape.size()); int ndim = buffer->shape.size(); - Array new_region; + ffi::Array new_region; new_region.reserve(ndim); for (int i = 0; i < ndim; ++i) { new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i]))); @@ -441,23 +443,24 @@ Array EvalSetRegions(const Array& regions, * \param regions The input regions for the union. * \return The union regions */ -Array UnionRegions(const Array& regions) { - typedef std::vector> ranges_t; +ffi::Array UnionRegions(const ffi::Array& regions) { + typedef std::vector> ranges_t; std::unordered_map intset_map; for (const BufferRegion& buffer_region : regions) { const Buffer& buffer = buffer_region->buffer; if (intset_map.find(buffer) == intset_map.end()) { - intset_map[buffer] = {buffer->shape.size(), Array()}; + intset_map[buffer] = {buffer->shape.size(), ffi::Array()}; } - std::vector> dim_range(buffer->shape.size(), Array()); + std::vector> dim_range(buffer->shape.size(), + ffi::Array()); for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim])); } } - Array results; + ffi::Array results; for (const auto& it : intset_map) { const Buffer& buffer = it.first; - Array regions; + ffi::Array regions; for (size_t dim = 0; dim < buffer->shape.size(); ++dim) { const arith::IntSet intset = arith::Union(it.second[dim]); regions.push_back({intset.min(), intset.max() + 1}); @@ -475,7 +478,7 @@ Array UnionRegions(const Array& regions) { */ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { for (const ForNode* loop : loops) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->body = std::move(stmt); stmt = For(new_loop); } @@ -483,7 +486,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { } BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - Map* block_sref_reuse, arith::Analyzer* analyzer, + ffi::Map* block_sref_reuse, arith::Analyzer* analyzer, bool preserve_unit_iters) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. @@ -492,25 +495,25 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, StmtSRef block_sref = self->stmt2ref.at(block.get()); // Step 2: Derive subspace division std::vector loops; - Array> division = + ffi::Array> division = SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters); if (division.empty()) { - throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); + throw SubspaceNotDivisibleError(self->mod, ffi::GetRef(loops.back()), block); } PrimExpr outer_predicate = division.back()[0]->extent; PrimExpr inner_predicate = division.back()[1]->extent; // Step 3. Derive block bindings for both outer and inner block. - Array outer_iter_vars; - Array inner_iter_vars; - Array outer_bindings; - Array inner_bindings; - Map block_var_subst = // + ffi::Array outer_iter_vars; + ffi::Array inner_iter_vars; + ffi::Array outer_bindings; + ffi::Array inner_bindings; + ffi::Map block_var_subst = // DeriveBlockBinding(block->iter_vars, division, // &outer_iter_vars, &outer_bindings, // &inner_iter_vars, &inner_bindings, // preserve_unit_iters); // Step 4: Do var substitution to adjust to the new block bindings - Map inner_iter_dom; + ffi::Map inner_iter_dom; for (const IterVar& iter : inner_iter_vars) { inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom)); analyzer->Bind(iter->var, iter->dom); @@ -549,12 +552,12 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, block_subst->init.defined() // ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, block_subst->name_hint + "_init") - : Optional(std::nullopt))); + : ffi::Optional(std::nullopt))); } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { arith::Analyzer analyzer; - Map block_sref_reuse; + ffi::Map block_sref_reuse; BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); @@ -566,34 +569,34 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u return result; } -BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& block_srefs, - const StmtSRef& lca, Map* block_sref_reuse, +BlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array& block_srefs, + const StmtSRef& lca, ffi::Map* block_sref_reuse, bool preserve_unit_iters) { - Array seq_body; + ffi::Array seq_body; PrimExpr outer_predicate{nullptr}; - Array outer_iter_vars{nullptr}; - Array outer_bindings{nullptr}; - Array read_regions; - Array write_regions; + ffi::Array outer_iter_vars{nullptr}; + ffi::Array outer_bindings{nullptr}; + ffi::Array read_regions; + ffi::Array write_regions; std::string outer_block_name = "outer_"; - Map loop_var_subst; + ffi::Map loop_var_subst; arith::Analyzer analyzer; for (const auto& block_sref : block_srefs) { auto block_realize = GetBlockRealize(self, block_sref); auto block = block_realize->block; // Step 1: Derive subspace division std::vector loops; - Array> division = SubspaceDivide(block_realize, block_sref, lca, &loops, - &analyzer, preserve_unit_iters, true); + ffi::Array> division = SubspaceDivide( + block_realize, block_sref, lca, &loops, &analyzer, preserve_unit_iters, true); if (division.empty()) { - throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); + throw SubspaceNotDivisibleError(self->mod, ffi::GetRef(loops.back()), block); } outer_predicate = division.back()[0]->extent; PrimExpr inner_predicate = division.back()[1]->extent; // Step 2. Derive block bindings for both outer and inner block. - Array inner_iter_vars; - Array inner_bindings; - Map block_var_subst = // + ffi::Array inner_iter_vars; + ffi::Array inner_bindings; + ffi::Map block_var_subst = // DeriveBlockBinding(block->iter_vars, division, // &outer_iter_vars, &outer_bindings, // &inner_iter_vars, &inner_bindings, // @@ -604,7 +607,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl loop_var_subst.Set(Downcast(outer_bindings[i]), outer_iter_vars[i]->var); } } - Map inner_iter_dom; + ffi::Map inner_iter_dom; for (const IterVar& iter : inner_iter_vars) { Range dom = Substitute(iter->dom, loop_var_subst); inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom)); @@ -637,7 +640,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl block_sref_reuse->Set(block, inner_realize->block); Stmt stmt = inner_realize; for (const ForNode* loop : loops) { - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->body = std::move(stmt); new_loop->extent = Substitute(new_loop->extent, loop_var_subst); stmt = For(new_loop); @@ -654,19 +657,19 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl /*writes=*/UnionRegions(write_regions), /*name_hint=*/outer_block_name, /*body=*/SeqStmt(seq_body), - /*init=*/Optional(std::nullopt))); + /*init=*/ffi::Optional(std::nullopt))); } class BlockizeRewriter : public StmtMutator { public: - static Stmt Rewrite(const StmtSRef& lca, const Array& blocks, + static Stmt Rewrite(const StmtSRef& lca, const ffi::Array& blocks, const BlockRealize& blockized) { BlockizeRewriter rewriter(lca, blocks, blockized); - return rewriter(GetRef(lca->stmt)); + return rewriter(ffi::GetRef(lca->stmt)); } private: - explicit BlockizeRewriter(const StmtSRef& lca, const Array& blocks, + explicit BlockizeRewriter(const StmtSRef& lca, const ffi::Array& blocks, const BlockRealize& blockized) : lca_(lca), blocks_(blocks), blockized_(blockized) {} @@ -676,7 +679,7 @@ class BlockizeRewriter : public StmtMutator { int idx_start = -1; int last_found_idx = -1; size_t cur_idx = 0; - Array new_seq; + ffi::Array new_seq; for (const Stmt& it : seq->seq) { target_in_ = false; Stmt stmt = StmtMutator::VisitStmt(it); @@ -717,17 +720,18 @@ class BlockizeRewriter : public StmtMutator { break; } } - return GetRef(block); + return ffi::GetRef(block); } StmtSRef lca_; - Array blocks_; + ffi::Array blocks_; BlockRealize blockized_; bool target_in_ = false; }; -StmtSRef Blockize(ScheduleState self, const Array& blocks, bool preserve_unit_iters) { - Map block_sref_reuse; +StmtSRef Blockize(ScheduleState self, const ffi::Array& blocks, + bool preserve_unit_iters) { + ffi::Map block_sref_reuse; auto lca = GetSRefLowestCommonAncestor(blocks); BlockRealize blockized = BlockizeBlocks(self, blocks, lca, &block_sref_reuse, preserve_unit_iters); @@ -743,17 +747,17 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed BlockRealize block_realize{nullptr}; - Optional old_block = std::nullopt; + ffi::Optional old_block = std::nullopt; if (sref->stmt->IsInstance()) { block_realize = GetBlockRealize(self, sref); old_block = block_realize->block; } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; - Map block_sref_reuse; + ffi::Map block_sref_reuse; block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); } else { LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " - << GetRef(sref->stmt); + << ffi::GetRef(sref->stmt); throw; } @@ -762,7 +766,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; - auto f_update_max_dtype_bits_from_region = [&](const Array& buffer_regions) { + auto f_update_max_dtype_bits_from_region = [&](const ffi::Array& buffer_regions) { for (const BufferRegion& buffer_region : buffer_regions) { for (const auto& range : buffer_region->region) { index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits()); @@ -794,7 +798,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int ICHECK(comparator.rhs_buffer_map_.count(desc)); impl2cur[impl] = comparator.rhs_buffer_map_[desc]; } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; Block impl_block = Downcast(intrin_impl->body)->block; for (const BufferRegion& read : impl_block->reads) { impl2region.emplace(read->buffer, read->region); @@ -804,16 +808,16 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor // intrin to make them subregions of the buffer in the original IR. - Array match_buffer_regions; + ffi::Array match_buffer_regions; match_buffer_regions.reserve(intrin_impl->params.size()); for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) { const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]); const Buffer& cur = impl2cur.at(impl); - const Array& old_region = impl2region.at(impl); + const ffi::Array& old_region = impl2region.at(impl); const std::vector& indices_base = comparator.buffer_indices_.at(cur); int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); ICHECK(offset >= 0); - Array new_region; + ffi::Array new_region; new_region.reserve(cur->shape.size()); for (int i = 0; i < offset; i++) { PrimExpr min = indices_base[i]; @@ -867,14 +871,14 @@ struct BlockizeTraits : public UnpackedInstTraits { static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool preserve_unit_iters) { if (auto loop = target.as()) { return sch->Blockize(loop.value(), preserve_unit_iters.operator bool()); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" << target->GetTypeKey(); } - static String UnpackedAsPython(Array outputs, ObjectRef target, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, + Bool preserve_unit_iters) { PythonAPICall py("blockize"); py.Input("target", target); py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); @@ -895,7 +899,7 @@ struct TensorizeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin, + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ffi::String intrin, Bool preserve_unit_iters) { if (auto block = block_or_loop_rv.as()) { sch->Tensorize(block.value(), intrin, preserve_unit_iters.operator bool()); @@ -907,8 +911,8 @@ struct TensorizeTraits : public UnpackedInstTraits { } } - static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_or_loop_rv, + ffi::String intrin, Bool preserve_unit_iters) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); py.Input("tensor_intrin", intrin); diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index 9ea47def4c31..156f2ae4c59c 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -38,17 +38,17 @@ struct IndexInfo { /*! \brief The expr to be precomputed */ std::vector index_exprs; /*! \brief The range of the loop vars relating to index computation */ - Map range_map; + ffi::Map range_map; /*! \brief The binding table of the block var and the loop var */ - Map var_binding; + ffi::Map var_binding; /*! \brief The block var of the target block */ - std::vector> origin_block_vars; + std::vector> origin_block_vars; /*! \brief The index to insert the cache stage. */ size_t loc_pos; /*! \brief The cache stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; }; /*! @@ -79,7 +79,7 @@ class IndexInfoCollector : public StmtExprVisitor { static void Collect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, IndexInfo* info) { IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh); - collector(GetRef(scope_sref->stmt)); + collector(ffi::GetRef(scope_sref->stmt)); info->loc_pos = collector.loc_pos_; info->index_exprs = collector.exprs_; info->range_map = collector.range_map_; @@ -150,7 +150,7 @@ class IndexInfoCollector : public StmtExprVisitor { // Analyze sub expr candidates ComputationTable table_syntactic_comp_done_by_stmt = - ComputationsDoneBy::GetComputationsDoneBy(GetRef(store), IsEligibleComputation, + ComputationsDoneBy::GetComputationsDoneBy(ffi::GetRef(store), IsEligibleComputation, [](const PrimExpr& expr) { return true; }); std::vector> semantic_comp_done_by_stmt = SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true); @@ -211,7 +211,7 @@ class IndexInfoCollector : public StmtExprVisitor { /*! \brief The flag indicating the right scope to update seq pos */ bool update_seq_pos_{false}; /*! \brief Record the ranges of iter vars */ - Map range_map_; + ffi::Map range_map_; }; /*! @@ -220,9 +220,9 @@ class IndexInfoCollector : public StmtExprVisitor { * \param storage_scope The storage scope of the cached buffer (only used in naming here) * \returns A block indicating the body of the loop nesting. */ -Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { - Array blocks; - Array bodies; +ffi::Array MakeIndexCacheStage(IndexInfo* info, const ffi::String& storage_scope) { + ffi::Array blocks; + ffi::Array bodies; bodies.reserve(info->index_exprs.size()); info->cache_buffer.reserve(info->index_exprs.size()); @@ -235,7 +235,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { PostOrderVisit(index_expr, [&info, &expr_index](const ObjectRef& node) { if (node->IsInstance()) { Var iter_var = Downcast(node); - const Array& origin_block_var = info->origin_block_vars[expr_index]; + const ffi::Array& origin_block_var = info->origin_block_vars[expr_index]; auto find_result = std::find_if(origin_block_var.begin(), origin_block_var.end(), [&](Var it) { return it.get() == iter_var.get(); }); if (find_result == origin_block_var.end()) { @@ -262,7 +262,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { DataType data_type = index_expr.dtype(); Var index_buffer_var("index_var_" + std::to_string(expr_index), PointerType(PrimType(data_type), storage_scope)); - Array buffer_shape; + ffi::Array buffer_shape; for (const Var& it : info->origin_block_vars[expr_index]) { buffer_shape.push_back( arith::EvalSet(info->var_binding.at(it), arith::AsIntSet(info->range_map)).max() + 1); @@ -272,7 +272,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { // Create loop vars and block vars' binding_value std::vector loop_vars; - Map replace_table; + ffi::Map replace_table; for (const Var& it : iter_vars) { DataType data_type = DetermineDatatype(arith::IntSet::FromRange(info->range_map.at(it))); Var loop_var("ax" + std::to_string(replace_table.size()), data_type); @@ -285,12 +285,12 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { iter_values.push_back(Substitute(info->var_binding.at(it), replace_table)); } // block variables - Array block_vars; + ffi::Array block_vars; // block access region for write buffers Region access_region; // indices used in block body - Array access_indices; - Map block_var_map; + ffi::Array access_indices; + ffi::Map block_var_map; // Create block vars, block's accessed region and accessing indices for (size_t i = 0; i < info->origin_block_vars[expr_index].size(); i++) { const Var& block_var = info->origin_block_vars[expr_index][i]; @@ -348,15 +348,15 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { */ Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) { if (const auto* seq_stmt = stmt.as()) { - ObjectPtr result = make_object(*seq_stmt); + ObjectPtr result = ffi::make_object(*seq_stmt); result->seq.insert(result->seq.begin() + pos, stage); return SeqStmt(result); } if (pos == 0) { - return SeqStmt::Flatten>({stage, stmt}); + return SeqStmt::Flatten>({stage, stmt}); } ICHECK_EQ(pos, 1); - return SeqStmt::Flatten>({stmt, stage}); + return SeqStmt::Flatten>({stmt, stage}); } /*! \brief Mutator for CacheIndex. */ @@ -370,14 +370,14 @@ class CacheIndexRewriter : public StmtExprMutator { */ static Stmt Rewrite(const StmtSRef& scope_sref, IndexInfo* info) { CacheIndexRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit CacheIndexRewriter(const StmtSRef& scope_sref, IndexInfo* info) : scope_sref_(scope_sref), info_(info) { cache_indices_.reserve(info_->origin_block_vars.size()); - for (const Array& group_it : info_->origin_block_vars) { + for (const ffi::Array& group_it : info_->origin_block_vars) { cache_indices_.push_back({}); for (const Var& it : group_it) { cache_indices_.back().push_back(it); @@ -386,7 +386,7 @@ class CacheIndexRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Mutate the body visiting_target_block = static_cast(block == info_->target_block->stmt); Block stmt = Downcast(StmtMutator::VisitStmt_(block)); @@ -395,7 +395,7 @@ class CacheIndexRewriter : public StmtExprMutator { // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation and insert cache stages on the parent scope - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage); for (const Buffer& it : info_->cache_buffer) { n->alloc_buffers.push_back(it); @@ -431,13 +431,13 @@ class CacheIndexRewriter : public StmtExprMutator { /*! \brief The info for inserting cache stage */ IndexInfo* info_; /*! \brief The indices for the cache buffer */ - std::vector> cache_indices_; + std::vector> cache_indices_; /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/ bool visiting_target_block{false}; }; -Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, - const String& storage_scope, int cse_thresh) { +ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& storage_scope, int cse_thresh) { /*! * Check: * - The index is in the array of block reading region @@ -460,14 +460,14 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, // Step 2. Create cache stages and rewrite the stmt. BlockRealize realize = GetBlockRealize(self, block_sref); info.var_binding = GetBindings(realize); - Array cache_stages = MakeIndexCacheStage(&info, storage_scope); + ffi::Array cache_stages = MakeIndexCacheStage(&info, storage_scope); Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline; // Step 3. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); - Array result_block_srefs; + ffi::Array result_block_srefs; for (const Block& it : cache_stages) { StmtSRef result_block_sref = self->stmt2ref.at(it.get()); result_block_srefs.push_back(result_block_sref); @@ -478,7 +478,7 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, affine_binding = true; } else { arith::Analyzer analyzer; - StmtSRef parent_sref = GetRef(result_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(result_block_sref->parent); affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); @@ -503,13 +503,14 @@ struct CacheIndexTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, String storage_scope, - Integer cse_thresh) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, + ffi::String storage_scope, + Integer cse_thresh) { return sch->CacheIndex(block, storage_scope, cse_thresh->value); } - static String UnpackedAsPython(Array outputs, String block, String storage_scope, - Integer cse_thresh) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::String storage_scope, Integer cse_thresh) { PythonAPICall py("cache_index"); py.Input("block", block); py.Input("storage_scope", storage_scope); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 38cafbe1515e..a2479a0d28ff 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -30,35 +30,35 @@ namespace tir { class NotSingleWriteBlock : public ScheduleError { public: - explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array write_blocks) + explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, ffi::Array write_blocks) : mod_(std::move(mod)), buffer_(std::move(buffer)) { ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - write_blocks_.push_back(GetRef(block)); + write_blocks_.push_back(ffi::GetRef(block)); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer is allowed to be written by single block."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { size_t k = write_blocks_.size(); return "The buffer " + buffer_->name + " is expected to be written by single block, but got " + std::to_string(k) + " blocks who write it."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { return {write_blocks_.begin(), write_blocks_.end()}; } private: IRModule mod_; Buffer buffer_; - Array write_blocks_; + ffi::Array write_blocks_; }; /******** Helper Functions/Classes ********/ @@ -70,7 +70,7 @@ struct CacheStageInfo { /*! \brief The buffer to be written. */ Buffer write_buffer; /*! \brief The buffer allocation to be inserted into the block signature. */ - Optional alloc; + ffi::Optional alloc; /*! \brief The AST node whose body is where the cache stage should be inserted. */ StmtSRef loc_sref; /*! \brief The index to insert the cache_read/cache_write stage. */ @@ -78,7 +78,7 @@ struct CacheStageInfo { /*! \brief The cache_read/cache_write stage to be inserted. */ Stmt cache_stage; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; /*! \brief A set of blocks that will consume the new cache. */ std::unordered_set consumer_blocks; /*! \brief cache region for the buffer to be cached */ @@ -86,9 +86,9 @@ struct CacheStageInfo { }; /*! \brief Return the buffer region related with the buffer */ -Optional GetBufferRegionFromBuffer(const Array& buffer_regions, - const Buffer& buffer) { - Optional res = std::nullopt; +ffi::Optional GetBufferRegionFromBuffer( + const ffi::Array& buffer_regions, const Buffer& buffer) { + ffi::Optional res = std::nullopt; for (const auto& region : buffer_regions) { if (region->buffer.same_as(buffer)) { ICHECK(!res.defined()); @@ -100,13 +100,13 @@ Optional GetBufferRegionFromBuffer(const Array& buff struct ReindexCacheStageInfo : CacheStageInfo { /* Indices used to access the allocated cache buffer. */ - Array indices; + ffi::Array indices; /* Touched loop variable related information. */ - Array loop_vars; - Array loop_ranges; + ffi::Array loop_vars; + ffi::Array loop_ranges; /* Touched block variable related information. */ - Array block_iter_vars; - Array block_iter_values; + ffi::Array block_iter_vars; + ffi::Array block_iter_values; }; /* \brief The schedule error that accessed buffer region is not a single point for @@ -119,26 +119,26 @@ class NotSinglePointAccess : public ScheduleError { primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer region accessed inside the block is not a single point."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The buffer region " << cache_region_ << " accessed inside block {0} is not a single point, which violates" << " the prerequisite of " << primitive_name_ << " primitive."; - return String(os.str()); + return ffi::String(os.str()); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; Block block_; BufferRegion cache_region_; - String primitive_name_; + ffi::String primitive_name_; }; /*! @@ -151,15 +151,15 @@ class NotSinglePointAccess : public ScheduleError { */ template Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, - const String& storage_scope) { + const ffi::String& storage_scope) { // loop variables std::vector loop_vars; // block variables - Array block_vars; + ffi::Array block_vars; // bindings in block realize std::vector iter_values; // Create loop vars and block vars' binding_value - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < info->loop_vars.size(); ++i) { Var original_var = info->loop_vars[i]; Var loop_var(original_var->name_hint, original_var.dtype()); @@ -180,15 +180,15 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI // block access region for read/write buffers Region read_access_region, write_access_region; - Array read_access_indices, write_access_indices; + ffi::Array read_access_indices, write_access_indices; // Compute read/write region and read/write access indices. - Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; + ffi::Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; Region& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { old_indices.push_back(Substitute(range->min, var_map)); old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); } - Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; + ffi::Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; Region& new_region = (is_cache_read) ? write_access_region : read_access_region; for (const PrimExpr& idx : info->indices) { new_indices.push_back(Substitute((idx), var_map)); @@ -237,7 +237,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI * \returns A block indicating the body of the loop nesting. */ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, - const String& storage_scope, bool cache_full_region = true) { + const ffi::String& storage_scope, bool cache_full_region = true) { // loop variables std::vector loop_vars; // bindings in block realize @@ -249,13 +249,13 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, iter_values.push_back(cache_full_region ? (axis_range->min + loop_var) : loop_var); } // block variables - Array block_vars; + ffi::Array block_vars; // block access region for read/write buffers Region read_access_region; Region write_access_region; // indices used in block body - Array read_access_indices; - Array write_access_indices; + ffi::Array read_access_indices; + ffi::Array write_access_indices; // Create block vars, block's accessed region and accessing indices for (int i = 0; i < static_cast(cache_region->buffer->shape.size()); ++i) { Range axis_range = cache_region->region[i]; @@ -344,14 +344,14 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, */ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, const std::unordered_set& covered, - const Array& original_indices, int buffer_index, + const ffi::Array& original_indices, int buffer_index, BufferIndexType buffer_index_type) { // iters of the reindex block - Array new_block_iters; + ffi::Array new_block_iters; // the substitution map from the original block iter to the iters of the reindex block std::unordered_map block_var_replace_map; // indices to access the reindex buffer and the target buffer - Array reindex_indices, target_indices; + ffi::Array reindex_indices, target_indices; // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the // reindex buffer. @@ -383,8 +383,8 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, // The src and the dst region and indices of the data copy Region src_region{nullptr}; Region dst_region{nullptr}; - Array src_indices{nullptr}; - Array dst_indices{nullptr}; + ffi::Array src_indices{nullptr}; + ffi::Array dst_indices{nullptr}; if (buffer_index_type == BufferIndexType::kWrite) { src_indices = reindex_indices; @@ -444,7 +444,7 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) return true; } arith::Analyzer analyzer; - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), /*analyzer=*/&analyzer); @@ -477,7 +477,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { } if (const auto* seq_stmt = body.as()) { - Array seq = seq_stmt->seq; + ffi::Array seq = seq_stmt->seq; ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " into sequence of length " << seq.size(); seq.insert(seq.begin() + pos, stage); @@ -506,14 +506,14 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { * or `std::nullopt` if no block writes it in the scope. * \throw NotSingleWriteBlock if there are more than one interested block. */ -Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, - const Buffer& buffer) { +ffi::Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, + const Buffer& buffer) { BlockScope scope = self->GetBlockScope(scope_sref); auto it = scope->buffer_writers.find(buffer); if (it == scope->buffer_writers.end()) { return std::nullopt; } else { - const Array& block_srefs = it->second; + const ffi::Array& block_srefs = it->second; ICHECK(!block_srefs.empty()); if (block_srefs.size() > 1) { throw NotSingleWriteBlock(self->mod, buffer, block_srefs); @@ -570,11 +570,11 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive) { BlockRealize realize = GetBlockRealize(self, block_sref); - Map binding = GetBindings(realize); + ffi::Map binding = GetBindings(realize); const Buffer& buffer = buffer_region->buffer; arith::Analyzer analyzer; BufferRegion subst_region = BufferRegion(buffer, Substitute(buffer_region->region, binding)); - Array int_sets = AnalyzeRegionUpperBound( + ffi::Array int_sets = AnalyzeRegionUpperBound( /*region=*/subst_region, /*predicate=*/realize->predicate, /*dom_low_inclusive=*/dom_low_inclusive, @@ -632,7 +632,7 @@ class CacheLocDetector : public StmtVisitor { if (!related_blocks.empty()) { CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); - detector(GetRef(scope_sref->stmt)); + detector(ffi::GetRef(scope_sref->stmt)); info->loc_sref = detector.loc_sref_; info->loc_pos = detector.loc_pos_; } else { @@ -761,7 +761,7 @@ class CacheInplaceLocDetector : public StmtVisitor { static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { CacheInplaceLocDetector detector(self, block_sref, scope_sref); - detector(GetRef(scope_sref->stmt)); + detector(ffi::GetRef(scope_sref->stmt)); info->loc_sref = detector.loc_sref_; info->loc_pos = detector.loc_pos_; } @@ -851,7 +851,7 @@ class CacheReadRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, bool cache_full_region = true) { CacheReadRewriter rewriter(scope_sref, info, cache_full_region); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -868,12 +868,12 @@ class CacheReadRewriter : public StmtExprMutator { return ret; }; - update_access_regions = [this, update_region](Array regions) { + update_access_regions = [this, update_region](ffi::Array regions) { if (cache_full_region_) { return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); } - Array ret; + ffi::Array ret; for (const BufferRegion& region : regions) { if (region->buffer.same_as(info_->read_buffer)) { ret.push_back(BufferRegion(info_->write_buffer, @@ -884,12 +884,12 @@ class CacheReadRewriter : public StmtExprMutator { } return ret; }; - update_match_buffers = [this, update_region](Array match_buffers) { + update_match_buffers = [this, update_region](ffi::Array match_buffers) { if (cache_full_region_) { return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); } - Array ret; + ffi::Array ret; for (const MatchBufferRegion& match_buffer : match_buffers) { if (match_buffer->source->buffer.same_as(info_->read_buffer)) { ret.push_back(MatchBufferRegion( @@ -909,7 +909,7 @@ class CacheReadRewriter : public StmtExprMutator { // Check the insertion point if (loop == info_->loc_sref->stmt) { // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Stmt(n); } @@ -917,14 +917,14 @@ class CacheReadRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. bool is_consumer = info_->consumer_blocks.empty(); // Otherwise check if this is one of the specified blocks. for (StmtSRef consumer_sref : info_->consumer_blocks) { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); + Block consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { is_consumer = true; } @@ -941,14 +941,14 @@ class CacheReadRewriter : public StmtExprMutator { // Check the insertion point if (block == info_->loc_sref->stmt) { // Insert cache stage into the block if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Block(n); } // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); @@ -959,10 +959,10 @@ class CacheReadRewriter : public StmtExprMutator { // Only make this change if the block is one of the specified consumers. if (is_consumer) { // Use the updated block stmt - Array reads = update_access_regions(stmt->reads); - Array match_buffers = update_match_buffers(stmt->match_buffers); + ffi::Array reads = update_access_regions(stmt->reads); + ffi::Array match_buffers = update_match_buffers(stmt->match_buffers); if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); stmt = Block(n); @@ -973,7 +973,7 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Array RewriteIndices(const Array& indices) { + ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); @@ -983,7 +983,7 @@ class CacheReadRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->write_buffer; if (!cache_full_region_) { n->indices = RewriteIndices(load->indices); @@ -997,7 +997,7 @@ class CacheReadRewriter : public StmtExprMutator { if (op == info_->read_buffer->data.get()) { return info_->write_buffer->data; } - return GetRef(op); + return ffi::GetRef(op); } private: @@ -1008,9 +1008,9 @@ class CacheReadRewriter : public StmtExprMutator { /*! \brief Whether the most recently visited block is a specified consumer. */ bool current_block_consumes; /*! \brief function to update read/write region of block being cache read.*/ - std::function(Array)> update_access_regions; + std::function(ffi::Array)> update_access_regions; /*! \brief function to update match buffers of block being cache read.*/ - std::function(Array)> update_match_buffers; + std::function(ffi::Array)> update_match_buffers; /*! * \brief A boolean indicating if the cache buffer is allocated with * full region or compact region. @@ -1033,18 +1033,18 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { */ static Stmt Rewrite(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) { ReindexCacheReadRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit ReindexCacheReadRewriter(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) : CacheReadRewriter(scope_sref, info) { new_indices_ = info->indices; - update_access_regions = [&](Array reads) { - Array new_reads; + update_access_regions = [&](ffi::Array reads) { + ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->read_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1055,12 +1055,12 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { } return new_reads; }; - update_match_buffers = [&](const Array match_buffers) { - Array new_match_buffers; + update_match_buffers = [&](const ffi::Array match_buffers) { + ffi::Array new_match_buffers; for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->read_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1076,7 +1076,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { PrimExpr VisitExpr_(const BufferLoadNode* load) final { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->write_buffer; n->indices = new_indices_; return PrimExpr(n); @@ -1085,7 +1085,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { } /*! \brief The indices to use for new buffer. */ - Array new_indices_; + ffi::Array new_indices_; }; class ReindexCacheWriteRewriter; @@ -1105,7 +1105,7 @@ class CacheWriteRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, CacheStageInfo* info, bool cache_full_region = true) { CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, cache_full_region); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1125,12 +1125,12 @@ class CacheWriteRewriter : public StmtExprMutator { return ret; }; - update_access_regions = [this, update_region](Array regions) { + update_access_regions = [this, update_region](ffi::Array regions) { if (cache_full_region_) { return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); } - Array ret; + ffi::Array ret; for (const BufferRegion& region : regions) { if (region->buffer.same_as(info_->write_buffer)) { ret.push_back(BufferRegion(info_->read_buffer, @@ -1141,12 +1141,12 @@ class CacheWriteRewriter : public StmtExprMutator { } return ret; }; - update_match_buffers = [this, update_region](Array match_buffers) { + update_match_buffers = [this, update_region](ffi::Array match_buffers) { if (cache_full_region_) { return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); } - Array ret; + ffi::Array ret; for (const MatchBufferRegion& match_buffer : match_buffers) { if (match_buffer->source->buffer.same_as(info_->write_buffer)) { ret.push_back(MatchBufferRegion( @@ -1166,7 +1166,7 @@ class CacheWriteRewriter : public StmtExprMutator { // Check the insertion point if (loop == info_->loc_sref->stmt) { // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Stmt(n); } @@ -1174,17 +1174,17 @@ class CacheWriteRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) override { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); // Check if this block is one of the specified cache consumers. // update the read buffer to the cache. for (StmtSRef consumer_sref : info_->consumer_blocks) { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); + Block consumer_block = ffi::GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { - Array writes = update_access_regions(block->writes); - Array reads = update_access_regions(block->reads); - Array match_buffers = update_match_buffers(block->match_buffers); + ffi::Array writes = update_access_regions(block->writes); + ffi::Array reads = update_access_regions(block->reads); + ffi::Array match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { auto n = CopyOnWrite(block); @@ -1213,13 +1213,13 @@ class CacheWriteRewriter : public StmtExprMutator { // Find the insertion point if (block == info_->loc_sref->stmt) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); stmt = Block(n); } // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); @@ -1232,7 +1232,7 @@ class CacheWriteRewriter : public StmtExprMutator { auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); @@ -1243,7 +1243,7 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Array RewriteIndices(const Array& indices) { + ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); @@ -1267,7 +1267,7 @@ class CacheWriteRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->write_buffer)) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->read_buffer; if (!cache_full_region_) { n->indices = RewriteIndices(n->indices); @@ -1281,7 +1281,7 @@ class CacheWriteRewriter : public StmtExprMutator { if (op == info_->write_buffer->data.get()) { return info_->read_buffer->data; } - return GetRef(op); + return ffi::GetRef(op); } private: @@ -1294,9 +1294,9 @@ class CacheWriteRewriter : public StmtExprMutator { /*! \brief Whether the current node is under the given block. */ bool under_writer_block_{false}; /*! \brief function to update read/write region of block being cache write.*/ - std::function(Array)> update_access_regions; + std::function(ffi::Array)> update_access_regions; /*! \brief function to update match buffers of block being cache write.*/ - std::function(Array)> update_match_buffers; + std::function(ffi::Array)> update_match_buffers; /*! * \brief A boolean indicating if the cache buffer is allocated with * full region or compact region. @@ -1321,7 +1321,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, ReindexCacheStageInfo* info) { ReindexCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1329,11 +1329,11 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { ReindexCacheStageInfo* info) : CacheWriteRewriter(scope_sref, writer_block_sref, info) { new_indices_ = info->indices; - update_access_regions = [&](Array reads) { - Array new_reads; + update_access_regions = [&](ffi::Array reads) { + ffi::Array new_reads; for (const BufferRegion& buf_region : reads) { if (buf_region->buffer.same_as(info_->write_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1344,12 +1344,12 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { } return new_reads; }; - update_match_buffers = [&](const Array match_buffers) { - Array new_match_buffers; + update_match_buffers = [&](const ffi::Array match_buffers) { + ffi::Array new_match_buffers; for (const MatchBufferRegion& match_buffer_region : match_buffers) { BufferRegion source = match_buffer_region->source; if (source->buffer.same_as(info_->write_buffer)) { - Array region; + ffi::Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } @@ -1377,7 +1377,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { PrimExpr VisitExpr_(const BufferLoadNode* load) final { if (load->buffer.same_as(info_->write_buffer)) { - ObjectPtr n = make_object(*load); + ObjectPtr n = ffi::make_object(*load); n->buffer = info_->read_buffer; n->indices = new_indices_; return PrimExpr(n); @@ -1386,7 +1386,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { } /*! \brief The indices to use for new buffer. */ - Array new_indices_; + ffi::Array new_indices_; }; /*! @@ -1396,10 +1396,10 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { * \param covered Set of block iter vars covered by the buffer access indices * \return The new buffer with target shape. */ -Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_iters, +Buffer CreateReindexBuffer(const Buffer& buffer, const ffi::Array& block_iters, const std::unordered_set& covered) { - ObjectPtr new_buffer = make_object(*buffer.get()); - ObjectPtr new_var = make_object(*buffer->data.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); + ObjectPtr new_var = ffi::make_object(*buffer->data.get()); std::vector new_shape; std::vector new_strides; for (const auto& iter : block_iters) { @@ -1421,14 +1421,16 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_ite class NotLeafBlockError : public ScheduleError { public: NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block is not a leaf block."; } - String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block."; } + ffi::String DetailRenderTemplate() const final { + return "The target block {0} is not a leaf block."; + } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -1444,12 +1446,12 @@ class InvalidBufferAccessError : public ScheduleError { InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " "indices should be the same if there are multiple accesses to the target buffer."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The target buffer " << buffer_->name << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices " @@ -1464,7 +1466,7 @@ class InvalidBufferAccessError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1476,7 +1478,8 @@ class InvalidBufferAccessError : public ScheduleError { /*! \brief Collect the related Load/Store to reindex */ class ReIndexCollector : public StmtExprVisitor { public: - static Array Collect(const IRModule& mod, const Buffer& buffer, const Block& block) { + static ffi::Array Collect(const IRModule& mod, const Buffer& buffer, + const Block& block) { ReIndexCollector collector(mod, buffer, block); collector(block->body); if (!collector.buffer_access_indices_.defined()) { @@ -1509,7 +1512,7 @@ class ReIndexCollector : public StmtExprVisitor { } } - void CheckAndUpdateBufferAccessIndices(const Array indices) { + void CheckAndUpdateBufferAccessIndices(const ffi::Array indices) { if (!buffer_access_indices_.defined()) { buffer_access_indices_ = indices; return; @@ -1534,7 +1537,7 @@ class ReIndexCollector : public StmtExprVisitor { /*! \brief The block to visit */ Block block_; /*! \brief The indices of buffer acess to rewrite */ - Optional> buffer_access_indices_; + ffi::Optional> buffer_access_indices_; }; /*! \brief Mutator of ReIndex */ @@ -1543,7 +1546,7 @@ class ReIndexRewriter : public StmtExprMutator { static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, const std::unordered_set& covered) { ReIndexRewriter rewriter(block_sref, info, covered); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: @@ -1555,12 +1558,12 @@ class ReIndexRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); if (is_scope_) { is_scope_ = false; Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); // Insert cache stage into the loop - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); @@ -1587,7 +1590,7 @@ class ReIndexRewriter : public StmtExprMutator { BufferRegion{new_buffer_, region_}); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - ObjectPtr n = make_object(*stmt.as()); + ObjectPtr n = ffi::make_object(*stmt.as()); n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); @@ -1632,7 +1635,7 @@ class ReIndexRewriter : public StmtExprMutator { /*! \brief The reindex buffer */ Buffer new_buffer_; /*! \brief The new indices */ - Array indices_; + ffi::Array indices_; /*! \brief The new region */ Region region_; }; @@ -1642,15 +1645,15 @@ void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer rea public: explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The scope root's region cover is not complete."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return R"(The scope {0} 's region cover is not complete. The region cover property require to hold for every of its child blocks )"; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; }; @@ -1661,7 +1664,7 @@ The region cover property require to hold for every of its child blocks if (region->buffer.same_as(read_buffer)) { if (!self->block_info.at(child_block_sref).region_cover) { const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); - throw NotRegionCoverError(self->mod, GetRef(block)); + throw NotRegionCoverError(self->mod, ffi::GetRef(block)); } } } @@ -1671,7 +1674,7 @@ The region cover property require to hold for every of its child blocks /******** Implementation ********/ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const Array consumer_blocks) { + const ffi::String& storage_scope, const ffi::Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1688,8 +1691,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 1. Check index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer read_buffer = - GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + Buffer read_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check required region cover for cache_read CheckRegionCover(self, scope_sref, read_buffer); @@ -1709,13 +1712,14 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 3. Update cache stage info. BufferRegion cache_region{nullptr}; - if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + if (ffi::Optional _write_block_sref = + GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); // Find the producing region BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); - StmtSRef parent_sref = GetRef(write_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(write_block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); @@ -1724,7 +1728,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Case 2. The buffer is the input block for the scope. info.loc_sref = scope_sref; info.loc_pos = 0; - if (Optional region = + if (ffi::Optional region = GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) { cache_region = region.value(); } else { @@ -1764,7 +1768,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, const Array consumer_blocks) { + const ffi::String& storage_scope, const ffi::Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1781,8 +1785,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 1. Checking index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer write_buffer = - GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); + Buffer write_buffer = GetNthAccessBuffer(self, ffi::GetRef(block), write_buffer_index, + BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Step 2. Creating CacheStageInfo @@ -1803,7 +1807,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 4. Find the producing region and insert position BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = @@ -1841,12 +1845,12 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } -Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { +ffi::Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { std::vector result; for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); parent = parent->parent) { if (parent == top_sref.get()) break; - result.push_back(GetRef(parent)); + result.push_back(ffi::GetRef(parent)); } return {result.rbegin(), result.rend()}; } @@ -1858,8 +1862,9 @@ Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& t class ReindexCacheReadWriteNotMatchError : public ScheduleError { public: ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, - Array old_indices, Array new_indices, - bool is_cache_read, bool appears_in_old) + ffi::Array old_indices, + ffi::Array new_indices, bool is_cache_read, + bool appears_in_old) : mod_(std::move(mod)), block_(std::move(block)), var_(std::move(var)) { primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; if (appears_in_old) { @@ -1870,26 +1875,26 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { other_indices_ = std::move(old_indices); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: the block itervars appeared in lhs and rhs of reindex cache stage do " "not match."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream s; s << "Error when applying " << primitive_name_ << " on block {0}, the block itervar " << var_ << " appears in " << appears_indices_ << ", but not in " << other_indices_ << "."; - return String(s.str()); + return ffi::String(s.str()); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - String primitive_name_; + ffi::String primitive_name_; Block block_; Var var_; - Array appears_indices_; - Array other_indices_; + ffi::Array appears_indices_; + ffi::Array other_indices_; }; /*! @@ -1908,21 +1913,21 @@ class ReindexCacheReadWriteNotMatchError : public ScheduleError { template void CollectReindexCacheStageInfoAndCreateBuffer( ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, - const String& storage_scope, const IndexMap& index_map, const Block& block, + const ffi::String& storage_scope, const IndexMap& index_map, const Block& block, const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { arith::Analyzer analyzer; - Array block_iter_vars, block_shape; + ffi::Array block_iter_vars, block_shape; for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var); block_shape.push_back(iter_var->dom->extent); } - Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); - Array new_shape = index_map->MapShape(block_shape, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ffi::Array new_shape = index_map->MapShape(block_shape, &analyzer); info->indices = new_indices; // Step 5. Update CacheTouchedInfo VarUseDefAnalyzer collector_old(/*defined_vars=*/{}); - Array old_indices; + ffi::Array old_indices; for (const Range& range : cache_region->region) { collector_old(range->min); old_indices.push_back(range->min); @@ -1959,8 +1964,8 @@ void CollectReindexCacheStageInfoAndCreateBuffer( } // Create new buffer - ObjectPtr new_buffer = make_object(*old_buffer.get()); - ObjectPtr new_var = make_object(*old_buffer->data.get()); + ObjectPtr new_buffer = ffi::make_object(*old_buffer.get()); + ObjectPtr new_var = ffi::make_object(*old_buffer->data.get()); const auto* ptr_type = TVM_TYPE_AS(old_buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); @@ -1992,7 +1997,7 @@ void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion } StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) { + const ffi::String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -2008,7 +2013,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); BlockRealize realize = GetBlockRealize(self, block_sref); Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); @@ -2019,15 +2024,16 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re info.consumer_blocks.insert(block_sref); // Step 3. Update cache stage info. - Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); + ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); ICHECK(maybe_region.defined()) << read_buffer << " should appear in the block's read region: " << block->reads; BufferRegion cache_region = maybe_region.value(); - if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + if (ffi::Optional _write_block_sref = + GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); // Find the producing region - StmtSRef parent_sref = GetRef(write_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(write_block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); } else { @@ -2062,7 +2068,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re } StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) { + const ffi::String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -2078,7 +2084,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + Block block = ffi::GetRef(TVM_SREF_TO_BLOCK(block_sref)); BlockRealize realize = GetBlockRealize(self, block_sref); Buffer write_buffer = GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); @@ -2092,9 +2098,9 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); // Step 4. Find the producing region and insert position - Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); + ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; - StmtSRef parent_sref = GetRef(block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = maybe_region.value(); @@ -2130,23 +2136,23 @@ class NotReadWriteError : public ScheduleError { public: NotReadWriteError(IRModule mod, Block block, Buffer buffer) : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block does not both read & write target buffer."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The target block {0} does not both read & write target buffer {1}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_, buffer_}; } + ffi::Array LocationsOfInterest() const final { return {block_, buffer_}; } IRModule mod_; Block block_; Buffer buffer_; }; -Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope) { +ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const ffi::String& storage_scope) { /*! * Do cache read then cache write */ @@ -2156,8 +2162,8 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int // Check 1. Check index, get the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + Buffer buffer = GetNthAccessBuffer(self, ffi::GetRef(block), read_buffer_index, + BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check 3. Check required region cover for cache_read @@ -2165,13 +2171,13 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int // Check 4. Check if target block both read & write target buffer. const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); - Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); - Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); + ffi::Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); + ffi::Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); if (!read_region.defined() || !write_region.defined()) { - throw NotReadWriteError(self->mod, GetRef(rw_block), buffer); + throw NotReadWriteError(self->mod, ffi::GetRef(rw_block), buffer); } - Array results_block_sref; + ffi::Array results_block_sref; Buffer new_buffer = WithScope(buffer, storage_scope); // Do cache read @@ -2237,14 +2243,14 @@ Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - Block block = GetRef(block_ptr); + Block block = ffi::GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); arith::Analyzer analyzer; // Step 1. Collect the original indices and check there's only single pattern of related // Load/Store and the buffer is not accessed opaquely - Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); + ffi::Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); // Simplify the indices if possible for (const IterVar& iter : block->iter_vars) { analyzer.Bind(iter->var, iter->dom); @@ -2319,13 +2325,14 @@ struct CacheReadTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Array consumer_blocks, Integer read_buffer_index, - String storage_scope) { + ffi::Array consumer_blocks, + Integer read_buffer_index, ffi::String storage_scope) { return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Array consumer_blocks, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array consumer_blocks, + Integer read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2352,13 +2359,14 @@ struct CacheWriteTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Array consumer_blocks, Integer write_buffer_index, - String storage_scope) { + ffi::Array consumer_blocks, + Integer write_buffer_index, ffi::String storage_scope) { return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Array consumer_blocks, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array consumer_blocks, + Integer write_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); @@ -2384,13 +2392,14 @@ struct CacheInplaceTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, - Integer read_buffer_index, String storage_scope) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, + Integer read_buffer_index, + ffi::String storage_scope) { return sch->CacheInplace(block, read_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, - String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_inplace"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2418,14 +2427,14 @@ struct ReIndexTraits : public UnpackedInstTraits { static_cast(buffer_index_type->value)); } - static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, - Integer buffer_index_type) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer buffer_index, Integer buffer_index_type) { PythonAPICall py("reindex"); py.Input("block", block); std::ostringstream os; os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) << "\", " << buffer_index << ")"; - py.Input("buffer", String(os.str())); + py.Input("buffer", ffi::String(os.str())); py.SingleOutput(outputs); return py.Str(); } @@ -2444,12 +2453,13 @@ struct ReindexCacheReadTraits : public UnpackedInstTraitsReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); } - static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + IndexMap index_map, Integer read_buffer_index, + ffi::String storage_scope) { PythonAPICall py("reindex_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2473,12 +2483,13 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraitsReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); } - static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + IndexMap index_map, Integer write_buffer_index, + ffi::String storage_scope) { PythonAPICall py("reindex_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 0075fee18f4c..cd56ff8b9ddf 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -33,21 +33,21 @@ template class NotAllRequiredBlocksAreVisitedError : public ScheduleError { public: explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, - const Array& required) + const ffi::Array& required) : mod_(mod), num_not_visited_(num_not_visited) { required_.reserve(required.size()); for (const StmtSRef& block_sref : required) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - required_.push_back(GetRef(block)); + required_.push_back(ffi::GetRef(block)); } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Not all required blocks are under the loop scope"; } - String DetailRenderTemplate() const final { - String relation = is_consumer ? "consumer(s)" : "producer(s)"; + ffi::String DetailRenderTemplate() const final { + ffi::String relation = is_consumer ? "consumer(s)" : "producer(s)"; std::ostringstream os; os << "The primitive requires all the " << relation << " of the given block to be present under the target loop. However, there are " @@ -61,14 +61,14 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { return {required_.begin(), required_.end()}; } private: IRModule mod_; int num_not_visited_; - Array required_; + ffi::Array required_; }; /*! @@ -96,22 +96,22 @@ class NotInSameScopeError : public ScheduleError { } } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " "not to be the ancestor of block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " "and loop not to be the ancestor of block"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_, loop_}; } + ffi::Array LocationsOfInterest() const final { return {block_, loop_}; } private: explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) : mod_(mod), - block_(GetRef(block_sref->StmtAs())), - loop_(GetRef(loop_sref->StmtAs())) {} + block_(ffi::GetRef(block_sref->StmtAs())), + loop_(ffi::GetRef(loop_sref->StmtAs())) {} IRModule mod_; Block block_; @@ -138,8 +138,9 @@ class NotInSameScopeError : public ScheduleError { * \throws ScheduleError if there is no such insertion point found */ template -int FindInsertionPoint(const ScheduleState& self, const Array& subtrees, - const Array& producer_srefs, const Array& consumer_srefs, +int FindInsertionPoint(const ScheduleState& self, const ffi::Array& subtrees, + const ffi::Array& producer_srefs, + const ffi::Array& consumer_srefs, std::unordered_map* block2realize, int index) { ProducerConsumerSplit split = @@ -254,9 +255,9 @@ class ScopeReconstructor : private StmtMutator { void MakeNewLoop(int insert_position, std::vector iter_doms, arith::Analyzer* analyzer, bool preserve_unit_loops) { int n_iters = iter_doms.size(); - Array loop_vars; - Array loop_extents; - Array iter_values; + ffi::Array loop_vars; + ffi::Array loop_extents; + ffi::Array iter_values; loop_vars.reserve(n_iters); loop_extents.reserve(n_iters); iter_values.reserve(n_iters); @@ -302,9 +303,9 @@ class ScopeReconstructor : private StmtMutator { /*ForKind=*/ForKind::kSerial, /*body=*/std::move(new_subtree)); } - Array subtrees = AsArray(loop_->body); + ffi::Array subtrees = AsArray(loop_->body); subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree)); - ObjectPtr new_loop = make_object(*loop_.get()); + ObjectPtr new_loop = ffi::make_object(*loop_.get()); new_loop->body = SeqStmt(std::move(subtrees)); this->new_loop_ = For(std::move(new_loop)); } @@ -312,7 +313,7 @@ class ScopeReconstructor : private StmtMutator { private: Stmt VisitStmt_(const BlockNode* block) final { if (block != scope_root_.get()) { - return GetRef(block); + return ffi::GetRef(block); } if (block == rm_src_stmt_.get()) { block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode); @@ -358,19 +359,19 @@ class ScopeReconstructor : private StmtMutator { * \param relaxed Where the calculation result is stored */ template -void RelaxBufferRegions(const Map& binding, - const Array& buffer_regions, +void RelaxBufferRegions(const ffi::Map& binding, + const ffi::Array& buffer_regions, const StmtSRef& relax_path_low_inclusive, const StmtSRef& relax_path_high_exclusive, std::unordered_map>* relaxed) { runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""}; // We cache the variable domains runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; - Optional> var_dom = std::nullopt; + ffi::Optional> var_dom = std::nullopt; // Enumerate every buffer region for (const BufferRegion& buffer_region : buffer_regions) { const Buffer& buffer = buffer_region->buffer; - const Array& region = buffer_region->region; + const ffi::Array& region = buffer_region->region; // Skip the buffer regions we are not interested in auto it = relaxed->find(buffer.get()); if (it == relaxed->end()) { @@ -389,7 +390,7 @@ void RelaxBufferRegions(const Map& binding, /*extra_relax_scope=*/scope)); } // Relax the region - Array relaxed_region = + ffi::Array relaxed_region = arith::EvalSet(Substitute(region, binding), var_dom.value()); relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()}); } @@ -412,7 +413,7 @@ std::pair SolveBlockVarDomain(const arith::IntSet& prov PrimExpr required_min = analyzer->Simplify(required.min()); PrimExpr required_max = analyzer->Simplify(required.max()); arith::IntSet var_dom, var_bound; - Optional var; + ffi::Optional var; arith::PVar p_v; arith::PVar p_e; if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) { @@ -506,9 +507,10 @@ void UpdateBlockVarDomainDimwise( } /*! \brief Helper function to implement intset version of `InverseAffineIterMap`. */ -Map InverseAffineIterMap(const Array& iter_map, - const NDIntSet& outputs, arith::Analyzer* analyzer) { - Array min_point, max_point; +ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, + const NDIntSet& outputs, + arith::Analyzer* analyzer) { + ffi::Array min_point, max_point; min_point.reserve(outputs.size()); max_point.reserve(outputs.size()); for (const auto& intset : outputs) { @@ -518,7 +520,7 @@ Map InverseAffineIterMap(const Array& it } auto rev_min = InverseAffineIterMap(iter_map, min_point); auto rev_max = InverseAffineIterMap(iter_map, max_point); - Map dom_map; + ffi::Map dom_map; for (const auto& kv : rev_min) { const Var& var = kv.first; auto it = rev_max.find(var); @@ -543,7 +545,7 @@ Map InverseAffineIterMap(const Array& it * \param iter_doms The result iteration domains to be updated * \returns bool. Denotes whether update success */ -bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& iter_vars, +bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const ffi::Array& iter_vars, const NDIntSet& provided_region, const NDIntSet& required_region, arith::Analyzer* analyzer, std::unordered_map* iter_doms) { @@ -552,12 +554,12 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& if (!intset.CanProveSinglePoint(analyzer)) return false; } // calculate forward mapping (block vars -> provided region point) - Map dom_map; + ffi::Map dom_map; for (const IterVar& iter_var : iter_vars) { dom_map.Set(iter_var->var, iter_var->dom); } size_t ndim = buffer->shape.size(); - Array provide_indices; + ffi::Array provide_indices; provide_indices.reserve(ndim); for (size_t i = 0; i < ndim; ++i) { provide_indices.push_back(provided_region[i].min()); @@ -573,8 +575,10 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& required_bound.push_back( arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype), max(buffer->shape[i] - 1, 0))); } - Map var_dom = InverseAffineIterMap(res->indices, required_region, analyzer); - Map var_bound = InverseAffineIterMap(res->indices, required_bound, analyzer); + ffi::Map var_dom = + InverseAffineIterMap(res->indices, required_region, analyzer); + ffi::Map var_bound = + InverseAffineIterMap(res->indices, required_bound, analyzer); for (const auto& kv : var_dom) { const Var& var = kv.first; auto it = var_bound.find(var); @@ -593,7 +597,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& * \return A list of iteration domain info corresponding to the given list of block vars */ std::vector CalculateBlockVarDomain( - const Array& iter_vars, + const ffi::Array& iter_vars, std::unordered_map> provided_regions, std::unordered_map> required_regions, arith::Analyzer* analyzer) { @@ -657,16 +661,16 @@ template void CalculateProvidedRequiredRegions( const BlockNode* block, const StmtSRef& loop_sref, std::unordered_map block2realize, - Array producer_srefs, Array consumer_srefs, + ffi::Array producer_srefs, ffi::Array consumer_srefs, std::unordered_map>* provided_regions, std::unordered_map>* required_regions) { // Step 1. Calculate the region provided by a single execution instance of `block` - const Array& provided_buffers = is_compute_at ? block->writes : block->reads; + const ffi::Array& provided_buffers = is_compute_at ? block->writes : block->reads; provided_regions->reserve(provided_buffers.size()); required_regions->reserve(provided_buffers.size()); for (const BufferRegion& provided_buffer_region : provided_buffers) { const BufferNode* buffer = provided_buffer_region->buffer.get(); - const Array& region = provided_buffer_region->region; + const ffi::Array& region = provided_buffer_region->region; (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region)); (*required_regions)[buffer].clear(); } @@ -675,9 +679,9 @@ void CalculateProvidedRequiredRegions( const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref); ICHECK(block2realize.count(required_block)); RelaxBufferRegions( - /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), + /*binding=*/GetBindings(ffi::GetRef(block2realize.at(required_block))), /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, - /*relax_path_low_inclusive=*/GetRef(required_block_sref->parent), + /*relax_path_low_inclusive=*/ffi::GetRef(required_block_sref->parent), /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); } } @@ -695,11 +699,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - Block scope_root = GetRef(scope_root_sref->StmtAs()); + Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); AddShapeVarBounds(self, scope_root_sref.get(), analyzer); BlockScope scope = self->GetBlockScope(scope_root_sref); - Array producer_srefs = GetProducers(block_sref, scope); - Array consumer_srefs = GetConsumers(block_sref, scope); + ffi::Array producer_srefs = GetProducers(block_sref, scope); + ffi::Array consumer_srefs = GetConsumers(block_sref, scope); // Check condition 2) : `block` is a complete or reduction block CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref); // Check condition 3): `block` and `loop` are under the same scope, @@ -711,7 +715,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s CheckNotOutputBlock(self, block_sref, scope_root_sref); } // Step 2. Plan for the removal of `block` - ScopeReconstructor reconstructor(scope_root, GetRef(block), GetRef(loop)); + ScopeReconstructor reconstructor(scope_root, ffi::GetRef(block), ffi::GetRef(loop)); LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); // Step 3. Find the insertion point under `loop` // Check condition 5): all the required block are under the given loop @@ -755,7 +759,7 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s BlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/reconstructor.new_block_realize_, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent)), /*analyzer=*/analyzer); } @@ -813,8 +817,8 @@ struct ComputeAtTraits : public UnpackedInstTraits { return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops, IntImm index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); @@ -842,8 +846,8 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsvalue); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, - Bool preserve_unit_loops, IntImm index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 4e037158d98a..e480c68ff4ad 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -36,14 +36,16 @@ class HasInitBlock : public ScheduleError { public: explicit HasInitBlock(IRModule mod, Block block) : mod_(mod), block_(block) {} - String FastErrorString() const final { return "ScheduleError: The block has init statement"; } + ffi::String FastErrorString() const final { + return "ScheduleError: The block has init statement"; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "ScheduleError: The block has init statement: {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } static void Check(const IRModule& mod, const Block& block) { if (block->init.defined()) { @@ -61,12 +63,12 @@ class NotSingleReadWriteBuffer : public ScheduleError { explicit NotSingleReadWriteBuffer(IRModule mod, bool is_read, Block block) : mod_(mod), is_read_(is_read), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return is_read_ ? "ScheduleError: The block is allowed to read only a single buffer region" : "ScheduleError: The block is allowed to write only a single buffer region"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (is_read_) { int k = block_->reads.size(); return "The block is only allowed to read a single buffer region, but it reads " + @@ -79,7 +81,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; bool is_read_; @@ -87,7 +89,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { static Buffer GetSingleRead(const ScheduleState& self, const Block& block, const StmtSRef& scope_root_sref) { - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; const BufferNode* read_buffer = nullptr; for (const BufferRegion& read_region : block->reads) { @@ -95,7 +97,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { if (buffer == read_buffer) { continue; } - if (buffer_writers.count(GetRef(buffer)) > 0) { + if (buffer_writers.count(ffi::GetRef(buffer)) > 0) { if (read_buffer != nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } @@ -105,7 +107,7 @@ class NotSingleReadWriteBuffer : public ScheduleError { if (read_buffer == nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } - return GetRef(read_buffer); + return ffi::GetRef(read_buffer); } static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { @@ -121,17 +123,17 @@ class BodyAnalysisError : public ScheduleError { explicit BodyAnalysisError(bool is_reverse, IRModule mod, Block block) : is_reverse_(is_reverse), mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block cannot be inlined because its body pattern does not meet the " "condition for inlining"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return is_reverse_ ? kErrBodyReverseInline : kErrBodyInline; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } bool is_reverse_; IRModule mod_; @@ -143,20 +145,20 @@ class NonSingleProducerError : public ScheduleError { explicit NonSingleProducerError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The consumer block to be inlined is required to have only a single " "producer block, and the producer block should be a complete block who has only a " "single consumer"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The consumer block {0} to be inlined is required to have only a single " "producer block, and the producer block should be a complete block who has only a " "single consumer"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -174,7 +176,7 @@ class NonSingleProducerError : public ScheduleError { const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead( - self, GetRef(consumer_block), scope_root_sref); + self, ffi::GetRef(consumer_block), scope_root_sref); class ProducerFinder : public StmtVisitor { public: static std::vector GetProducer(const ScheduleState& self, @@ -211,9 +213,9 @@ class NonSingleProducerError : public ScheduleError { // Check if the producer block is a complete block StmtSRef producer_block_sref = self_->stmt2ref.at(node); if (!IsCompleteBlock(self_, producer_block_sref, scope_root_sref_)) { - throw NonSingleProducerError(self_->mod, GetRef(node)); + throw NonSingleProducerError(self_->mod, ffi::GetRef(node)); } - producer_across_scope_.back().push_back(GetRef(node)); + producer_across_scope_.back().push_back(ffi::GetRef(node)); break; } } @@ -224,9 +226,9 @@ class NonSingleProducerError : public ScheduleError { std::vector> producer_across_scope_; }; std::vector producer_across_scope = ProducerFinder::GetProducer( - self, scope_root_sref, consumer_buffer, GetRef(scope_block)); + self, scope_root_sref, consumer_buffer, ffi::GetRef(scope_block)); if (producer_across_scope.size() != 1) { - throw NonSingleProducerError(self->mod, GetRef(consumer_block)); + throw NonSingleProducerError(self->mod, ffi::GetRef(consumer_block)); } return self->stmt2ref.at(producer_across_scope[0].get()); } @@ -237,21 +239,21 @@ class OpaqueAccessError : public ScheduleError { explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref) : mod_(mod), scope_root_(nullptr) { const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); - this->scope_root_ = GetRef(scope_root); + this->scope_root_ = ffi::GetRef(scope_root); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer to be inlined has opaque access (e.g. `B.data`), or its " "subregion is matched into other blocks"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer to be inlined has opaque access (e.g. `B.data`), or its " "subregion is matched into other blocks: {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {scope_root_}; } + ffi::Array LocationsOfInterest() const final { return {scope_root_}; } IRModule mod_; Block scope_root_; @@ -263,11 +265,11 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { PrimExpr new_predicate) : mod_(mod), producer_(producer), new_predicate_(new_predicate) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The producer block has a non-trivial predicate."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The producer block {0} has a non-trivial predicate " << producer_->predicate << " that cannot be implied by the synthesized predicate " @@ -276,7 +278,7 @@ class ProducerHasNonTrivialPredicateError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {producer_}; } + ffi::Array LocationsOfInterest() const final { return {producer_}; } IRModule mod_; BlockRealize producer_; @@ -315,7 +317,7 @@ class BaseInliner : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) { CheckMatchBufferRegion(block); AddBuffersInBlockSignature(block); - Block src_block = GetRef(block); + Block src_block = ffi::GetRef(block); if (src_block.same_as(src_stmt)) { block = tgt_stmt.as(); ICHECK(block != nullptr); @@ -358,7 +360,7 @@ class BaseInliner : public StmtExprMutator { */ Block UpdateBuffersInBlockSignature(Block block, bool is_scope_root) { // Step 1. Update `BlockNode::alloc_buffers` - Array alloc_buffers; + ffi::Array alloc_buffers; if (is_scope_root) { alloc_buffers.reserve(block->alloc_buffers.size()); for (const Buffer& alloc_buffer : block->alloc_buffers) { @@ -370,14 +372,15 @@ class BaseInliner : public StmtExprMutator { alloc_buffers = std::move(block->alloc_buffers); } // Step 2. Update `BlockNode::reads` and `BlockNode::writes` - Array reads = std::move(block->reads); - Array writes = std::move(block->writes); + ffi::Array reads = std::move(block->reads); + ffi::Array writes = std::move(block->writes); auto f_access_inline_buffer = [this](const BufferRegion& access) { return access->buffer.same_as(this->inlined_buffer_); }; if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) || std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) { - Array> inspected = GetBlockReadWriteRegion(block, buffer_var_map_); + ffi::Array> inspected = + GetBlockReadWriteRegion(block, buffer_var_map_); reads = inspected[0]; writes = inspected[1]; } @@ -422,7 +425,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief The scope root */ StmtSRef scope_root_sref_{nullptr}; /*! \brief Maps a buffer's data field to itself */ - Map buffer_var_map_; + ffi::Map buffer_var_map_; /*! \brief The indices used for indexing the buffer to be inlined */ std::vector idx_vars_; /*! \brief The mapping to substitute index variables to PrimExprs */ @@ -438,7 +441,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief The Stmt to be replaced to when removing the leaf block */ Stmt tgt_stmt{nullptr}; /*! \brief The reuse mapping of block srefs */ - Map block_reuse; + ffi::Map block_reuse; /*! \brief Indicates if there is any opaque access of the inlined buffer */ bool has_opaque_access{false}; }; @@ -489,7 +492,7 @@ class ComputeInliner : public BaseInliner { // If the mapping for store indices is non-trivial // check bijective mapping from producer iter var to store indices - Map producer_iter_doms; + ffi::Map producer_iter_doms; for (const auto& iter : producer_block->iter_vars) { producer_iter_doms.Set(iter->var, iter->dom); } @@ -509,7 +512,7 @@ class ComputeInliner : public BaseInliner { idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype()); } auto inverse_iter_map = arith::InverseAffineIterMap( - res->indices, Array(idx_vars_.begin(), idx_vars_.end())); + res->indices, ffi::Array(idx_vars_.begin(), idx_vars_.end())); for (const auto& iter : producer_block->iter_vars) { if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) { // fallback mapping for constant iters @@ -541,7 +544,7 @@ class ComputeInliner : public BaseInliner { * \brief Set the mapping of index substitution `self->idx_sub_` * \param indices The expressions that the corresponding index variables are replaced to */ - void SetIndexSubstitution(const Array& indices) { + void SetIndexSubstitution(const ffi::Array& indices) { ICHECK_EQ(indices.size(), idx_vars_.size()); int n = idx_vars_.size(); for (int i = 0; i < n; ++i) { @@ -573,7 +576,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr VisitExpr_(const VarNode* var) final { auto it = self_->idx_sub_.find(var); if (it == self_->idx_sub_.end()) { - return GetRef(var); + return ffi::GetRef(var); } return (*it).second; } @@ -594,7 +597,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr VisitExpr_(const VarNode* var) final { auto it = self_->idx_sub_.find(var); if (it == self_->idx_sub_.end()) { - return GetRef(var); + return ffi::GetRef(var); } return (*it).second; } @@ -644,7 +647,7 @@ class ReverseComputeInliner : public BaseInliner { } // Collect block iter domains and update the substition map - Map consumer_iter_doms; + ffi::Map consumer_iter_doms; for (const auto& iter_var : consumer_block->iter_vars) { consumer_iter_doms.Set(iter_var->var, iter_var->dom); // Set default mapping for unit iters @@ -708,7 +711,7 @@ class ReverseComputeInliner : public BaseInliner { /*! \brief Generate the predicate after inlining based on the consumer predicate */ BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) { // Bind the producer block iter domains for simplification - Map subst_map; + ffi::Map subst_map; Block producer_block = producer_block_realize->block; for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { const IterVar& iter = producer_block->iter_vars[i]; @@ -748,7 +751,7 @@ class ReverseComputeInliner : public BaseInliner { auto n = producer_block_realize.CopyOnWrite(); n->block = producer_block; n->predicate = analyzer_.Simplify(outer_predicate); - return GetRef(n); + return ffi::GetRef(n); } Stmt VisitStmt_(const BlockRealizeNode* op) final { @@ -774,7 +777,7 @@ class ReverseComputeInliner : public BaseInliner { * \return Whether the consumer block iter domains are covered */ bool CheckConsumerCovered() { - Map producer_iter_doms; + ffi::Map producer_iter_doms; for (const IterVar& iter_var : producer_block_->iter_vars) { producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom)); } @@ -800,7 +803,7 @@ class ReverseComputeInliner : public BaseInliner { * the result. It will be later used to transform the BufferStore indices of the producer. * \param producer_indices The BufferStore indices of the producer. */ - void CreateInverseMapping(const Array producer_indices) { + void CreateInverseMapping(const ffi::Array producer_indices) { auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices); for (const auto& pair : inverse_iter_map) { idx_sub_[pair.first.get()] = pair.second; @@ -811,7 +814,7 @@ class ReverseComputeInliner : public BaseInliner { // "producer->value" may contain the buffer that is inlined in cases of reduction, // so we need to resolve the recursion first producer_rhs_ = RecursionResolver(this)(producer->value); - return Substituter(this)(GetRef(inlined_store_)); + return Substituter(this)(ffi::GetRef(inlined_store_)); } /*! @@ -847,7 +850,7 @@ class ReverseComputeInliner : public BaseInliner { * \param expected_ndim The expected ndim of the access * \return A boolean flag indicating if the check is successful */ - bool UpdateAndCheckIndexExprs(const Array& indices) { + bool UpdateAndCheckIndexExprs(const ffi::Array& indices) { if (buffer_load_indices_.empty()) { buffer_load_indices_ = indices; } else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(), @@ -861,9 +864,9 @@ class ReverseComputeInliner : public BaseInliner { /*! \brief The RHS value of the producer's BufferStore statement */ PrimExpr producer_rhs_{nullptr}; /*! \brief The indices of the consumer's BufferLoad */ - Array buffer_load_indices_; + ffi::Array buffer_load_indices_; /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ - Array buffer_load_iter_map_{nullptr}; + ffi::Array buffer_load_iter_map_{nullptr}; /*! \brief The producer block */ const BlockNode* producer_block_{nullptr}; /* \brief The consumer block */ @@ -879,7 +882,7 @@ class ReverseComputeInliner : public BaseInliner { void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, bool check_only = false) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref); - Block producer_block = GetRef(_producer_block); + Block producer_block = ffi::GetRef(_producer_block); HasInitBlock::Check(self->mod, producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block @@ -897,7 +900,7 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, LeafBlockRemovalPlan(self, producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, // and update other blocks who read from the removed block - Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + Stmt tgt_stmt = inliner(ffi::GetRef(scope_root_sref->stmt)); if (inliner.has_opaque_access) { throw OpaqueAccessError(self->mod, scope_root_sref); } @@ -924,7 +927,7 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); - Block consumer_block = GetRef(_consumer_block); + Block consumer_block = ffi::GetRef(_consumer_block); BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block @@ -949,7 +952,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block LeafBlockRemovalPlan(self, consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, // and update other blocks who read from the removed block - Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); + Stmt tgt_stmt = inliner(ffi::GetRef(scope_root_sref->stmt)); if (inliner.has_opaque_access) { throw OpaqueAccessError(self->mod, scope_root_sref); } @@ -963,7 +966,8 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block BlockInfo& block_info = self->block_info[producer_block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/GetBlockRealize(self, producer_block_sref), - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(producer_block_sref->parent)), + /*loop_var_ranges=*/ + LoopDomainOfSRefTreePath(ffi::GetRef(producer_block_sref->parent)), /*analyzer=*/&analyzer); } @@ -995,7 +999,7 @@ struct ComputeInlineTraits : public UnpackedInstTraits { return sch->ComputeInline(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("compute_inline"); py.Input("block", block_rv); return py.Str(); @@ -1018,7 +1022,7 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraitsReverseComputeInline(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("reverse_compute_inline"); py.Input("block", block_rv); return py.Str(); diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index d848dad28f27..fe76823b8972 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -27,7 +27,7 @@ namespace tir { /*! \brief Information used to create new padding block */ struct PaddingBlockInfo { /*! \brief In-bound block iter regions, wrt loop vars. */ - Array in_bound_region; + ffi::Array in_bound_region; /*! \brief In-bound value, wrt block iter vars. */ PrimExpr in_bound_value; /*! \brief Condition of in-bound write, wrt loop vars. */ @@ -41,12 +41,12 @@ class PaddingPatternMatchError : public ScheduleError { PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg) : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: decompose_padding expect the block to match padding pattern\n " + error_msg_; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: decompose_padding expect the block {0} to match padding pattern\n " << error_msg_; @@ -54,7 +54,7 @@ class PaddingPatternMatchError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -68,7 +68,7 @@ class PaddingPatternMatchError : public ScheduleError { class PaddingInfoAnalyzer { public: static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize, - const Map& dom_map, + const ffi::Map& dom_map, arith::Analyzer* analyzer) { PaddingInfoAnalyzer padding_analyzer(analyzer); if (!padding_analyzer.MatchPadding(realize, dom_map)) { @@ -81,7 +81,7 @@ class PaddingInfoAnalyzer { explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} /*! \brief Detect padding pattern and update result. */ - bool MatchPadding(const BlockRealizeNode* realize, const Map& dom_map) { + bool MatchPadding(const BlockRealizeNode* realize, const ffi::Map& dom_map) { // Step 1. Check match padding computation pattern. // A[...] = T.if_then_else(predicate, B[...], imm) Block block = realize->block; @@ -120,7 +120,7 @@ class PaddingInfoAnalyzer { SetError("The in-bound predicate is trivial"); return false; } - Array in_bound_region = this->EstimateInBoundRegion( + ffi::Array in_bound_region = this->EstimateInBoundRegion( /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map, /*in_bound_predicate=*/in_bound_predicate); if (in_bound_region.empty()) { @@ -157,10 +157,10 @@ class PaddingInfoAnalyzer { } /*! \brief Return iteration region of block vars where the padding predicate evals to true. */ - Array EstimateInBoundRegion(const Array& iter_values, - const Map& dom_map, - const PrimExpr& in_bound_predicate) { - Array region; + ffi::Array EstimateInBoundRegion(const ffi::Array& iter_values, + const ffi::Map& dom_map, + const PrimExpr& in_bound_predicate) { + ffi::Array region; auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate, arith::IterMapLevel::Surjective, analyzer_); @@ -196,12 +196,12 @@ class PaddingInfoAnalyzer { /*! \brief Create block to fill constant pad values into full region */ static std::pair CreateConstBlock(const BlockRealizeNode* realize, const PaddingBlockInfo& info, - const Array& loops, + const ffi::Array& loops, const Stmt& highest_pos_inclusive, arith::Analyzer* analyzer) { const Block& block = realize->block; - Array new_iter_vars; - Map repl_dict; + ffi::Array new_iter_vars; + ffi::Map repl_dict; // create new block itervars for (size_t i = 0; i < block->iter_vars.size(); ++i) { @@ -231,7 +231,7 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); // create new loop vars - Array new_loop_vars; + ffi::Array new_loop_vars; for (const For& loop : loops) { Var new_var = loop->loop_var.copy_with_suffix(""); new_loop_vars.push_back(new_var); @@ -242,7 +242,7 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re } // create new block realize node - Array new_iter_values; + ffi::Array new_iter_values; for (size_t i = 0; i < realize->iter_values.size(); ++i) { new_iter_values.push_back(rewrite_expr(realize->iter_values[i])); } @@ -265,15 +265,15 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re static std::pair CreateInBoundBlock(const BlockRealizeNode* realize, const PaddingBlockInfo& info, - const Array& loops, + const ffi::Array& loops, const Stmt& highest_pos_inclusive, arith::Analyzer* analyzer) { const Block& block = realize->block; - Array new_iter_vars; - Map repl_dict; + ffi::Array new_iter_vars; + ffi::Map repl_dict; // record loop ranges to be mutated - Map new_loop_ranges; + ffi::Map new_loop_ranges; for (const For& loop : loops) { new_loop_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); if (loop.same_as(highest_pos_inclusive)) { @@ -282,7 +282,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* } // create new block iter vars and iter bindings - Array new_iter_binding; + ffi::Array new_iter_binding; for (size_t i = 0; i < info.in_bound_region.size(); ++i) { // add new block itervar const IterVar& origin_itervar = block->iter_vars[i]; @@ -318,7 +318,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* }; // create new read/write region for in-bound accesses - Array reads, writes; + ffi::Array reads, writes; for (const BufferRegion& read : block->reads) { reads.push_back(BufferRegion(read->buffer, rewrite_region(read->region))); } @@ -413,7 +413,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Condition Checks and Information Collection const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); - Map dom_map; + ffi::Map dom_map; arith::Analyzer analyzer; // Check 1. check the block is complete. @@ -423,14 +423,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Check 2. Check loop_sref is an ancestor of block_sref. Also collect // - the highest loop position (inclusive) to insert const pad value filling code above. // - the highest loop position (inclusive) to replace with in-bound value filling code. - Array loop_srefs = GetLoops(block_sref); - Array loops; + ffi::Array loop_srefs = GetLoops(block_sref); + ffi::Array loops; bool found_const_filling_pos = false; bool found_in_bound_filling_pos = false; - For const_filling_pos = GetRef(loop_sref->StmtAs()); + For const_filling_pos = ffi::GetRef(loop_sref->StmtAs()); For in_bound_filling_pos{nullptr}; for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { - For cur_loop = GetRef((*it)->StmtAs()); + For cur_loop = ffi::GetRef((*it)->StmtAs()); Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); dom_map.Set(cur_loop->loop_var, range); analyzer.Bind(cur_loop->loop_var, range); @@ -454,7 +454,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, } ICHECK(in_bound_filling_pos.defined()); if (!found_const_filling_pos) { - throw LoopPositionError(self->mod, const_filling_pos, GetRef(block), + throw LoopPositionError(self->mod, const_filling_pos, ffi::GetRef(block), "decompose_padding"); } @@ -473,7 +473,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); // Step 2. Execute IR replacement. - Block old_scope_root_block = GetRef(scope_root_sref->StmtAs()); + Block old_scope_root_block = ffi::GetRef(scope_root_sref->StmtAs()); Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); if (check_only) { return block_sref; @@ -482,7 +482,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Step 3. Update schedule states. self->Replace(scope_root_sref, new_scope_root, {{old_scope_root_block, new_scope_root}, - {GetRef(block), replace_desc.in_bound_filling_block->block}}); + {ffi::GetRef(block), replace_desc.in_bound_filling_block->block}}); auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get()); // Set block info of created const pad value filling block @@ -556,7 +556,8 @@ struct DecomposPaddingTraits : public UnpackedInstTraits return sch->DecomposePadding(block_rv, loop_rv); } - static String UnpackedAsPython(Array outputs, String block_rv, LoopRV loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + LoopRV loop_rv) { PythonAPICall py("decompose_padding"); py.Input("block", block_rv); py.Input("loop", loop_rv); diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 6dd1eafcc076..de550979c18f 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -29,13 +29,13 @@ class WrongBlockIterTypeError : public ScheduleError { ? "parallel" : (for_kind == ForKind::kVectorized ? "vectorize" : "bind"); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream os; os << "ScheduleError: The \"" << op_str_ << "\" cannot be fulfilled with regard to some of its underlying block"; return os.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; if (op_str_ != "bind") { os << "The \"" << op_str_ @@ -52,7 +52,7 @@ class WrongBlockIterTypeError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; std::string op_str_; Var loop_var_; @@ -127,8 +127,8 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind if (!self->stmt2ref.count(realize->block.get())) { return false; } - CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), - thread_scope); + CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, + ffi::GetRef(realize), thread_scope); } return true; }); @@ -144,7 +144,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind * `for_kind` is `kThreadBinding` */ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, - Optional thread_axis) { + ffi::Optional thread_axis) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); /* @@ -163,12 +163,12 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. - CheckParallelizability(self, GetRef(loop), for_kind, + CheckParallelizability(self, ffi::GetRef(loop), for_kind, thread_axis.has_value() ? runtime::ThreadScope::Create(thread_axis.value()) : runtime::ThreadScope{-1, -1}); // Step 3. Loop update and IR replacement - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->kind = for_kind; if (thread_axis.has_value()) { new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), // @@ -189,13 +189,13 @@ void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { ParallelizeComputation(self, loop_sref, ForKind::kVectorized, std::nullopt); } -void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis) { +void Bind(ScheduleState self, const StmtSRef& loop_sref, const ffi::String& thread_axis) { ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); } void Unroll(ScheduleState self, const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - ObjectPtr new_loop = make_object(*loop); + ObjectPtr new_loop = ffi::make_object(*loop); new_loop->kind = ForKind::kUnrolled; new_loop->thread_binding = std::nullopt; self->Replace(loop_sref, For(new_loop), {}); @@ -216,7 +216,7 @@ struct ParallelTraits : public UnpackedInstTraits { return sch->Parallel(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("parallel"); py.Input("loop", loop_rv); return py.Str(); @@ -239,7 +239,7 @@ struct VectorizeTraits : public UnpackedInstTraits { return sch->Vectorize(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("vectorize"); py.Input("loop", loop_rv); return py.Str(); @@ -258,11 +258,12 @@ struct BindTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, ffi::String thread) { return sch->Bind(loop_rv, thread); } - static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::String thread) { PythonAPICall py("bind"); py.Input("loop", loop_rv); py.Input("thread_axis", thread); @@ -284,7 +285,7 @@ struct UnrollTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } - static String UnpackedAsPython(Array outputs, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv) { PythonAPICall py("unroll"); py.Input("loop", loop_rv); return py.Str(); diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 588770d968ef..0ad1d82ee0df 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -22,9 +22,11 @@ namespace tvm { namespace tir { -Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) { +ffi::Array GetBlocks(const ScheduleState& self, const ffi::String& name, + const GlobalVar& gv) { struct Finder : public StmtVisitor { - explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} + explicit Finder(const ScheduleState& self, const ffi::String& name) + : self_(self), name_(name) {} void VisitStmt_(const BlockNode* block) override { if (block->name_hint == name_) { @@ -36,8 +38,8 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G } const ScheduleState& self_; - const String& name_; - Array results_; + const ffi::String& name_; + ffi::Array results_; }; BaseFunc func = self->mod->Lookup(gv); @@ -47,16 +49,16 @@ Array GetBlocks(const ScheduleState& self, const String& name, const G return std::move(finder.results_); } -Array GetLoops(const StmtSRef& block_sref) { +ffi::Array GetLoops(const StmtSRef& block_sref) { std::vector result; for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); parent = parent->parent) { - result.push_back(GetRef(parent)); + result.push_back(ffi::GetRef(parent)); } return {result.rbegin(), result.rend()}; } -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { +ffi::Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { private: void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } @@ -65,7 +67,7 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent explicit Collector(const ScheduleState& self) : self(self) {} const ScheduleState& self; - Array result; + ffi::Array result; }; Collector collector(self); if (parent_sref->stmt->IsInstance()) { @@ -78,17 +80,17 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent return std::move(collector.result); } -Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { +ffi::Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); return tir::GetProducers(block_sref, self->GetBlockScope(scope_root)); } -Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { +ffi::Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); } -Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { +ffi::Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref); return tir::GetOutputBlocks(self, scope_block); } @@ -104,11 +106,12 @@ struct GetBlockTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) { + static BlockRV UnpackedApplyToSchedule(Schedule sch, ffi::String name, ffi::String func_name) { return sch->GetBlock(name, func_name); } - static String UnpackedAsPython(Array outputs, String name, String func_name) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String name, + ffi::String func_name) { PythonAPICall py("get_block"); py.Input("name", name); py.Input("func_name", func_name); @@ -129,11 +132,11 @@ struct GetLoopsTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetLoops(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_loops"); py.Input("block", block_rv); py.OutputList(outputs); @@ -153,7 +156,7 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { if (auto block = block_or_loop_rv.as()) { return sch->GetChildBlocks(block.value()); } @@ -164,7 +167,8 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { throw; } - static String UnpackedAsPython(Array outputs, String block_or_loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::String block_or_loop_rv) { PythonAPICall py("get_child_blocks"); py.Input("", block_or_loop_rv); py.OutputList(outputs); @@ -184,11 +188,11 @@ struct GetProducersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetProducers(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_producers"); py.Input("block", block_rv); py.OutputList(outputs); @@ -208,11 +212,11 @@ struct GetConsumersTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetConsumers(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_consumers"); py.Input("block", block_rv); py.OutputList(outputs); @@ -232,11 +236,11 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { return sch->GetOutputBlocks(block_rv); } - static String UnpackedAsPython(Array outputs, String block_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv) { PythonAPICall py("get_output_blocks"); py.Input("block", block_rv); py.OutputList(outputs); diff --git a/src/tir/schedule/primitive/hide_buffer_access.cc b/src/tir/schedule/primitive/hide_buffer_access.cc index 469dc278e503..f5e92b8ba50b 100644 --- a/src/tir/schedule/primitive/hide_buffer_access.cc +++ b/src/tir/schedule/primitive/hide_buffer_access.cc @@ -27,25 +27,25 @@ namespace tir { namespace { class BufTypeError : public ScheduleError { public: - explicit BufTypeError(IRModule mod, const String& buf_type) + explicit BufTypeError(IRModule mod, const ffi::String& buf_type) : mod_(std::move(mod)), buf_type_(buf_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Invalid buffer type for hide_buffer_access schedule."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer type for hide_buffer_access schedule should either be 'read'" " or 'write', got " + buf_type_ + " instead."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; - String buf_type_; + ffi::String buf_type_; }; class InvalidIndexError : public ScheduleError { @@ -53,11 +53,11 @@ class InvalidIndexError : public ScheduleError { explicit InvalidIndexError(IRModule mod, int num_access_regions, int buf_idx) : mod_(std::move(mod)), num_access_regions_(num_access_regions), buf_idx_(buf_idx) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Invalid buffer index array for hide_buffer_access schedule."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The buffer index array for hide_buffer_access schedule should be a list of integers" " between 0 and " + std::to_string(num_access_regions_ - 1) + ", got " + std::to_string(buf_idx_) + @@ -66,7 +66,7 @@ class InvalidIndexError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -78,8 +78,9 @@ class InvalidIndexError : public ScheduleError { /******** Implementation ********/ -void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, const String& buf_type, - const Array& buf_index_array) { +void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { /*! * Check: * - validity of buf_index_array @@ -107,7 +108,7 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, cons /* Step 0: Collect new buffer access regions. */ - Array reads, writes; + ffi::Array reads, writes; if (buf_type == "read") { for (size_t i = 0; i < block->reads.size(); ++i) { @@ -129,12 +130,12 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, cons /* Step 1: Replace old block with the new block */ - auto n = make_object(*block); + auto n = ffi::make_object(*block); n->reads = reads; n->writes = writes; Block new_block = Block(n); - Map blk_map; - blk_map.Set(GetRef(block), new_block); + ffi::Map blk_map; + blk_map.Set(ffi::GetRef(block), new_block); self->Replace(block_sref, new_block, blk_map); } @@ -147,13 +148,13 @@ struct UnsafeHideBufferAccessTraits : public UnpackedInstTraits buf_index_array) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::String buf_type, + ffi::Array buf_index_array) { sch->UnsafeHideBufferAccess(block, buf_type, buf_index_array); } - static String UnpackedAsPython(Array outputs, String block, String buf_type, - Array buf_index_array) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::String buf_type, ffi::Array buf_index_array) { PythonAPICall py("unsafe_hide_buffer_access"); py.Input("block", block); py.Input("buf_type", buf_type); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 8931c0e71c11..c625d8c153cf 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -75,8 +75,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { // Loops within the analyzed block that should be replaced struct ReplacementPlan { - Map replacements; - Map new_block_to_old; + ffi::Map replacements; + ffi::Map new_block_to_old; }; // The block to be inserted, along with the location at which it @@ -94,7 +94,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, arith::Analyzer* analyzer) { + ffi::Optional pad_value, arith::Analyzer* analyzer) { ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); @@ -108,7 +108,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BufferStore store; // The block realize that contains the store, if any. - Optional innermost_block_realize; + ffi::Optional innermost_block_realize; // The nested loops whose values contribute to the indices used in // the store. Not all loop variables in the loopnest need to @@ -125,7 +125,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} void VisitStmt_(const ForNode* op) override { - BindLoopVar context(this, GetRef(op)); + BindLoopVar context(this, ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -135,7 +135,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } void VisitStmt_(const BlockRealizeNode* op) override { - BindBlockRealize context(this, GetRef(op)); + BindBlockRealize context(this, ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); } @@ -158,7 +158,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } WriteInfo write_info; - write_info.store = GetRef(op); + write_info.store = ffi::GetRef(op); if (loop_dependency_range) { size_t i = loop_dependency_range.value().first; size_t j = loop_dependency_range.value().second; @@ -220,8 +220,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { class BufferStoreReplacer : public StmtExprMutator { public: BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, - const IndexMap& inverse, const Optional& pad_value, - Map* new_block_to_old, arith::Analyzer* analyzer) + const IndexMap& inverse, const ffi::Optional& pad_value, + ffi::Map* new_block_to_old, arith::Analyzer* analyzer) : info(info), new_buffer(new_buffer), new_indices(inverse->initial_indices), @@ -250,7 +250,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BlockRealize block_realize = info.innermost_block_realize.value(); const auto& block = block_realize->block; - const Array& old_indices = info.store->indices; + const ffi::Array& old_indices = info.store->indices; const auto& old_iter_vars = block->iter_vars; this->new_iter_vars = old_iter_vars; @@ -294,10 +294,10 @@ class TransformLayoutPlanner : private StmtExprVisitor { return Var(ss.str(), var.dtype()); }); - Map + ffi::Map loop_var_to_virtual_var; // For updating padding_predicate in terms of the new indices - Array new_iter_values; // For BlockRealize - Array new_iter_vars; // For Block + ffi::Array new_iter_values; // For BlockRealize + ffi::Array new_iter_vars; // For Block for (size_t i = 0; i < block_index_start; i++) { new_iter_vars.push_back(old_iter_vars[i]); @@ -339,7 +339,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return false; } - const Array& old_indices = info.store->indices; + const ffi::Array& old_indices = info.store->indices; ICHECK_EQ(old_indices.size(), op->indices.size()); ExprDeepEqual expr_equal; @@ -351,9 +351,9 @@ class TransformLayoutPlanner : private StmtExprVisitor { return true; }(); - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); if (can_replace) { - Array new_index_exprs = + ffi::Array new_index_exprs = new_indices.Map([](const auto& var) -> PrimExpr { return var; }); PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_index_exprs, analyzer)[0]; store = @@ -387,7 +387,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } Stmt VisitStmt_(const BlockNode* op) final { - Block orig = GetRef(op); + Block orig = ffi::GetRef(op); Block mutated = Downcast(StmtExprMutator::VisitStmt_(op)); RecordReplacement(orig, mutated); @@ -395,7 +395,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (auto opt = var_remap.Get(var)) { return opt.value(); } else { @@ -423,21 +423,21 @@ class TransformLayoutPlanner : private StmtExprVisitor { const WriteInfo& info; const Buffer& new_buffer; - Array new_indices; - Array new_iter_vars; - Array new_iter_values; + ffi::Array new_indices; + ffi::Array new_iter_vars; + ffi::Array new_iter_values; PrimExpr padding_predicate; const IndexMap& inverse; - const Optional& pad_value; - Map& new_block_to_old; + const ffi::Optional& pad_value; + ffi::Map& new_block_to_old; bool all_stores_replaced{true}; arith::Analyzer* analyzer; - Map var_remap; + ffi::Map var_remap; }; TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, - PrimExpr padding_predicate, Optional pad_value, + PrimExpr padding_predicate, ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value, analyzer); @@ -458,16 +458,16 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Array iter_vars; - Array iter_values; - Array indices; - Map loop_indices_to_block_indices; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; + ffi::Map loop_indices_to_block_indices; ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; @@ -503,14 +503,14 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Map new_block_to_old; - auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional { + ffi::Map new_block_to_old; + auto generate_if_then_else_block = [&](const WriteInfo& info) -> ffi::Optional { if (!info.contains_row_major_traversal || !pad_value.defined() || is_zero(padding_predicate)) { return std::nullopt; @@ -534,7 +534,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return stmt; }; - Map loop_replacements; + ffi::Map loop_replacements; for (const auto& info : write_info_) { if (info.dependent_loopnest.size()) { @@ -553,15 +553,15 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value, + ffi::Optional pad_value, arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } - Array iter_vars; - Array iter_values; - Array indices; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; @@ -673,7 +673,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { BindBlockRealize& operator=(BindBlockRealize&&) = delete; TransformLayoutPlanner* self_{nullptr}; - Optional cache_; + ffi::Optional cache_; std::vector bound_vars_; }; @@ -707,7 +707,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { * * Used to fill the `WriteInfo::innermost_block_realize` field.. */ - Optional innermost_block_realize_{std::nullopt}; + ffi::Optional innermost_block_realize_{std::nullopt}; /*! \brief The buffer to be replaced */ Buffer old_buffer_; @@ -719,23 +719,23 @@ class TransformLayoutPlanner : private StmtExprVisitor { */ class ReuseBlocksCollector : public tir::StmtVisitor { public: - static Map Collect(Block result, Map new_block_to_old) { + static ffi::Map Collect(Block result, ffi::Map new_block_to_old) { return ReuseBlocksCollector(new_block_to_old).Run(result); } private: /*! \brief Entry point */ - Map Run(const Block result) { + ffi::Map Run(const Block result) { VisitStmt(result); return block_sref_reuse_; } /*! \brief Constructor */ - explicit ReuseBlocksCollector(Map new_block_to_old) + explicit ReuseBlocksCollector(ffi::Map new_block_to_old) : new_block_to_old_(new_block_to_old) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { - Block block_ref = GetRef(block); + Block block_ref = ffi::GetRef(block); auto it = new_block_to_old_.find(block_ref); if (it != new_block_to_old_.end()) { block_sref_reuse_.Set((*it).second, (*it).first); @@ -744,9 +744,9 @@ class ReuseBlocksCollector : public tir::StmtVisitor { } /*! \brief New map to be filled with just blocks from scope block */ - Map block_sref_reuse_; + ffi::Map block_sref_reuse_; /*! \brief All block replacements collected so far */ - Map new_block_to_old_; + ffi::Map new_block_to_old_; }; class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { @@ -760,10 +760,10 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { * \return The new AST rooting at the original parent scope and the map from the old block to the * new block */ - static std::pair> Rewrite( + static std::pair> Rewrite( const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, - const IndexMap& index_map, const Optional& opt_inverse, - const PrimExpr& padding_predicate, const Optional& pad_value) { + const IndexMap& index_map, const ffi::Optional& opt_inverse, + const PrimExpr& padding_predicate, const ffi::Optional& pad_value) { arith::Analyzer analyzer; auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, @@ -778,7 +778,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); } - Map block_sref_reuse = + ffi::Map block_sref_reuse = ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_); return {result, block_sref_reuse}; @@ -800,7 +800,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { } } - void RewriteBufferAccess(Buffer* buffer, Array* indices) { + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) { *buffer = new_buffer_; *indices = index_map_->MapIndices(*indices, &index_simplifier_); *indices = this->IterMapSimplifyWithContext(*indices, true); @@ -825,7 +825,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { // replacing `loop` with `{loop, post_proc}`. In this case, avoid // infinite recursion. - For node = GetRef(op); + For node = ffi::GetRef(op); if (auto plan_ptr = std::get_if(&plan_)) { auto it = plan_ptr->replacements.find(node); if (it != plan_ptr->replacements.end()) { @@ -853,8 +853,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return buffer_store; } - void RewriteAccessRegion(Array* old_access_regions, - const Array& infered_access_regions) { + void RewriteAccessRegion(ffi::Array* old_access_regions, + const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(old_buffer_)) { ICHECK(infered_access_regions.size() == 1); @@ -867,7 +867,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const BlockNode* op) final { Block orig = [&]() { - Block block = GetRef(op); + Block block = ffi::GetRef(op); while (true) { if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) { block = (*it).second; @@ -918,8 +918,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Buffer& new_buffer_; const IndexMap& index_map_; const TransformLayoutPlanner::TransformPlan& plan_; - Map buffer_data_to_buffer_; - Map new_block_to_old_; + ffi::Map buffer_data_to_buffer_; + ffi::Map new_block_to_old_; arith::Analyzer index_simplifier_; }; @@ -927,19 +927,19 @@ class BufferIsSubregionError : public ScheduleError { public: explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected" " to be a function parameter or allocated by a block"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of " << "a block, it is expected to be a function parameter or allocated by a block."; return os.str(); } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod() const final { return mod_; } private: @@ -952,14 +952,14 @@ class TransformationPaddingIndexMapError : public ScheduleError { TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value) : mod_(mod), pad_value_(pad_value) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: The IndexMap specifying pad_value has " << pad_value_->final_indices.size() << " outputs, should only have one output"; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has " << pad_value_->final_indices.size() << " outputs, but should only have one output"; @@ -967,7 +967,7 @@ class TransformationPaddingIndexMapError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -982,13 +982,13 @@ class TransformationPaddingTypeError : public ScheduleError { pad_value_dtype_ = pad_value_->final_indices[0].dtype(); } - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype << ", but the transformation fills padding with " << pad_value_ << ", which is of type " @@ -997,7 +997,7 @@ class TransformationPaddingTypeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1025,26 +1025,26 @@ class TransformationPaddingExpressionError : public ScheduleError { void VisitExpr_(const BufferLoadNode* op) final { if (!op->buffer.same_as(buffer_)) { - illegal_load = GetRef(op); + illegal_load = ffi::GetRef(op); } ExprVisitor::VisitExpr_(op); } const Buffer& buffer_; - Optional illegal_load; + ffi::Optional illegal_load; }; TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value, BufferLoad illegal_load) : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Pad value may not contain load from " << illegal_load_->buffer->name; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream ss; ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer " << buffer_->name << ", but pad_value " << pad_value_ << " contains expression " @@ -1053,7 +1053,7 @@ class TransformationPaddingExpressionError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; Buffer buffer_; @@ -1070,13 +1070,13 @@ class TransformationIntroducesPaddingError : public ScheduleError { index_map_(std::move(index_map)), padding_predicate_(std::move(padding_predicate)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { std::ostringstream ss; ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << "."; return ss.str(); } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { arith::Analyzer analyzer; auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer); std::ostringstream os; @@ -1087,7 +1087,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1098,12 +1098,12 @@ class TransformationIntroducesPaddingError : public ScheduleError { // Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid // dtype-mismatch issues later. -IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& args) { +IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Array& args) { const auto& initial_indices_orig = index_map->initial_indices; ICHECK(args.size() == initial_indices_orig.size()); - Array initial_indices; - Map var_map; + ffi::Array initial_indices; + ffi::Map var_map; std::optional index_dtype = std::nullopt; for (size_t i = 0; i < args.size(); ++i) { @@ -1134,8 +1134,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& [&](const Var& var) { return var_map.Get(var); }); } }); - Optional opt_inverse_index_map = - Downcast>(index_map->inverse_index_map); + ffi::Optional opt_inverse_index_map = + Downcast>(index_map->inverse_index_map); if (opt_inverse_index_map.defined()) { opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices); } @@ -1146,13 +1146,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map_orig, - const Optional& pad_value, bool assume_injective_transform) { + const ffi::Optional& pad_value, bool assume_injective_transform) { arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape); @@ -1176,11 +1176,11 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - Optional opt_inverse = std::nullopt; + ffi::Optional opt_inverse = std::nullopt; PrimExpr padding_predicate = Bool(false); if (!assume_injective_transform) { std::tie(opt_inverse, padding_predicate) = [&]() { - Array region; + ffi::Array region; for (const auto& dim : old_buffer->shape) { region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim)); } @@ -1200,7 +1200,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block // alloc_buffers. auto [new_stmt, block_sref_reuse] = - TransformLayoutRewriter::Rewrite(GetRef(scope_block), old_buffer, new_buffer, + TransformLayoutRewriter::Rewrite(ffi::GetRef(scope_block), old_buffer, new_buffer, index_map, opt_inverse, padding_predicate, pad_value); Block new_scope_block = Downcast(new_stmt); @@ -1211,7 +1211,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ IRModuleNode* new_mod = self->mod.CopyOnWrite(); ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); - Map new_buffer_map; + ffi::Map new_buffer_map; for (auto [var, buffer] : old_func->buffer_map) { if (buffer.same_as(old_buffer)) { buffer = new_buffer; @@ -1266,11 +1266,11 @@ class NotBijectiveAffineIndexMapError : public ScheduleError { public: NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map) : mod_(std::move(mod)), index_map_(std::move(index_map)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The index map is not bijective affine."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The index map " << index_map_->ToPythonString() << " is not bijective affine."; return os.str(); @@ -1278,7 +1278,7 @@ class NotBijectiveAffineIndexMapError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -1295,12 +1295,12 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map) : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The index map can't be applied to block iters because the number of " "parameters mismatch."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The index map " << index_map_->ToPythonString() << " can't be applied to block iters of {0} because the number of parameters mismatch. " @@ -1311,7 +1311,7 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError { IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1324,12 +1324,12 @@ class OpaqueNewIterTypeError : public ScheduleError { explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value) : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot detect the new block iter type because it contains more than one " "type of original iter vars."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "Cannot detect the block iter type for new iter value " << iter_value_ << " in {0} because it contains more than one type of original iter vars."; @@ -1337,7 +1337,7 @@ class OpaqueNewIterTypeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -1348,13 +1348,13 @@ class OpaqueNewIterTypeError : public ScheduleError { void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const IndexMap& index_map) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); - const Block& block = GetRef(block_ptr); + const Block& block = ffi::GetRef(block_ptr); arith::Analyzer analyzer; AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Collect outer loops and loop vars - Array loops = GetLoops(block_sref); // outer loops of the block - std::unordered_set loop_vars; // loop vars of the outer loops + ffi::Array loops = GetLoops(block_sref); // outer loops of the block + std::unordered_set loop_vars; // loop vars of the outer loops for (const StmtSRef& loop_sref : loops) { CheckLoopStartsWithZero(self, loop_sref, &analyzer); loop_vars.emplace(loop_sref->StmtAs()->loop_var.get()); @@ -1374,11 +1374,11 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, CheckBlockHasTrivialBinding(self, block_sref); // Step 3: Collect information of block iter vars - Array block_vars; // iter_var->var of each block iter - Map block_iter_dom; // domain of block iter + ffi::Array block_vars; // iter_var->var of each block iter + ffi::Map block_iter_dom; // domain of block iter std::unordered_map block_iter_type; // iter type of block iter - Array + ffi::Array block_iter_range_array; // array of block iter extents in the same order as block iters for (const auto& iter_var : block->iter_vars) { block_vars.push_back(iter_var->var); @@ -1390,15 +1390,16 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 4: Apply the IndexMap to block iters. IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); - Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); - Array new_block_iter_range = index_map->MapShape(block_iter_range_array, &analyzer); + ffi::Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); + ffi::Array new_block_iter_range = + index_map->MapShape(block_iter_range_array, &analyzer); // Step 5: Create the new block after transformation. // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n, // create block iter each expression in f(ax_0, ..., ax_n). - Array new_block_iters; // new block iters - Array new_block_vars; // iter_var->var of new block iters + ffi::Array new_block_iters; // new block iters + ffi::Array new_block_vars; // iter_var->var of new block iters for (size_t i = 0; i < transformed_block_iters.size(); ++i) { Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype}; new_block_vars.push_back(new_block_var); @@ -1409,7 +1410,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); } if (iter_type == kOpaque) { - throw OpaqueNewIterTypeError(self->mod, GetRef(block_ptr), transformed_block_iters[i]); + throw OpaqueNewIterTypeError(self->mod, ffi::GetRef(block_ptr), + transformed_block_iters[i]); } auto dtype = new_block_var.dtype(); new_block_iters.push_back(IterVar( @@ -1419,10 +1421,10 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. - Map inverse_subst_map; + ffi::Map inverse_subst_map; // Construct the inverse map { - Array initial_ranges; + ffi::Array initial_ranges; for (const PrimExpr& extent : block_iter_range_array) { initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent)); } @@ -1433,20 +1435,20 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, throw NotBijectiveAffineIndexMapError(self->mod, index_map); } // old block vars written in terms of new block vars - Array inversed_new_block_vars = + ffi::Array inversed_new_block_vars = inverse_index_map->MapIndices(new_block_vars, &analyzer); for (int i = 0, n = block_vars.size(); i < n; ++i) { inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } } - Block new_block = Downcast(Substitute(GetRef(block_ptr), inverse_subst_map)); + Block new_block = Downcast(Substitute(ffi::GetRef(block_ptr), inverse_subst_map)); new_block.CopyOnWrite()->iter_vars = new_block_iters; new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); // Step 5.3: Create outer loops for each new block iter. // Make new loop vars - Array new_loop_vars; + ffi::Array new_loop_vars; for (int i = 0; i < static_cast(new_block_iters.size()); ++i) { new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype())); } @@ -1457,7 +1459,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, new_block_realize->block = new_block; // Generate outer loops - Stmt body = GetRef(new_block_realize); + Stmt body = ffi::GetRef(new_block_realize); for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; --i) { body = For(Downcast(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, std::move(body)); @@ -1474,14 +1476,14 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, class BufferAxisSeparatorMutator : private ReplaceBufferMutator { public: static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) { + ffi::Map* block_sref_reuse) { BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse); return Downcast(mutator.VisitStmt(scope_block)); } private: BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {} MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { @@ -1493,8 +1495,8 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) { new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; } else { - new_target_buffer.CopyOnWrite()->axis_separators = - Array(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); + new_target_buffer.CopyOnWrite()->axis_separators = ffi::Array( + new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0)); LOG(WARNING) << "Buffer view " << new_target_buffer << " has different dimensionality than backing buffer " << new_source_buffer << ". The `axis_separators` for " << new_target_buffer << "." @@ -1509,10 +1511,11 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator { }; void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const Array& axis_separators) { + BufferIndexType buffer_index_type, + const ffi::Array& axis_separators) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = - GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + GetNthAccessBuffer(self, ffi::GetRef(block_ptr), buffer_index, buffer_index_type); auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); @@ -1527,11 +1530,11 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer Buffer new_buffer = old_buffer; new_buffer.CopyOnWrite()->axis_separators = axis_separators; - Map block_sref_reuse; + ffi::Map block_sref_reuse; // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. - Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef(scope_block), old_buffer, - new_buffer, &block_sref_reuse); + Block new_scope_block = BufferAxisSeparatorMutator::Mutate( + ffi::GetRef(scope_block), old_buffer, new_buffer, &block_sref_reuse); if (!defining_site_sref.defined()) { // mutate buffer_map of the PrimFunc GlobalVar g_var; @@ -1566,16 +1569,17 @@ struct TransformLayoutTraits : public UnpackedInstTraits static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map, Integer buffer_index, Integer buffer_index_type, - Optional pad_value, + ffi::Optional pad_value, Bool assume_injective_transform) { return sch->TransformLayout(block_rv, buffer_index.IntValue(), static_cast(buffer_index_type->value), index_map, pad_value, assume_injective_transform.operator bool()); } - static String UnpackedAsPython(Array outputs, String block_rv, IndexMap index_map, - Integer buffer_index, Integer buffer_index_type, - Optional pad_value, Bool assume_injective_transform) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + IndexMap index_map, Integer buffer_index, + Integer buffer_index_type, ffi::Optional pad_value, + Bool assume_injective_transform) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -1591,13 +1595,13 @@ struct TransformLayoutTraits : public UnpackedInstTraits } public: - static ObjectRef AttrsAsJSON(const Array& attrs) { - Array attrs_record; + static ObjectRef AttrsAsJSON(const ffi::Array& attrs) { + ffi::Array attrs_record; attrs_record.reserve(kNumAttrs); attrs_record.push_back(attrs[0]); attrs_record.push_back(attrs[1]); if (attrs[2] != nullptr) { - attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + attrs_record.push_back(ffi::String(::tvm::SaveJSON(attrs[2]))); } else { attrs_record.push_back(attrs[2]); } @@ -1605,13 +1609,13 @@ struct TransformLayoutTraits : public UnpackedInstTraits return attrs_record; } - static Array AttrsFromJSON(const ObjectRef& attrs_record_) { - Array attrs_record = Downcast>(attrs_record_); - Array attrs; + static ffi::Array AttrsFromJSON(const ObjectRef& attrs_record_) { + ffi::Array attrs_record = Downcast>(attrs_record_); + ffi::Array attrs; attrs.push_back(attrs_record[0]); attrs.push_back(attrs_record[1]); if (attrs_record[2] != nullptr) { - attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); } else { attrs.push_back(attrs_record[2]); } @@ -1636,7 +1640,8 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraitsTransformBlockLayout(block_rv, index_map); } - static String UnpackedAsPython(Array outputs, String block_rv, IndexMap index_map) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + IndexMap index_map) { PythonAPICall py("transform_block_layout"); py.Input("block", block_rv); py.Input("index_map", index_map->ToPythonString()); @@ -1644,17 +1649,17 @@ struct TransformBlockLayoutTraits : public UnpackedInstTraits& attrs) { - Array attrs_record; + static ObjectRef AttrsAsJSON(const ffi::Array& attrs) { + ffi::Array attrs_record; attrs_record.reserve(kNumAttrs); - attrs_record.push_back(String(::tvm::SaveJSON(attrs[0]))); + attrs_record.push_back(ffi::String(::tvm::SaveJSON(attrs[0]))); return attrs_record; } - static Array AttrsFromJSON(const ObjectRef& attrs_record_) { - Array attrs_record = Downcast>(attrs_record_); - Array attrs; - attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[0]))); + static ffi::Array AttrsFromJSON(const ObjectRef& attrs_record_) { + ffi::Array attrs_record = Downcast>(attrs_record_); + ffi::Array attrs; + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[0]))); return attrs; } @@ -1672,14 +1677,16 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { + Integer buffer_index_type, + ffi::Array axis_separators) { return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(), static_cast(buffer_index_type->value), axis_separators); } - static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer buffer_index_type, Array axis_separators) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + Integer buffer_index, Integer buffer_index_type, + ffi::Array axis_separators) { PythonAPICall py("set_axis_separator"); py.Input("block", block_rv); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7baf4e98b775..b2c64e65e568 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -46,14 +46,15 @@ class BlockPredicateAppender : public StmtMutator { /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { public: - explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, - Map* opaque_blocks) + explicit SubstituteVarAndCollectOpaqueBlock( + std::function(const Var&)> vmap, + ffi::Map* opaque_blocks) : vmap_(vmap), opaque_blocks_(opaque_blocks) {} private: PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); - if (Optional ret = vmap_(var)) { + Var var = ffi::GetRef(op); + if (ffi::Optional ret = vmap_(var)) { return tvm::cast(var.dtype(), ret.value()); } else { return var; @@ -69,23 +70,24 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { } /*! \brief The substitute function */ - std::function(const Var&)> vmap_; + std::function(const Var&)> vmap_; /*! \brief The reuse mapping of opaque blocks */ - Map* opaque_blocks_; + ffi::Map* opaque_blocks_; }; /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(ffi::MapObj* opaque_blocks, Map loop_var2extent, + explicit IterMapSimplifyBlockBinding(ffi::MapObj* opaque_blocks, + ffi::Map loop_var2extent, bool preserve_unit_iters) : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent), preserve_unit_iters_(preserve_unit_iters) {} - static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + static For SimplifyBindings(Stmt stmt, const ffi::Array& loop_srefs, ffi::MapObj* opaque_blocks, bool preserve_unit_iters) { - Map loop_var2extent; + ffi::Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -115,7 +117,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { } return realize; } - Array v = + ffi::Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, @@ -123,7 +125,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*analyzer=*/&analzyer_, /*simplify_trivial_iterators=*/!preserve_unit_iters_); if (v.same_as(op->iter_values)) { - return GetRef(op); + return ffi::GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); n->iter_values = std::move(v); @@ -134,7 +136,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*! \brief The reuse mapping */ ffi::MapObj* opaque_blocks_; /*! \brief The range of loops */ - Map loop_var2extent_; + ffi::Map loop_var2extent_; /*! \brief Internal analyzer */ arith::Analyzer analzyer_; /*! \brief Whether or not to simplify unit iterators */ @@ -161,11 +163,12 @@ class BlockPropertyError : public ScheduleError { void VisitStmt_(const BlockNode* op) final { for (const IterVar& iter_var : op->iter_vars) { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - throw BlockPropertyError(state_->mod, GetRef(op)); + throw BlockPropertyError(state_->mod, ffi::GetRef(op)); } - Optional high_exclusive = - top_->parent ? GetRef(top_->parent) : Optional(std::nullopt); - CheckPartialAffineBinding(state_, GetRef(op), high_exclusive); + ffi::Optional high_exclusive = top_->parent + ? ffi::GetRef(top_->parent) + : ffi::Optional(std::nullopt); + CheckPartialAffineBinding(state_, ffi::GetRef(op), high_exclusive); } } const ScheduleState& state_; @@ -173,23 +176,23 @@ class BlockPropertyError : public ScheduleError { }; BlockIterTypeAndAffineBindingChecker checker(self, top); - checker(GetRef(sref->stmt)); + checker(ffi::GetRef(sref->stmt)); } explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The block under the loops to be reordered have block iter type other " "than data-parallel or reduction"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The block {0} under the loops to be reordered have block iter type other than " "data-parallel or reduction"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; @@ -200,17 +203,17 @@ class HasAnnotationOrThreadBindingError : public ScheduleError { explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The primitive can't be applied because the loop has annotation or " "thread binding"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The primitive can't be applied because the loop {0} has annotation or thread binding"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -221,17 +224,17 @@ class OuterNotInnerParent : public ScheduleError { explicit OuterNotInnerParent(IRModule mod, For outer, For inner) : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The outer loop is not the parent of the inner loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loops can't be fused because the outer loop {0} is not the parent of the inner " "loop {1}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {outer_, inner_}; } + ffi::Array LocationsOfInterest() const final { return {outer_, inner_}; } IRModule mod_; For outer_; @@ -243,17 +246,17 @@ class NotOnlyChildError : public ScheduleError { explicit NotOnlyChildError(IRModule mod, For outer, For inner) : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The inner loop is not the only child of outer loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The loops can't be fused because the inner loop {1} is not the only child of outer " "loop {0}."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {outer_, inner_}; } + ffi::Array LocationsOfInterest() const final { return {outer_, inner_}; } IRModule mod_; For outer_; @@ -264,16 +267,16 @@ class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Only one factor can be specified as -1 or none"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } IRModule mod_; }; @@ -282,17 +285,17 @@ class WrongFactorProductError : public ScheduleError { public: explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The product of factors is not larger than or equal to the extent of " "loop"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The product of factors is not larger than or equal to the extent of loop {0}"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -302,16 +305,16 @@ class LoopMultiAppearanceError : public ScheduleError { public: explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Some loop appears in the input array for multiple times."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Loop {0} appears in the input array for multiple times."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -321,12 +324,14 @@ class LoopsNotAChainError : public ScheduleError { public: enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; - explicit LoopsNotAChainError(IRModule mod, Optional problematic_loop, ProblemKind kind) + explicit LoopsNotAChainError(IRModule mod, ffi::Optional problematic_loop, ProblemKind kind) : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} - String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; } + ffi::String FastErrorString() const final { + return "ScheduleError: the loops are not in a chain"; + } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::stringstream ss; ss << "The loops are not in a chain because"; if (kind_ == ProblemKind::kNotUnderAScope) { @@ -338,7 +343,7 @@ class LoopsNotAChainError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { + ffi::Array LocationsOfInterest() const final { if (kind_ == ProblemKind::kNotUnderAScope) { return {}; } else { @@ -348,17 +353,17 @@ class LoopsNotAChainError : public ScheduleError { } IRModule mod_; - Optional problematic_loop_; + ffi::Optional problematic_loop_; ProblemKind kind_; }; class DependentLoopError : public ScheduleError { public: enum class PrimitiveKind { kFuse, kReorder }; - explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind) + explicit DependentLoopError(IRModule mod, For loop, ffi::String inner_var, PrimitiveKind kind) : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { if (kind_ == PrimitiveKind::kReorder) { return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " "in the new order"; @@ -367,7 +372,7 @@ class DependentLoopError : public ScheduleError { } } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { if (kind_ == PrimitiveKind::kReorder) { return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + " in the new order"; @@ -377,16 +382,17 @@ class DependentLoopError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; - String inner_var_; + ffi::String inner_var_; PrimitiveKind kind_; }; -Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, - bool preserve_unit_iters, bool disable_predication) { +ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters, + bool disable_predication) { // Invariance // - The total repeat number has not changed for each direct child block with updating predicate. // - The execution order has not changed. (The block executes with the same args and the same @@ -394,7 +400,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array // Step 1. Check correctness const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; @@ -420,10 +426,10 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); new_loop_vars.emplace_back(std::move(var)); } - Map opaque_block_reuse; + ffi::Map opaque_block_reuse; Stmt new_stmt = loop->body; new_stmt = SubstituteVarAndCollectOpaqueBlock( - [&](const Var& v) -> Optional { + [&](const Var& v) -> ffi::Optional { if (v.same_as(loop->loop_var)) { return substitute_value; } else { @@ -444,7 +450,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array opaque_block_reuse.CopyOnWrite(), preserve_unit_iters); self->Replace(loop_sref, new_stmt, opaque_block_reuse); - Array result_srefs; + ffi::Array result_srefs; result_srefs.reserve(n); for (int i = 0; i < n; i++) { result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); @@ -458,7 +464,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { public: explicit BufferIndicesMapExtractor(Var loop_var) : loop_var_(loop_var) {} - static Map> Extract(Var loop_var, Block& block) { + static ffi::Map> Extract(Var loop_var, Block& block) { BufferIndicesMapExtractor extractor(loop_var); extractor(std::move(block->body)); return extractor.buffer_indices_map; @@ -466,7 +472,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { private: void VisitStmt_(const BufferStoreNode* store) final { - Array indices; + ffi::Array indices; bool check_ = false; for (size_t i = 0; i < store->indices.size(); i++) { const VarNode* var_node = store->indices[i].as(); @@ -482,7 +488,7 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* load) final { - Array indices; + ffi::Array indices; bool check_ = false; for (size_t i = 0; i < load->indices.size(); i++) { const VarNode* var_node = load->indices[i].as(); @@ -500,21 +506,21 @@ class BufferIndicesMapExtractor : public StmtExprVisitor { void VisitStmt_(const BlockNode* op) final { StmtVisitor::VisitStmt_(op); } Var loop_var_; - Map> buffer_indices_map; + ffi::Map> buffer_indices_map; }; -Array MutateBufferRegion(Map> buffer_indices_map, - Map index_range_map, - Array region_arr) { +ffi::Array MutateBufferRegion( + ffi::Map> buffer_indices_map, + ffi::Map index_range_map, ffi::Array region_arr) { // Update the region with new Ranges and return new BufferRegion - Array new_region_arr = + ffi::Array new_region_arr = MutateArray(region_arr, [&buffer_indices_map, &index_range_map](const BufferRegion& region) { BufferRegion new_region = region; auto it = buffer_indices_map.find(new_region->buffer->name); if (it == buffer_indices_map.end()) return new_region; - Array old_indices = buffer_indices_map[new_region->buffer->name]; - Array new_ranges; + ffi::Array old_indices = buffer_indices_map[new_region->buffer->name]; + ffi::Array new_ranges; for (size_t i = 0; i < old_indices.size(); i++) { new_ranges.push_back(index_range_map[old_indices[i]]); } @@ -543,7 +549,7 @@ class BlockMutator : public StmtExprMutator { Var iter_var_ = new_block->iter_vars[inner_iter_var_index]->var; inner_iter_var_index = -1; // As we are working on cloned block, we need to create new instances of iter_var - Array new_iter_vars = + ffi::Array new_iter_vars = MutateArray(new_block->iter_vars, [this, &iter_var_](const IterVar& iter) { auto dtype = iter->var.dtype(); // Create new Var instance for each IterVar @@ -565,29 +571,29 @@ class BlockMutator : public StmtExprMutator { } // Get the (iter_var, new Range) map - Map index_range_map; + ffi::Map index_range_map; for (size_t i = 0; i < new_block->iter_vars.size(); i++) { IterVar iter = new_block->iter_vars[i]; index_range_map.Set(iter->var->name_hint, iter->dom); } // Get the (Buffer, indices) map - Map> buffer_indices_map = + ffi::Map> buffer_indices_map = BufferIndicesMapExtractor::Extract(new_loop_var_, new_block); - Array new_writes = + ffi::Array new_writes = MutateBufferRegion(buffer_indices_map, index_range_map, new_block->writes); if (!new_block->writes.same_as(new_writes)) { // Update the writes with new_writes new_block.CopyOnWrite()->writes = std::move(new_writes); } - Array new_reads = + ffi::Array new_reads = MutateBufferRegion(buffer_indices_map, index_range_map, new_block->reads); if (!new_block->reads.same_as(new_reads)) { // Update the reads with new_reads new_block.CopyOnWrite()->reads = std::move(new_reads); } - Map var_map; + ffi::Map var_map; for (size_t i = 0; i < new_block->iter_vars.size(); i++) { var_map.Set(_op->iter_vars[i]->var, new_block->iter_vars[i]->var); } @@ -598,7 +604,7 @@ class BlockMutator : public StmtExprMutator { } Stmt VisitStmt_(const BlockRealizeNode* realize) final { - Array iter_values = realize->iter_values; + ffi::Array iter_values = realize->iter_values; for (size_t i = 0; i < iter_values.size(); i++) { if (new_loop_var_.same_as(iter_values[i])) { // Get the iter_var index corresponding to loop_var iter_value index @@ -627,7 +633,7 @@ class BlockMutator : public StmtExprMutator { int inner_iter_var_index = -1; }; -const String get_block_name(Stmt loop_body) { +const ffi::String get_block_name(Stmt loop_body) { const BlockRealizeNode* blk_realize = loop_body.as(); if (blk_realize == nullptr) { return get_block_name(loop_body.as()->body); @@ -635,11 +641,11 @@ const String get_block_name(Stmt loop_body) { return blk_realize->block->name_hint; } -Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters) { +ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, + const ffi::Array& factors, bool preserve_unit_iters) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } arith::Analyzer analyzer; @@ -653,12 +659,12 @@ Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, dtype = DataType::Int(bits); } - String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; + ffi::String block_name = get_block_name(loop->body) + "_" + loop->loop_var->name_hint; int n = factors.size(); PrimExpr min_value = loop->min; PrimExpr extent_value; - Array block_partitions; + ffi::Array block_partitions; block_partitions.reserve(n); // Iterate over each pair of factors and create partition @@ -696,7 +702,7 @@ Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref, self->block_info[scope_root].affine_binding = scope_block_affine_binding; // Collect the SRef for each partitioned loop and return - Array partition_srefs; + ffi::Array partition_srefs; partition_srefs.reserve(n); for (int i = 0; i < n; i++) { StmtSRef partition_loop_sref = @@ -717,11 +723,11 @@ class LoopReconstructor : private StmtMutator { * \brief Create the new nest loops induced by the given loops */ void MakeNewLoop() { - Array new_loop_vars; - Array new_loop_extents; - Array new_stmts; + ffi::Array new_loop_vars; + ffi::Array new_loop_extents; + ffi::Array new_stmts; for (size_t i = 0; i < loops_.size(); i++) { - Map var_map; + ffi::Map var_map; for (size_t j = 0; j < loops_[i].size(); j++) { if (i == 0) { Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m"); @@ -748,15 +754,16 @@ class LoopReconstructor : private StmtMutator { private: Stmt VisitStmt_(const BlockNode* block) final { if (block != scope_root_.get()) { - return GetRef(block); + return ffi::GetRef(block); } return StmtMutator::VisitStmt_(block); } Stmt VisitStmt_(const ForNode* loop) final { - if (GetRef(loop) == need_remove_loop_.back()) { + if (ffi::GetRef(loop) == need_remove_loop_.back()) { return new_outer_loop_; - } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), GetRef(loop))) { + } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), + ffi::GetRef(loop))) { return Evaluate(0); } return StmtMutator::VisitStmt_(loop); @@ -764,7 +771,7 @@ class LoopReconstructor : private StmtMutator { Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final { auto ret = Downcast(StmtMutator::VisitSeqStmt_(seq_stmt, true)); - Array filtered; + ffi::Array filtered; for (Stmt stmt : ret->seq) { if (!is_no_op(stmt)) { filtered.push_back(std::move(stmt)); @@ -793,7 +800,7 @@ class LoopReconstructor : private StmtMutator { std::vector need_remove_loop_; }; -StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { +StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { // Invariance // - The total repeat number has not changed for each direct child block. // - The execution order has not changed. (The block executes with the same @@ -813,10 +820,10 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { for (auto p = sref.get(); p != lca.get(); p = p->parent) { if (auto loop = p->StmtAs()) { if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } - CheckLoopStartsWithZero(self, GetRef(p), &analyzer); - nest_loop_i_loops.push_back(GetRef(loop)); + CheckLoopStartsWithZero(self, ffi::GetRef(p), &analyzer); + nest_loop_i_loops.push_back(ffi::GetRef(loop)); nest_loop_i_extents.push_back(loop->extent); } } @@ -824,7 +831,7 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { const ForNode* outer_loop = nullptr; for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) { if (outer_loop && !outer_loop->body.same_as(*iter)) { - throw NotOnlyChildError(self->mod, GetRef(outer_loop), *iter); + throw NotOnlyChildError(self->mod, ffi::GetRef(outer_loop), *iter); } outer_loop = (*iter).get(); } @@ -853,7 +860,7 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { } } // Step 2. Create merged loops and replace the original loops - Block scope_root = GetRef(scope_root_sref->StmtAs()); + Block scope_root = ffi::GetRef(scope_root_sref->StmtAs()); LoopReconstructor reconstructor(scope_root, lca_nest_loops); reconstructor.MakeNewLoop(); Block new_scope_root = Downcast(reconstructor(scope_root)); @@ -862,7 +869,8 @@ StmtSRef Merge(ScheduleState self, const Array& loop_srefs) { return self->stmt2ref.at(reconstructor.new_inner_loop_.get()); } -StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preserve_unit_iters) { +StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, + bool preserve_unit_iters) { // Invariance // - The total repeat number has not changed for each direct child block. // - The execution order has not changed. (The block executes with the same @@ -877,14 +885,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { - throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } if (outer_loop_sref.defined()) { if (sref->parent != outer_loop_sref.get()) { - throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + throw OuterNotInnerParent(self->mod, ffi::GetRef(outer_loop), ffi::GetRef(loop)); } - if (!outer_loop->body.same_as(GetRef(loop))) { - throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + if (!outer_loop->body.same_as(ffi::GetRef(loop))) { + throw NotOnlyChildError(self->mod, ffi::GetRef(outer_loop), ffi::GetRef(loop)); } } outer_loop_sref = sref; @@ -899,7 +907,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser return false; }; if (UsesVar(loop->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(loop), used_var->name_hint, + throw DependentLoopError(self->mod, ffi::GetRef(loop), used_var->name_hint, DependentLoopError::PrimitiveKind::kFuse); } outer_loop_vars.insert(loop->loop_var.get()); @@ -915,7 +923,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser } suffix += "_fused"; Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits)); - Array substitute_value; + ffi::Array substitute_value; substitute_value.resize(loops.size()); PrimExpr lower = 1; for (int i = static_cast(loops.size()) - 1; i > 0; i--) { @@ -926,8 +934,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser } substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; - Map opaque_block_reuse; - auto f_substitute = [&](const Var& v) -> Optional { + ffi::Map opaque_block_reuse; + auto f_substitute = [&](const Var& v) -> ffi::Optional { for (int i = 0; i < n; i++) { if (v.same_as(loops[i]->loop_var)) { return substitute_value[i]; @@ -959,14 +967,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser * \throws ScheduleError If there are duplicate loops in the array */ std::unordered_set CollectLoopsIntoSet( - const ScheduleState& self, const Array& ordered_loop_srefs) { + const ScheduleState& self, const ffi::Array& ordered_loop_srefs) { std::unordered_set loop_srefs; loop_srefs.reserve(ordered_loop_srefs.size()); for (const StmtSRef& loop_sref : ordered_loop_srefs) { auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - throw LoopMultiAppearanceError(self->mod, GetRef(loop)); + throw LoopMultiAppearanceError(self->mod, ffi::GetRef(loop)); } } return loop_srefs; @@ -1004,7 +1012,7 @@ std::pair GetBoundaryOfReorderRange( // `bottom`. if (visited.count(v)) { if (v != bottom) { - throw LoopsNotAChainError(self->mod, GetRef(v->stmt), + throw LoopsNotAChainError(self->mod, ffi::GetRef(v->stmt), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } bottom = loop_sref; @@ -1041,7 +1049,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel const ForNode* inner = loop_sref->StmtAs(); ICHECK(outer != nullptr && inner != nullptr); if (outer->body.get() != inner) { - throw LoopsNotAChainError(self->mod, GetRef(outer), + throw LoopsNotAChainError(self->mod, ffi::GetRef(outer), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } chain.push_back(loop_sref); @@ -1062,7 +1070,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel * reordering */ For ConstructNewLoopChain(const ScheduleState& self, std::vector chain, - const Array& ordered_loop_srefs, + const ffi::Array& ordered_loop_srefs, const std::unordered_set& loop_srefs) { std::unordered_set inner_vars; inner_vars.reserve(chain.size()); @@ -1077,7 +1085,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectorStmtAs(); } ICHECK(copy != nullptr); - ObjectPtr n = make_object(*copy); + ObjectPtr n = ffi::make_object(*copy); if (new_loop.defined()) { n->body = new_loop; } else { @@ -1092,7 +1100,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectormin, f_contain) || UsesVar(copy->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint, + throw DependentLoopError(self->mod, ffi::GetRef(copy), used_var->name_hint, DependentLoopError::PrimitiveKind::kReorder); } inner_vars.insert(copy->loop_var.get()); @@ -1101,7 +1109,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vector& ordered_loop_srefs) { +void Reorder(ScheduleState self, const ffi::Array& ordered_loop_srefs) { if (ordered_loop_srefs.size() <= 1) { return; } @@ -1124,12 +1132,13 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); - self->Replace(GetRef(top), new_loop, {}); + self->Replace(ffi::GetRef(top), new_loop, {}); } StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { if (sref->stmt->IsInstance()) { - For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(sref->stmt)); + For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, + ffi::GetRef(sref->stmt)); self->Replace(sref, new_loop, {}); return self->stmt2ref.at(new_loop.get()); } @@ -1139,8 +1148,8 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { Stmt VisitStmt_(const BlockRealizeNode* realize) final { if (realize->block.get() == src_block_) { - new_loop_ = - For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(realize)); + new_loop_ = For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, + ffi::GetRef(realize)); return new_loop_; } return StmtMutator::VisitStmt_(realize); @@ -1151,13 +1160,13 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { }; CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block"; - StmtSRef parent_sref = GetRef(sref->parent); + StmtSRef parent_sref = ffi::GetRef(sref->parent); NewLoopCreator creator(sref->stmt); - Stmt new_stmt = creator(GetRef(parent_sref->stmt)); + Stmt new_stmt = creator(ffi::GetRef(parent_sref->stmt)); if (new_stmt->IsInstance()) { self->Replace(parent_sref, std::move(new_stmt), {}); } else { - Block old_parent_block = GetRef(parent_sref->StmtAs()); + Block old_parent_block = ffi::GetRef(parent_sref->StmtAs()); Block new_parent_block = Downcast(new_stmt); self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); } @@ -1176,24 +1185,26 @@ struct SplitTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { thread_local Any loop_rv{nullptr}; - thread_local Array factors{nullptr}; + thread_local ffi::Array factors{nullptr}; loop_rv = inputs[0]; - factors = Array{inputs.begin() + 1, inputs.end()}; + factors = ffi::Array{inputs.begin() + 1, inputs.end()}; packed_args[delta] = loop_rv; packed_args[delta + 1] = factors; } - static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, - Array> factors, - Bool preserve_unit_iters, Bool disable_predication) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + ffi::Array> factors, + Bool preserve_unit_iters, + Bool disable_predication) { return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool(), disable_predication.operator bool()); } - static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters, Bool disable_predication) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::Array factors, Bool preserve_unit_iters, + Bool disable_predication) { PythonAPICall py("split"); py.Input("loop", loop_rv); py.Input("factors", factors); @@ -1217,23 +1228,23 @@ struct LoopPartitionTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { thread_local Any loop_rv{nullptr}; - thread_local Array factors{nullptr}; + thread_local ffi::Array factors{nullptr}; loop_rv = inputs[0]; - factors = Array{inputs.begin() + 1, inputs.end()}; + factors = ffi::Array{inputs.begin() + 1, inputs.end()}; packed_args[delta] = loop_rv; packed_args[delta + 1] = factors; } - static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, - Array> factors, - Bool preserve_unit_iters) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + ffi::Array> factors, + Bool preserve_unit_iters) { return sch->LoopPartition(loop_rv, factors, preserve_unit_iters.operator bool()); } - static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + ffi::Array factors, Bool preserve_unit_iters) { PythonAPICall py("loop_partition"); py.Input("loop", loop_rv); py.Input("factors", factors); @@ -1256,17 +1267,18 @@ struct MergeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + static LoopRV UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs) { return sch->Merge(loop_rvs); } - static String UnpackedAsPython(Array outputs, Array loop_rvs) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs) { PythonAPICall py("merge"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } py.SingleOutput(outputs); @@ -1287,19 +1299,19 @@ struct FuseTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs, + static LoopRV UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs, Bool preserve_unit_iters) { return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool()); } - static String UnpackedAsPython(Array outputs, Array loop_rvs, - Bool preserve_unit_iters) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs, Bool preserve_unit_iters) { PythonAPICall py("fuse"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); @@ -1321,17 +1333,18 @@ struct ReorderTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; template - static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const Array& inputs) { + static TVM_ALWAYS_INLINE void _SetInputs(AnyView* packed_args, const ffi::Array& inputs) { packed_args[delta] = inputs; } - static void UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + static void UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs) { return sch->Reorder(loop_rvs); } - static String UnpackedAsPython(Array outputs, Array loop_rvs) { + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::Array loop_rvs) { PythonAPICall py("reorder"); - for (const String& loop_rv : loop_rvs) { + for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } return py.Str(); @@ -1361,7 +1374,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { } } - static String UnpackedAsPython(Array outputs, String rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String rv) { PythonAPICall py("add_unit_loop"); py.Input("block_or_loop", rv); py.SingleOutput(outputs); diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc index 5b724b6bd295..f66ee2f63e33 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/tir/schedule/primitive/pad_einsum.cc @@ -29,8 +29,9 @@ namespace tir { * \param buffer_access The BufferLoad or BufferStore * \return The indices if the indices are all Vars, otherwise std::nullopt */ -Optional> CheckTrivialBufferIndices(const Array& buffer_access) { - Array indices; +ffi::Optional> CheckTrivialBufferIndices( + const ffi::Array& buffer_access) { + ffi::Array indices; for (const PrimExpr& index : buffer_access) { if (index->IsInstance()) { continue; @@ -39,13 +40,13 @@ Optional> CheckTrivialBufferIndices(const Array& buffer_acc if (var == nullptr) { return std::nullopt; } - indices.push_back(GetRef(var)); + indices.push_back(ffi::GetRef(var)); } return indices; } -Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) { - Array indices; +ffi::Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) { + ffi::Array indices; indices.reserve(buffer_region->region.size()); for (const Range& range : buffer_region->region) { if (!tir::is_one(range->extent)) { @@ -55,7 +56,7 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) continue; } if (const auto* var = range->min.as()) { - indices.push_back(GetRef(var)); + indices.push_back(ffi::GetRef(var)); } else { return std::nullopt; } @@ -66,21 +67,21 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) /*! \brief The schedule error class when the padding size is invalid. */ class InvalidPaddingError : public ScheduleError { public: - InvalidPaddingError(IRModule mod, Block block, Array padding) + InvalidPaddingError(IRModule mod, Block block, ffi::Array padding) : mod_(std::move(mod)), block_(std::move(block)), padding_(std::move(padding)) {} IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - String FastErrorString() const final { + ffi::Array LocationsOfInterest() const final { return {block_}; } + ffi::String FastErrorString() const final { return "ScheduleError: The padding size for the block is invalid."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The padding for the block {0} are invalid. It should be a list of " << block_->iter_vars.size() << " positive integers. Got " << padding_; return os.str(); } - static void Check(const ScheduleState& self, const Block& block, Array padding) { + static void Check(const ScheduleState& self, const Block& block, ffi::Array padding) { if (padding.size() != block->iter_vars.size()) { throw InvalidPaddingError(self->mod, block, padding); } @@ -94,7 +95,7 @@ class InvalidPaddingError : public ScheduleError { private: IRModule mod_; Block block_; - Array padding_; + ffi::Array padding_; }; /*! \brief The schedule error class when the block body is not an Einsum pattern. */ @@ -104,11 +105,11 @@ class NonEinsumError : public ScheduleError { : mod_(std::move(mod)), block_(std::move(block)) {} IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - String FastErrorString() const final { + ffi::Array LocationsOfInterest() const final { return {block_}; } + ffi::String FastErrorString() const final { return "ScheduleError: The block is not a computation of Einsum pattern."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The block {0} not a computation of Einsum pattern."; } @@ -120,13 +121,13 @@ class NonEinsumError : public ScheduleError { /*! \brief Data structure that represents a Einsum computation. */ struct Einsum { // The output buffer - Array output_buffers; + ffi::Array output_buffers; // The indices of the output buffer - Map> output_indices; + ffi::Map> output_indices; // The input buffers - Array input_buffers; + ffi::Array input_buffers; // The indices of the input buffers - Map> input_indices; + ffi::Map> input_indices; }; struct BufferPadding { @@ -134,10 +135,10 @@ struct BufferPadding { Buffer padded_buffer; static BufferPadding FromBufferRegion(const BufferRegion& buffer_region, - const Map& iter_extents) { + const ffi::Map& iter_extents) { BufferPadding result; result.buffer = buffer_region->buffer; - Array shape; + ffi::Array shape; shape.reserve(buffer_region->region.size()); int ndim = buffer_region->region.size(); for (int i = 0; i < ndim; ++i) { @@ -145,7 +146,7 @@ struct BufferPadding { ICHECK(pos->IsInstance() || pos->IsInstance()); if (pos->IsInstance()) { shape.push_back(IntImm(pos->dtype, 1)); - } else if (Optional extent = iter_extents.Get(Downcast(pos))) { + } else if (ffi::Optional extent = iter_extents.Get(Downcast(pos))) { shape.push_back(extent.value()); } else { shape.push_back(buffer_region->buffer->shape[i]); @@ -156,12 +157,12 @@ struct BufferPadding { return result; } - Stmt MakeCopyBlock(bool is_read, Array* blocks, arith::Analyzer* analyzer) { - Array loop_vars; - Array loop_doms; - Array iter_vars; - Array instance_dom; - Array indices; + Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::Analyzer* analyzer) { + ffi::Array loop_vars; + ffi::Array loop_doms; + ffi::Array iter_vars; + ffi::Array instance_dom; + ffi::Array indices; int ndim = buffer->shape.size(); for (int i = 0; i < ndim; ++i) { PrimExpr dim{nullptr}; @@ -199,7 +200,8 @@ struct BufferPadding { } Block new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, std::move(body)); blocks->push_back(new_block); - body = BlockRealize(Array{loop_vars.begin(), loop_vars.end()}, Bool(true), new_block); + body = BlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, Bool(true), + new_block); for (int i = ndim - 1; i >= 0; --i) { body = For(loop_vars[i], loop_doms[i]->min, loop_doms[i]->extent, ForKind::kSerial, std::move(body)); @@ -218,7 +220,7 @@ Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { throw NonEinsumError(self->mod, block); } buffer_used.insert(buffer.get()); - if (Optional> opt_indices = CheckTrivialBufferAccess(block->reads[i])) { + if (ffi::Optional> opt_indices = CheckTrivialBufferAccess(block->reads[i])) { result.input_buffers.push_back(buffer); result.input_indices.Set(buffer, opt_indices.value()); } else { @@ -232,7 +234,7 @@ Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { throw NonEinsumError(self->mod, block); } buffer_used.insert(buffer.get()); - if (Optional> opt_indices = CheckTrivialBufferAccess(block->writes[i])) { + if (ffi::Optional> opt_indices = CheckTrivialBufferAccess(block->writes[i])) { result.output_buffers.push_back(buffer); result.output_indices.Set(buffer, opt_indices.value()); } else { @@ -247,12 +249,12 @@ class BufferNotAllocatedInScopeError : public ScheduleError { explicit BufferNotAllocatedInScopeError(IRModule mod, Buffer buffer) : mod_(std::move(mod)), buffer_(std::move(buffer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The buffer is not allocated as an intermediate buffer in current " "PrimFunc."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The buffer " << buffer_->name << " is not allocated as an intermediate buffer in current PrimFunc."; @@ -260,7 +262,7 @@ class BufferNotAllocatedInScopeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } private: IRModule mod_; @@ -273,11 +275,11 @@ class InvalidProducerError : public ScheduleError { explicit InvalidProducerError(IRModule mod, Block producer) : mod_(std::move(mod)), producer_(std::move(producer)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The producer block cannot be padded."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The producer block {0} cannot be padded. It should write to a single buffer and the " "body should be a BufferStore."; @@ -285,7 +287,7 @@ class InvalidProducerError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {producer_}; } + ffi::Array LocationsOfInterest() const final { return {producer_}; } private: IRModule mod_; @@ -296,32 +298,32 @@ class InvalidProducerError : public ScheduleError { class PadEinsumBufferReplacer : public StmtExprMutator { public: Stmt VisitStmt_(const BlockNode* old_block_ptr) final { - Block old_block = GetRef(old_block_ptr); + Block old_block = ffi::GetRef(old_block_ptr); Block block = Downcast(StmtMutator::VisitStmt_(old_block_ptr)); - Array iter_vars; + ffi::Array iter_vars; iter_vars.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { - if (Optional new_dom = iter2padded_extents.Get(iter_var->var)) { - ObjectPtr new_iter_var = make_object(*iter_var.get()); + if (ffi::Optional new_dom = iter2padded_extents.Get(iter_var->var)) { + ObjectPtr new_iter_var = ffi::make_object(*iter_var.get()); new_iter_var->dom = Range::FromMinExtent(iter_var->dom->min, new_dom.value()); iter_vars.push_back(IterVar(new_iter_var)); } else { iter_vars.push_back(iter_var); } } - Array reads; + ffi::Array reads; reads.reserve(block->reads.size()); for (const BufferRegion& read : block->reads) { - if (Optional buffer = buffer_map_.Get(read->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(read->buffer)) { reads.push_back(BufferRegion(buffer.value(), read->region)); } else { reads.push_back(read); } } - Array writes; + ffi::Array writes; writes.reserve(block->writes.size()); for (const BufferRegion& write : block->writes) { - if (Optional buffer = buffer_map_.Get(write->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(write->buffer)) { writes.push_back(BufferRegion(buffer.value(), write->region)); } else { writes.push_back(write); @@ -335,10 +337,10 @@ class PadEinsumBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* old_for_ptr) final { - For old_for = GetRef(old_for_ptr); + For old_for = ffi::GetRef(old_for_ptr); For new_for = Downcast(StmtMutator::VisitStmt_(old_for_ptr)); - if (Optional new_extent = loop_var2padded_extent.Get(new_for->loop_var)) { - ObjectPtr new_for_ptr = make_object(*new_for.get()); + if (ffi::Optional new_extent = loop_var2padded_extent.Get(new_for->loop_var)) { + ObjectPtr new_for_ptr = ffi::make_object(*new_for.get()); new_for_ptr->extent = new_extent.value(); new_for = For(new_for_ptr); } @@ -347,7 +349,7 @@ class PadEinsumBufferReplacer : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* old_store_ptr) final { BufferStore store = Downcast(StmtMutator::VisitStmt_(old_store_ptr)); - if (Optional buffer = buffer_map_.Get(store->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(store->buffer)) { return BufferStore(buffer.value(), store->value, store->indices); } else { return store; @@ -356,29 +358,29 @@ class PadEinsumBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* old_load_ptr) final { BufferLoad load = Downcast(ExprMutator::VisitExpr_(old_load_ptr)); - if (Optional buffer = buffer_map_.Get(load->buffer)) { + if (ffi::Optional buffer = buffer_map_.Get(load->buffer)) { return BufferLoad(buffer.value(), load->indices); } else { return load; } } - Map iter2padded_extents; - Map loop_var2padded_extent; - Map buffer_map_; - Map block_sref_reuse_; + ffi::Map iter2padded_extents; + ffi::Map loop_var2padded_extent; + ffi::Map buffer_map_; + ffi::Map block_sref_reuse_; }; -void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array& padding) { +void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array& padding) { arith::Analyzer analyzer; // Step 1: Input checking and error handling const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); BlockRealize realize = GetBlockRealize(self, block_sref); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - InvalidPaddingError::Check(self, GetRef(block), padding); + InvalidPaddingError::Check(self, ffi::GetRef(block), padding); // Step 2. Extract the Einsum pattern - ExtractEinsum(self, GetRef(block)); + ExtractEinsum(self, ffi::GetRef(block)); // Step 3. Figure out the padding needed PadEinsumBufferReplacer replacer; for (int i = 0, n = padding.size(); i < n; ++i) { @@ -388,15 +390,15 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Arrayvar, new_dom); if (const auto* loop_var = realize->iter_values[i].as()) { - replacer.iter2padded_extents.Set(GetRef(loop_var), new_dom); - replacer.loop_var2padded_extent.Set(GetRef(loop_var), new_dom); + replacer.iter2padded_extents.Set(ffi::GetRef(loop_var), new_dom); + replacer.loop_var2padded_extent.Set(ffi::GetRef(loop_var), new_dom); } } } - auto f_needs_padding = [&replacer](const Array& region) { + auto f_needs_padding = [&replacer](const ffi::Array& region) { for (const Range& range : region) { if (const auto* var = range->min.as()) { - if (replacer.iter2padded_extents.count(GetRef(var))) { + if (replacer.iter2padded_extents.count(ffi::GetRef(var))) { return true; } } @@ -404,7 +406,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array scope_body; + ffi::Array scope_body; if (const auto* seq_stmt = scope_block->body.as()) { scope_body = seq_stmt->seq; } else { @@ -426,10 +428,10 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array read_blocks; - Array write_blocks; - Array new_copy_blocks; - Array alloc_buffers; + ffi::Array read_blocks; + ffi::Array write_blocks; + ffi::Array new_copy_blocks; + ffi::Array alloc_buffers; for (const BufferRegion& buffer_region : block->reads) { if (f_needs_padding(buffer_region->region)) { BufferPadding bp = @@ -449,7 +451,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array new_scope_body; + ffi::Array new_scope_body; for (int i = 0; i < static_cast(scope_body.size()); ++i) { if (i != pos) { new_scope_body.push_back(scope_body[i]); @@ -462,12 +464,12 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array n = make_object(*scope_block); + ObjectPtr n = ffi::make_object(*scope_block); n->body = SeqStmt::Flatten(new_scope_body); n->alloc_buffers.insert(n->alloc_buffers.end(), alloc_buffers.begin(), alloc_buffers.end()); new_scope_block = Block(n); } - replacer.block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + replacer.block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); // Step 8. Do replacement and update flags self->Replace(scope_sref, new_scope_block, replacer.block_sref_reuse_); for (const Block& block : new_copy_blocks) { @@ -490,11 +492,12 @@ struct PadEinsumTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array padding) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::Array padding) { sch->PadEinsum(block, padding); } - static String UnpackedAsPython(Array outputs, String block, Array padding) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array padding) { PythonAPICall py("pad_einsum"); py.Input("block", block); py.Input("padding", padding); diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc index 9fdb322a4996..44a0f9bbe284 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -26,7 +26,7 @@ namespace tir { using support::NDIntSet; -bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { +bool HasBuffer(const ffi::Array& buffer_regions, const Buffer& buffer) { for (const BufferRegion& buffer_region : buffer_regions) { if (buffer_region->buffer.same_as(buffer)) { return true; @@ -35,14 +35,14 @@ bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) return false; } -void RelaxBufferRegions(const Array& buffer_regions, - const Buffer& buffer, // - const Map& var_dom, // - const Map& bindings, // +void RelaxBufferRegions(const ffi::Array& buffer_regions, + const Buffer& buffer, // + const ffi::Map& var_dom, // + const ffi::Map& bindings, // std::vector* relaxed_regions) { for (const BufferRegion& buffer_region : buffer_regions) { if (buffer_region->buffer.same_as(buffer)) { - Array relaxed_region = + ffi::Array relaxed_region = arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); } @@ -53,7 +53,7 @@ class ScopeReplacer : public StmtMutator { public: static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, const ForNode* new_loop) { - ObjectPtr new_scope_block = make_object(*scope_block); + ObjectPtr new_scope_block = ffi::make_object(*scope_block); new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); new_scope_block->alloc_buffers.push_back(dst); return Block(new_scope_block); @@ -64,11 +64,11 @@ class ScopeReplacer : public StmtMutator { : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } - Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const BlockNode* block) final { return ffi::GetRef(block); } Stmt VisitStmt_(const ForNode* loop) final { if (loop == old_loop_) { found_ = true; - return GetRef(new_loop_); + return ffi::GetRef(new_loop_); } return StmtMutator::VisitStmt_(loop); } @@ -81,14 +81,14 @@ class ScopeReplacer : public StmtMutator { class ReadWriteAtBufferReplacer : public StmtExprMutator { public: explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} private: Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (store->buffer.same_as(src_)) { - ObjectPtr new_store = make_object(*store.get()); + ObjectPtr new_store = ffi::make_object(*store.get()); new_store->buffer = dst_; return BufferStore(new_store); } @@ -98,7 +98,7 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* _load) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); if (load->buffer.same_as(src_)) { - ObjectPtr new_load = make_object(*load.get()); + ObjectPtr new_load = ffi::make_object(*load.get()); new_load->buffer = dst_; return BufferLoad(new_load); } @@ -106,9 +106,9 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* _block) final { - Block old_block = GetRef(_block); + Block old_block = ffi::GetRef(_block); Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); - ObjectPtr new_block = make_object(*block.get()); + ObjectPtr new_block = ffi::make_object(*block.get()); new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); block_sref_reuse_->Set(old_block, Block(new_block)); @@ -117,16 +117,16 @@ class ReadWriteAtBufferReplacer : public StmtExprMutator { const Buffer& src_; const Buffer& dst_; - Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; struct ReadWriteAtImpl { template static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int buffer_index, const String& storage_scope, - Map annotations) { + int buffer_index, const ffi::String& storage_scope, + ffi::Map annotations) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer src = GetNthAccessBuffer(self, GetRef(block), buffer_index, + Buffer src = GetNthAccessBuffer(self, ffi::GetRef(block), buffer_index, is_read ? BufferIndexType::kRead : BufferIndexType::kWrite); Buffer dst = WithScope(src, storage_scope); ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); @@ -139,8 +139,8 @@ struct ReadWriteAtImpl { } private: - static Map GetLoopDomain(const StmtSRefNode* loop_sref) { - Map result; + static ffi::Map GetLoopDomain(const StmtSRefNode* loop_sref) { + ffi::Map result; for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; loop_sref = loop_sref->parent) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -153,7 +153,7 @@ struct ReadWriteAtImpl { /*require_stage_pipeline=*/true); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); - block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + block_sref_reuse_.Set(ffi::GetRef(scope_block), new_scope_block); self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); return self_->stmt2ref.at(new_block); } @@ -166,8 +166,8 @@ struct ReadWriteAtImpl { } template - std::pair MakeLoopAndBlock(const String& new_block_name_hint) { - Array subtrees = AsArray(loop_->body); + std::pair MakeLoopAndBlock(const ffi::String& new_block_name_hint) { + ffi::Array subtrees = AsArray(loop_->body); int n_subtrees = subtrees.size(); runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); std::vector relaxed_regions; @@ -197,10 +197,10 @@ struct ReadWriteAtImpl { /*buffer=*/src_, /*var_dom=*/ arith::AsIntSet(LoopDomainOfSRefTreePath( - /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*low_inclusive=*/ffi::GetRef(self_->stmt2ref.at(block)->parent), /*high_exclusive=*/loop_sref_, /*extra_relax_scope=*/scope)), - /*bindings=*/GetBindings(GetRef(realize)), + /*bindings=*/GetBindings(ffi::GetRef(realize)), /*relaxed_regions=*/&relaxed_regions); } return false; @@ -236,7 +236,7 @@ struct ReadWriteAtImpl { // Step 3. Calculate `domain`, the domain of buffer access NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); int ndim = relaxed.size(); - Array domain; + ffi::Array domain; domain.reserve(ndim); for (int i = 0; i < ndim; ++i) { const arith::IntSet& int_set = relaxed[i]; @@ -256,42 +256,43 @@ struct ReadWriteAtImpl { ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); subtrees.insert(subtrees.begin() + insert_pos, realize); - ObjectPtr new_loop = make_object(*loop_); + ObjectPtr new_loop = ffi::make_object(*loop_); new_loop->body = SeqStmt(std::move(subtrees)); return {For(new_loop), realize}; } - BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, - const Map& loop_domain, Array domain) const { + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, + const ffi::String& name_hint, const ffi::Map& loop_domain, + ffi::Array domain) const { int n = domain.size(); std::vector loop_vars; loop_vars.reserve(n); for (int i = 0; i < n; ++i) { loop_vars.push_back(Var("ax" + std::to_string(i))); } - Map bindings; - Array iter_vars; - Array iter_values; - Array indices; + ffi::Map bindings; + ffi::Array iter_vars; + ffi::Array iter_values; + ffi::Array indices; iter_vars.reserve(n); iter_values.reserve(n); indices.reserve(n); for (int i = 0; i < n; ++i) { auto f_substitute = [&loop_domain, &bindings, &iter_vars, - &iter_values](const Var& var) -> Optional { + &iter_values](const Var& var) -> ffi::Optional { auto it = bindings.find(var); if (it != bindings.end()) { return (*it).second; } Range range = loop_domain.at(var); - ObjectPtr v = make_object(*var.get()); + ObjectPtr v = ffi::make_object(*var.get()); v->name_hint = "v" + std::to_string(iter_vars.size()); bindings.Set(var, Var(v)); iter_values.push_back(var); iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); return Var(v); }; - ObjectPtr dom = make_object(*domain[i].get()); + ObjectPtr dom = ffi::make_object(*domain[i].get()); dom->min = Substitute(std::move(dom->min), f_substitute); dom->extent = Substitute(std::move(dom->extent), f_substitute); domain.Set(i, Range(dom)); @@ -318,7 +319,7 @@ struct ReadWriteAtImpl { } explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, - const Buffer& dst, Map annotations) + const Buffer& dst, ffi::Map annotations) : self_(self), loop_sref_(loop_sref), loop_(nullptr), @@ -335,19 +336,19 @@ struct ReadWriteAtImpl { const ForNode* loop_; const Buffer& src_; const Buffer& dst_; - Map annotations_; - Map block_sref_reuse_; + ffi::Map annotations_; + ffi::Map block_sref_reuse_; std::unique_ptr analyzer_; }; StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, {{tir::attr::auto_copy, true}}); } StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, storage_scope, {{tir::attr::auto_copy, true}}); } @@ -364,14 +365,15 @@ struct ReadAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, - int buffer_index, const String& storage_scope); + int buffer_index, const ffi::String& storage_scope); static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer read_buffer_index, String storage_scope) { + Integer read_buffer_index, ffi::String storage_scope) { return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String loop, String block, - Integer read_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, + ffi::String block, Integer read_buffer_index, + ffi::String storage_scope) { PythonAPICall py("read_at"); py.Input("loop", loop); py.Input("block", block); @@ -395,12 +397,13 @@ struct WriteAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, - Integer write_buffer_index, String storage_scope) { + Integer write_buffer_index, ffi::String storage_scope) { return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); } - static String UnpackedAsPython(Array outputs, String loop, String block, - Integer write_buffer_index, String storage_scope) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, + ffi::String block, Integer write_buffer_index, + ffi::String storage_scope) { PythonAPICall py("write_at"); py.Input("loop", loop); py.Input("block", block); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index b46801a0684d..f2b5613abbb5 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -67,7 +67,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { p_new_block->name_hint = p_new_block->name_hint + "_update"; p_new_block->init = std::nullopt; // Add write regions back to read regions in update block. - Array new_reads; + ffi::Array new_reads; std::unordered_set read_bufs; for (const BufferRegion& read_access : block->reads) { read_bufs.insert(read_access->buffer.get()); @@ -89,7 +89,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(seq->seq.size()); for (const Stmt& old_stmt : seq->seq) { new_stmts.push_back(VisitStmt(old_stmt)); @@ -108,7 +108,7 @@ class LoopHeightError : public ScheduleError { public: static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, const BlockRealizeNode* realize, - const Array& loops, + const ffi::Array& loops, const StmtSRef& loop_sref) { for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { // For each block var of type kCommReduce, check its binding @@ -126,7 +126,7 @@ class LoopHeightError : public ScheduleError { const Var& loop_var = higher_loop->StmtAs()->loop_var; if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - throw LoopHeightError(mod, GetRef(loop), GetRef(block)); + throw LoopHeightError(mod, ffi::GetRef(loop), ffi::GetRef(block)); } } } @@ -135,12 +135,12 @@ class LoopHeightError : public ScheduleError { explicit LoopHeightError(IRModule mod, For loop, Block block) : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops " "related to reduce block var"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops " "related to reduce block var of block {1}"; @@ -148,7 +148,7 @@ class LoopHeightError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_, block_}; } + ffi::Array LocationsOfInterest() const final { return {loop_, block_}; } IRModule mod_; For loop_; @@ -188,14 +188,14 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); // Get the outer loops from high to low - Array loops = GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); if (self->enable_check) { // Cond 0. Check loop_sref is an ancestor of block_sref if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { - throw LoopPositionError(self->mod, GetRef(loop), GetRef(block), + throw LoopPositionError(self->mod, ffi::GetRef(loop), ffi::GetRef(block), "decompose_reduction"); } // Cond 1. Check block is reduction @@ -204,8 +204,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); } // IR Manipulation - ObjectPtr init_block = make_object(); - ObjectPtr init_realize = make_object(); + ObjectPtr init_block = ffi::make_object(); + ObjectPtr init_realize = ffi::make_object(); init_block->name_hint = block->name_hint + "_init"; init_block->annotations = block->annotations; init_realize->iter_values = {}; @@ -273,7 +273,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); loop_var_map[old_loop_var] = new_loop_var; - Optional opt_thread_binding = old_loop->thread_binding; + ffi::Optional opt_thread_binding = old_loop->thread_binding; if (opt_thread_binding) { auto thread_binding = opt_thread_binding.value(); auto new_var = thread_binding->var.copy_with_suffix(""); @@ -291,10 +291,10 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, // Step 6. Mutate IR const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace( - GetRef(old_scope_root), GetRef(loop), body, GetRef(block)); + ffi::GetRef(old_scope_root), ffi::GetRef(loop), body, ffi::GetRef(block)); self->Replace(scope_root_sref, new_scope_root, - {{GetRef(old_scope_root), new_scope_root}, - {GetRef(block), new_reduction_block}}); + {{ffi::GetRef(old_scope_root), new_scope_root}, + {ffi::GetRef(block), new_reduction_block}}); self->UpdateScopeBlockInfo(new_scope_root); return self->stmt2ref.at(init_block.get()); } @@ -312,112 +312,114 @@ struct ReducerRegistry { : reducer_getters{ CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] + y[0]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] + y[0]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 0)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 0)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{x[0] * y[0]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] * y[0]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 1)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 1)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{min(x[0], y[0])}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{min(x[0], y[0])}; }, - [](const Array& values) { - return Array{max_value(values[0]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{max_value(values[0]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/1, - [](const Array& x, const Array& y) { - return Array{max(x[0], y[0])}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{max(x[0], y[0])}; }, - [](const Array& values) { - return Array{min_value(values[0]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{min_value(values[0]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { - return Array{x[0] + y[0], x[1] + y[1]}; + [](const ffi::Array& x, const ffi::Array& y) { + return ffi::Array{x[0] + y[0], x[1] + y[1]}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, 0), - make_const(values[1]->dtype, 0)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, 0), + make_const(values[1]->dtype, 0)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - min_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - min_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - max_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; }), CreateReducerGetter( /*n_buffers=*/2, - [](const Array& x, const Array& y) { + [](const ffi::Array& x, const ffi::Array& y) { PrimExpr idx = Select( Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]); PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]); - return Array{idx, val}; + return ffi::Array{idx, val}; }, - [](const Array& values) { - return Array{make_const(values[0]->dtype, -1), - max_value(values[1]->dtype)}; + [](const ffi::Array& values) { + return ffi::Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; })} {} static void RegisterReducer( - int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, - ffi::TypedFunction(Array)> identity_getter) { + int n_buffers, + ffi::TypedFunction(ffi::Array, ffi::Array)> combiner_getter, + ffi::TypedFunction(ffi::Array)> identity_getter) { ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( n_buffers, std::move(combiner_getter), std::move(identity_getter))); } - static ffi::TypedFunction(Array)> CreateReducerGetter( - int n_buffers, ffi::TypedFunction(Array, Array)> combiner_getter, - ffi::TypedFunction(Array)> identity_getter) { + static ffi::TypedFunction(ffi::Array)> CreateReducerGetter( + int n_buffers, + ffi::TypedFunction(ffi::Array, ffi::Array)> combiner_getter, + ffi::TypedFunction(ffi::Array)> identity_getter) { return [n_buffers, // combiner_getter = std::move(combiner_getter), // identity_getter = std::move(identity_getter) // - ](Array values) -> Optional { + ](ffi::Array values) -> ffi::Optional { if (static_cast(values.size()) != n_buffers) { return std::nullopt; } - Array lhs; - Array rhs; + ffi::Array lhs; + ffi::Array rhs; for (int i = 0; i < n_buffers; ++i) { lhs.push_back(Var("x" + std::to_string(i), values[i]->dtype)); rhs.push_back(Var("y" + std::to_string(i), values[i]->dtype)); @@ -431,10 +433,11 @@ struct ReducerRegistry { return &instance; } - std::vector(Array)>> reducer_getters; + std::vector(ffi::Array)>> reducer_getters; }; -std::vector(Array)>> GetReducerGetters() { +std::vector(ffi::Array)>> +GetReducerGetters() { return ReducerRegistry::Global()->reducer_getters; } @@ -443,12 +446,12 @@ class NotSerialLoopKindError : public ScheduleError { explicit NotSerialLoopKindError(IRModule mod, For loop) : mod_(std::move(mod)), loop_(std::move(loop)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input loop of rfactor is required to be `kSerial`"; } - String DetailRenderTemplate() const final { - String str_kind = ForKind2String(loop_->kind); + ffi::String DetailRenderTemplate() const final { + ffi::String str_kind = ForKind2String(loop_->kind); std::ostringstream os; os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " "kind of {0} is `" @@ -457,7 +460,7 @@ class NotSerialLoopKindError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } IRModule mod_; For loop_; @@ -468,12 +471,12 @@ class FactorAxisOutOfRangeError : public ScheduleError { explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; int ndim = static_cast(buffer_->shape.size()); os << "The write buffer " << buffer_->name << " has " << ndim @@ -484,7 +487,7 @@ class FactorAxisOutOfRangeError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } + ffi::Array LocationsOfInterest() const final { return {}; } static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { int ndim = static_cast(buffer->shape.size()); @@ -515,7 +518,7 @@ class LoopPropertyError : public ScheduleError { explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { switch (error_type_) { case kDataParIterTouchRFactorLoop: return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " @@ -534,7 +537,7 @@ class LoopPropertyError : public ScheduleError { throw; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { switch (error_type_) { case kDataParIterTouchRFactorLoop: return "The loop to be applied rfactor is {0}, which is required not to be touched by any " @@ -554,13 +557,13 @@ class LoopPropertyError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } + ffi::Array LocationsOfInterest() const final { return {loop_}; } - static void CheckLoopProperty(const ScheduleState& self, const Array& loops, + static void CheckLoopProperty(const ScheduleState& self, const ffi::Array& loops, const ForNode* rf_loop, const Block& block, const std::unordered_set& data_par_loop_vars, const std::unordered_set& reduce_loop_vars) { - Array children_of_outermost_loop = + ffi::Array children_of_outermost_loop = GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); if (!children_of_outermost_loop[0]->block.same_as(block)) { throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); @@ -601,7 +604,7 @@ class LoopPropertyError : public ScheduleError { * \param loops The loops to be analyzed * \return A mapping from loops to their corresponding loop vars */ -std::unordered_map GetLoopVar2LoopMap(const Array& loops) { +std::unordered_map GetLoopVar2LoopMap(const ffi::Array& loops) { std::unordered_map loop_vars2loop; loop_vars2loop.reserve(loops.size()); for (const For& loop : loops) { @@ -619,16 +622,16 @@ std::unordered_map GetLoopVar2LoopMap(const Array& loo * \param rf_loop The rfactor loop * \return The new created intermediate rfactor buffer */ -Array CreateRFactorBuffers(const Array& buf_stores, int factor_axis, - const ForNode* rf_loop) { - Array rf_buffers; +ffi::Array CreateRFactorBuffers(const ffi::Array& buf_stores, int factor_axis, + const ForNode* rf_loop) { + ffi::Array rf_buffers; rf_buffers.reserve(buf_stores.size()); for (const BufferStore& buf_store : buf_stores) { Buffer buffer = buf_store->buffer; - Array rf_shape = buffer->shape; + ffi::Array rf_shape = buffer->shape; rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); - ObjectPtr n = make_object(*buffer.get()); + ObjectPtr n = ffi::make_object(*buffer.get()); n->shape = rf_shape; n->name = buffer->name + ".rf"; n->data = buffer->data.copy_with_suffix(".rf"); @@ -648,8 +651,8 @@ Array CreateRFactorBuffers(const Array& buf_stores, int fac class BaseBlockCreator { public: explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, bool is_rf_block) + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, bool is_rf_block) : old_block_realize_(std::move(old_block_realize)), rf_loop_(std::move(rf_loop)), old_reduction_updates_(std::move(old_reduction_updates)), @@ -681,13 +684,13 @@ class BaseBlockCreator { // accesses, and the reduction LHS and RHS of the stored values. PreProcess(); Stmt block_body = Substitute(CreateBlockBody(has_reduce_iter), var_map_); - Optional block_init = CreateBlockInit(has_reduce_iter); + ffi::Optional block_init = CreateBlockInit(has_reduce_iter); if (block_init.defined()) { block_init = Substitute(block_init.value(), var_map_); } CreateReadWriteRegions(); - String new_block_name = old_block_realize_->block->name_hint; + ffi::String new_block_name = old_block_realize_->block->name_hint; PrimExpr predicate = const_true(); if (is_rf_block_) { new_block_name = new_block_name + "_rf"; @@ -713,7 +716,7 @@ class BaseBlockCreator { virtual void CreateReadWriteRegions() = 0; Stmt CreateBlockBody(bool has_reduce_iter) { - Array buf_stores; + ffi::Array buf_stores; buf_stores.reserve(n_buffers_); // Case 1. If the block has no reduction iterator, we just store the RHS values into the @@ -726,14 +729,14 @@ class BaseBlockCreator { } // Case 2. If the reduction is for single buffer, the block body is a single BufferStore. - Array stored_values = (*reducer_.get())(update_lhs_, update_rhs_); + ffi::Array stored_values = (*reducer_.get())(update_lhs_, update_rhs_); if (n_buffers_ == 1) { return BufferStore(update_buffers_[0], stored_values[0], update_indices_[0]); } // Case 3. In case the reduction is for multiple buffers, we should create the reduction with // LetStmt so that the reduction execution generates correct results. - Array let_vars; + ffi::Array let_vars; let_vars.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { Var var("v_" + update_buffers_[i]->name, PrimType(stored_values[i]->dtype)); @@ -747,12 +750,12 @@ class BaseBlockCreator { return body; } - Optional CreateBlockInit(bool has_reduce_iter) { + ffi::Optional CreateBlockInit(bool has_reduce_iter) { if (!has_reduce_iter) { return std::nullopt; } - Array inits; + ffi::Array inits; inits.reserve(n_buffers_); for (int i = 0; i < n_buffers_; ++i) { inits.push_back( @@ -767,7 +770,7 @@ class BaseBlockCreator { /*! \brief The new created block-realize */ BlockRealize new_block_realize_; /*! \brief The indices used to access the intermediate rfactor buffer */ - Array rf_buf_access_indices_; + ffi::Array rf_buf_access_indices_; protected: /*! \brief The old block-realize */ @@ -777,18 +780,18 @@ class BaseBlockCreator { /*! \brief The rfactor loop */ For rf_loop_; /*! \brief The update BufferStores of the old block */ - Array old_reduction_updates_; + ffi::Array old_reduction_updates_; /*! \brief The matched commutative reducer */ CommReducer reducer_; /*! \brief The intermediate rfactor buffers */ - Array rf_buffers_; + ffi::Array rf_buffers_; /*! \brief The number of rfactor buffers. */ const int n_buffers_; /*! * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced * by the expressions in future substitution for the two blocks */ - Map var_map_; + ffi::Map var_map_; /*! \brief Whether we are creating the rfactor block or the write-back block */ bool is_rf_block_; @@ -797,17 +800,17 @@ class BaseBlockCreator { /*! \brief The new block iter bindings of the new created block-realize */ std::vector iter_values_; /*! \brief The buffers updated in this block */ - Array update_buffers_; + ffi::Array update_buffers_; /*! \brief The indices of the buffers updated in this block, respectively */ - Array> update_indices_; + ffi::Array> update_indices_; /*! \brief The LHS values of the reduction in this block */ - Array update_lhs_; + ffi::Array update_lhs_; /*! \brief THe RHS values of the reduction in this block */ - Array update_rhs_; + ffi::Array update_rhs_; /*! \brief The read regions of the new created block */ - Array read_regions_; + ffi::Array read_regions_; /*! \brief The write regions of the new created block */ - Array write_regions_; + ffi::Array write_regions_; }; /*! @@ -835,10 +838,10 @@ class BaseBlockCreator { class RFactorBlockCreator : public BaseBlockCreator { public: explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, std::unordered_map loop_vars2loop, - int factor_axis, Array combiner_rhs) + int factor_axis, ffi::Array combiner_rhs) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), std::move(old_reduction_updates), std::move(reducer), std::move(rf_buffers), true), @@ -872,7 +875,7 @@ class RFactorBlockCreator : public BaseBlockCreator { ICHECK(old_iter->iter_type == kCommReduce); // This block iter is a reduction block iter that touches the rfactor loop. So next we try to // create a new block iter for all loop vars that appear in the old binding. - Array vars_in_old_binding = UndefinedVars(old_binding); + ffi::Array vars_in_old_binding = UndefinedVars(old_binding); for (const Var& var : vars_in_old_binding) { auto it = loop_vars2loop_.find(var.get()); if (it == loop_vars2loop_.end()) { @@ -909,7 +912,7 @@ class RFactorBlockCreator : public BaseBlockCreator { } void CreateReadWriteRegions() final { - Map buffer_map; + ffi::Map buffer_map; for (int i = 0; i < n_buffers_; ++i) { buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]); } @@ -921,11 +924,11 @@ class RFactorBlockCreator : public BaseBlockCreator { } write_regions_.reserve(old_block->writes.size()); for (const BufferRegion& write_region : old_block->writes) { - Array region = write_region->region; + ffi::Array region = write_region->region; region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, make_const(additional_iter_->var.dtype(), 1))); - Optional rf_buffer = buffer_map.Get(write_region->buffer); + ffi::Optional rf_buffer = buffer_map.Get(write_region->buffer); ICHECK(rf_buffer.defined()); write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); } @@ -944,7 +947,7 @@ class RFactorBlockCreator : public BaseBlockCreator { /*! \brief The factor_axis specified for rfactor */ int factor_axis_; /*! \brief The RHS values of the reduction in the old block */ - Array combiner_rhs_; + ffi::Array combiner_rhs_; /*! * \brief A mapping which maps loop vars to new created block iters. This map is used to * substitute the loop vars which appear in the bindings of some old block iters with the new @@ -960,10 +963,10 @@ class RFactorBlockCreator : public BaseBlockCreator { class WriteBackBlockCreator : public BaseBlockCreator { public: explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, - Array old_reduction_updates, CommReducer reducer, - Array rf_buffers, IterVar rf_additional_iter, - Array combiner_lhs, - Array rf_buf_access_indices) + ffi::Array old_reduction_updates, CommReducer reducer, + ffi::Array rf_buffers, IterVar rf_additional_iter, + ffi::Array combiner_lhs, + ffi::Array rf_buf_access_indices) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), std::move(old_reduction_updates), std::move(reducer), std::move(rf_buffers), false), @@ -1009,12 +1012,12 @@ class WriteBackBlockCreator : public BaseBlockCreator { CreateRegion(update_lhs_, false); } - void CreateRegion(const Array& buf_loads, bool is_read) { - Array& buf_regions = is_read ? read_regions_ : write_regions_; + void CreateRegion(const ffi::Array& buf_loads, bool is_read) { + ffi::Array& buf_regions = is_read ? read_regions_ : write_regions_; for (const PrimExpr& expr : buf_loads) { const auto* buf_load = expr.as(); ICHECK(buf_load != nullptr); - Array region; + ffi::Array region; region.reserve(buf_load->indices.size()); for (const PrimExpr& index : buf_load->indices) { region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1))); @@ -1027,7 +1030,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { /*! \brief The new created additional block iter of the rfactor block */ IterVar rf_additional_iter_; /*! \brief The LHS values of the reduction in the old block */ - Array combiner_lhs_; + ffi::Array combiner_lhs_; }; /*! @@ -1037,11 +1040,11 @@ class WriteBackBlockCreator : public BaseBlockCreator { * \param loops The loops to be wrapped over the rfactor block * \return A Stmt which is the wrapping result */ -Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array& loops) { +Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const ffi::Array& loops) { int n_loops = static_cast(loops.size()); // Step 1. Create new loop vars. - Array new_loops; + ffi::Array new_loops; std::unordered_map new_loop_var_map; new_loops.reserve(n_loops); new_loop_var_map.reserve(n_loops); @@ -1051,7 +1054,7 @@ Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array new_bindings; + ffi::Array new_bindings; new_bindings.reserve(rf_block_realize->iter_values.size()); for (const PrimExpr& old_binding : rf_block_realize->iter_values) { new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); @@ -1065,7 +1068,7 @@ Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array= 0; --i) { - ObjectPtr p_loop = make_object(*loops[i].get()); + ObjectPtr p_loop = ffi::make_object(*loops[i].get()); p_loop->loop_var = Downcast(new_loop_var_map[loops[i]->loop_var.get()]); p_loop->body = rf_body; rf_body = For(std::move(p_loop)); @@ -1102,7 +1105,7 @@ class BlockReplacer : public StmtMutator { BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, std::unordered_set reduce_loop_vars, std::unordered_map loop_vars2loop, - const Array& rf_buffers) { + const ffi::Array& rf_buffers) { BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), std::move(wb_block_realize), std::move(old_block_realize), std::move(rf_loop), std::move(reduce_loop_vars), @@ -1133,7 +1136,7 @@ class BlockReplacer : public StmtMutator { // that the scope root block has stage-pipeline property, if this loop is not outside the // reduction block, there's no need to recursively mutate. if (!loop_vars2loop_.count(loop->loop_var.get())) { - return GetRef(loop); + return ffi::GetRef(loop); } // Step 2. Recursively mutate. @@ -1160,7 +1163,7 @@ class BlockReplacer : public StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(static_cast(seq->seq.size())); for (const Stmt old_stmt : seq->seq) { @@ -1195,7 +1198,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax } const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { - throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); + throw NotSerialLoopKindError(self->mod, ffi::GetRef(rf_loop)); } // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block @@ -1206,7 +1209,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to // corresponding loop vars. - Array loops = LoopSRefs2Loops(GetLoops(block_sref)); + ffi::Array loops = LoopSRefs2Loops(GetLoops(block_sref)); std::unordered_map loop_vars2loop = GetLoopVar2LoopMap(loops); // Step 4. Check four properties that the loops should have: @@ -1224,11 +1227,11 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs // will be used when constructing the rfactor block. - Array init_values{nullptr}; - Array updates{nullptr}; + ffi::Array init_values{nullptr}; + ffi::Array updates{nullptr}; CommReducer reducer{nullptr}; - Array combiner_lhs{nullptr}; - Array combiner_rhs{nullptr}; + ffi::Array combiner_lhs{nullptr}; + ffi::Array combiner_rhs{nullptr}; std::tie(init_values, updates) = GetInitValuesAndUpdatesFromReductionBlock(self, block); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(self, init_values, updates); @@ -1246,16 +1249,16 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional // dimension that specified by `factor_axis` and `rf_loop`. - Array rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); + ffi::Array rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); // Step 2. Create the rfactor block. - RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + RFactorBlockCreator rf_block_creator(block_realize, ffi::GetRef(rf_loop), updates, reducer, rf_buffers, loop_vars2loop, factor_axis, std::move(combiner_rhs)); rf_block_creator.CreateBlock(); // Step 3. Create the write-back block. - WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + WriteBackBlockCreator wb_block_creator(block_realize, ffi::GetRef(rf_loop), updates, reducer, rf_buffers, std::move(rf_block_creator.additional_iter_), std::move(combiner_lhs), std::move(rf_block_creator.rf_buf_access_indices_)); @@ -1269,10 +1272,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // ***************************************************** // Step 1. Substitute the old scope root block with the new scope root block. - Block old_scope_root_block = GetRef(scope_root->StmtAs()); + Block old_scope_root_block = ffi::GetRef(scope_root->StmtAs()); Block new_scope_root_block = BlockReplacer::Replace( old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, - GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); + ffi::GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); self->Replace( scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); @@ -1304,7 +1307,8 @@ struct DecomposeReductionTraits : public UnpackedInstTraitsDecomposeReduction(block_rv, loop_rv); } - static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, + ffi::String loop_rv) { PythonAPICall py("decompose_reduction"); py.Input("block", block_rv); py.Input("loop", loop_rv); @@ -1329,7 +1333,8 @@ struct RFactorTraits : public UnpackedInstTraits { return sch->RFactor(loop_rv, factor_axis->value); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer factor_axis) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer factor_axis) { PythonAPICall py("rfactor"); py.Input("loop", loop_rv); py.Input("factor_axis", factor_axis->value); diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/tir/schedule/primitive/reorder_block_iter_var.cc index c7967a3ee904..6acc5fa2d924 100644 --- a/src/tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/tir/schedule/primitive/reorder_block_iter_var.cc @@ -27,29 +27,29 @@ namespace tir { */ class InvalidReorderIndex : public ScheduleError { public: - explicit InvalidReorderIndex(IRModule mod, Block block, Array new_order) + explicit InvalidReorderIndex(IRModule mod, Block block, ffi::Array new_order) : mod_(mod), block_(block), new_order_(new_order) {} IRModule mod() const final { return mod_; } - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The specified reorder indices are invalid."; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The user provided block itervar index order " << new_order_ << " is not a valid permutation of [0, 1, ..., num_block_iter_vars-1] in block {0}."; - return String(os.str()); + return ffi::String(os.str()); } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; Block block_; - Array new_order_; + ffi::Array new_order_; }; class BlockIterVarRewriter : public StmtMutator { public: - Map block_map; + ffi::Map block_map; explicit BlockIterVarRewriter(const BlockNode* block_n, std::vector order) : order_(std::move(order)), block_to_rewrite(block_n) {} @@ -60,8 +60,8 @@ class BlockIterVarRewriter : public StmtMutator { if (op->block.get() == block_to_rewrite) { auto block_n = CopyOnWrite(op->block.get()); Block block = op->block; - Array new_iter_vars; - Array new_iter_values; + ffi::Array new_iter_vars; + ffi::Array new_iter_values; for (int idx : order_) { new_iter_vars.push_back(block->iter_vars[idx]); new_iter_values.push_back(op->iter_values[idx]); @@ -80,7 +80,7 @@ class BlockIterVarRewriter : public StmtMutator { }; void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, - const Array& new_order) { + const ffi::Array& new_order) { const BlockNode* block_n = TVM_SREF_TO_BLOCK(block_sref); std::vector new_order_vec; for (const Integer& x : new_order) { @@ -95,7 +95,7 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, return x >= 0 && x < static_cast(num_block_itervars); }); if (!is_full || !is_unique || !is_within_boundary) { - throw InvalidReorderIndex(self->mod, GetRef(block_n), new_order); + throw InvalidReorderIndex(self->mod, ffi::GetRef(block_n), new_order); } // find parent block @@ -103,13 +103,13 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, const StmtSRefNode* p = block_sref.get()->parent; while (p != nullptr) { if (p->stmt->IsInstance()) { - parent_block_n = TVM_SREF_TO_BLOCK(GetRef(p)); + parent_block_n = TVM_SREF_TO_BLOCK(ffi::GetRef(p)); break; } p = p->parent; } - const StmtSRef parent_block_sref = GetRef(p); - const Block& parent_block = GetRef(parent_block_n); + const StmtSRef parent_block_sref = ffi::GetRef(p); + const Block& parent_block = ffi::GetRef(parent_block_n); // rewrite block and blockrealize BlockIterVarRewriter rewriter(block_n, std::move(new_order_vec)); @@ -127,11 +127,12 @@ struct ReorderBlockIterVarTraits : public UnpackedInstTraits new_order) { + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, ffi::Array new_order) { sch->ReorderBlockIterVar(block, new_order); } - static String UnpackedAsPython(Array outputs, String block, Array new_order) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + ffi::Array new_order) { PythonAPICall py("reorder_block_iter_var"); py.Input("block", block); py.Input("new_order", new_order); diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index bef5faf92b67..ff030bbef7a2 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -32,14 +32,14 @@ struct RollingBufferInfo { int rolling_axis; PrimExpr rolling_extent; std::vector axis_overlaps; - std::vector> axis_iter_vars; + std::vector> axis_iter_vars; /*! \brief The map used for ScheduleStateNode::Replace. */ - Map block_reuse; + ffi::Map block_reuse; }; BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, - const Map& dom_map) { - Array relaxed_intsets = + const ffi::Map& dom_map) { + ffi::Array relaxed_intsets = arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); Region relaxed_region; relaxed_region.reserve(relaxed_intsets.size()); @@ -55,16 +55,16 @@ class RollingBufferDependencyError : public ScheduleError { explicit RollingBufferDependencyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: The target block is required to have only RAW dependencies"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "The target block {0} is required to have only RAW dependencies"; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } /*! * \brief Check if the block has only RAW dependencies. @@ -79,13 +79,13 @@ class RollingBufferDependencyError : public ScheduleError { for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { if (!(producers->kind == DepKind::kRAW)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, GetRef(block)); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { if (!(consumers->kind == DepKind::kRAW)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferDependencyError(self->mod, GetRef(block)); + throw RollingBufferDependencyError(self->mod, ffi::GetRef(block)); } } } @@ -99,11 +99,11 @@ class RollingBufferMatchError : public ScheduleError { public: RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) : mod_(mod), block_(block), buffer_region_(buffer_region) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" "matching the rolling pattern such as: hh.outer * stride + hh.inner"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "The target buffer " << buffer_region_->buffer->name << " with region " << buffer_region_->region @@ -113,7 +113,7 @@ class RollingBufferMatchError : public ScheduleError { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -125,12 +125,12 @@ class RollingBufferInsertionError : public ScheduleError { public: RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) : mod_(mod), buffer_(std::move(buffer)), block_(block) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " "location of the target buffer is not a for loop. "; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " << "the lca of the access location of the target buffer " << buffer_->name @@ -138,7 +138,7 @@ class RollingBufferInsertionError : public ScheduleError { return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } + ffi::Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_; @@ -154,7 +154,7 @@ class RollingBufferInfoCollector { RollingBufferInfoCollector collector; if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - throw RollingBufferMatchError(mod, GetRef(block), buffer_region); + throw RollingBufferMatchError(mod, ffi::GetRef(block), buffer_region); } return collector.info_; } @@ -164,7 +164,7 @@ class RollingBufferInfoCollector { const Buffer& buffer = buffer_region->buffer; const Region& region = buffer_region->region; - std::vector> bound_iter_vars; + std::vector> bound_iter_vars; std::vector bound_overlaps; arith::PVar p_var; @@ -173,7 +173,7 @@ class RollingBufferInfoCollector { auto stride = 0; auto divisor = 1; - Optional iter_var; + ffi::Optional iter_var; if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) { // Handle the case of fractional strides // They take this form: floordiv(hh.outer, 2) @@ -211,17 +211,17 @@ class RollingBufferInfoCollector { bound_overlaps.push_back(bound_overlap); } - Array loop_srefs = GetLoops(block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis - Optional roll_iter_var; + ffi::Optional roll_iter_var; int roll_axis = 0; for (const tir::StmtSRef& loop_sref : loop_srefs) { auto loop_var = loop_sref->StmtAs()->loop_var; - auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional var) { - return var && (var.get() == loop_var.get()); - })}; + auto it{std::find_if( + bound_iter_vars.begin(), bound_iter_vars.end(), + [&](ffi::Optional var) { return var && (var.get() == loop_var.get()); })}; if (it != bound_iter_vars.end()) { auto i = std::distance(bound_iter_vars.begin(), it); roll_iter_var = loop_var; @@ -233,7 +233,7 @@ class RollingBufferInfoCollector { if (!roll_iter_var.defined()) { return false; } - Array new_shape = buffer->shape; + ffi::Array new_shape = buffer->shape; new_shape.Set(roll_axis, region[roll_axis]->extent); Buffer new_buffer = buffer; new_buffer.CopyOnWrite()->shape = new_shape; @@ -255,15 +255,15 @@ class RollingBufferRewriter : public StmtExprMutator { public: static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { RollingBufferRewriter rewriter(scope_sref, info); - return rewriter(GetRef(scope_sref->stmt)); + return rewriter(ffi::GetRef(scope_sref->stmt)); } private: explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) : scope_sref_(scope_sref), info_(info) {} - void RewriteAccessRegion(Array* old_access_regions, - const Array& infered_access_regions) { + void RewriteAccessRegion(ffi::Array* old_access_regions, + const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(info_->old_buffer)) { ICHECK(infered_access_regions.size() == 1); @@ -274,8 +274,8 @@ class RollingBufferRewriter : public StmtExprMutator { (*old_access_regions).MutateByApply(fmutate); } - void RewriteBufferAccess(Buffer* buffer, Array* indices) const { - Array new_indices; + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) const { + ffi::Array new_indices; new_indices.reserve(indices->size()); // First modify the access indices to use modulo arithmetic // for the rolling axis @@ -292,11 +292,11 @@ class RollingBufferRewriter : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); + Block old_stmt = ffi::GetRef(block); Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); BlockNode* n = stmt.CopyOnWrite(); if (block == scope_sref_->stmt) { - Array new_alloc_buffers; + ffi::Array new_alloc_buffers; for (const Buffer& buffer : stmt->alloc_buffers) { if (buffer != info_->old_buffer) { new_alloc_buffers.push_back(buffer); @@ -306,7 +306,7 @@ class RollingBufferRewriter : public StmtExprMutator { } n->alloc_buffers = std::move(new_alloc_buffers); } else { - Array new_iter_vars; + ffi::Array new_iter_vars; for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { auto old_iter_var = stmt->iter_vars[i]; if (static_cast(i) == info_->rolling_axis) { @@ -323,7 +323,7 @@ class RollingBufferRewriter : public StmtExprMutator { new_iter_vars.push_back(old_iter_var); } } - Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + ffi::Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); n->iter_vars = std::move(new_iter_vars); @@ -344,7 +344,8 @@ class RollingBufferRewriter : public StmtExprMutator { auto iter_var = info_->axis_iter_vars[i]; if (iter_var && info_->axis_overlaps[i] > 0) { Var var = iter_var.value(); - const Map dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + const ffi::Map dmap = { + std::make_pair(var, arith::IntSet::Interval(0, 0))}; auto iter_value = realize->iter_values[i]; arith::Analyzer analyzer; auto term_2 = analyzer.int_set(iter_value, dmap).min(); @@ -399,7 +400,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf * indices to circularize the buffer along the rolling dimension. * - Append block predicate to avoid recomputing overlapping elements. */ - Map dom_map; + ffi::Map dom_map; const BlockRealize& realize = GetBlockRealize(self, block_sref); const Block& block = realize->block; @@ -412,8 +413,8 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf RollingBufferDependencyError::Check(self, block_sref, scope_root_sref); // Step 3. Find the lca of the access location of the target buffer and relax the buffer - Array loop_srefs = GetLoops(block_sref); - Array consumers_sref = GetConsumers(self, block_sref); + ffi::Array loop_srefs = GetLoops(block_sref); + ffi::Array consumers_sref = GetConsumers(self, block_sref); consumers_sref.push_back(block_sref); StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); if (!lca->StmtAs()) { @@ -426,7 +427,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf if (stmt == lca) { break; } - For cur_loop = GetRef(stmt->StmtAs()); + For cur_loop = ffi::GetRef(stmt->StmtAs()); Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); } @@ -458,7 +459,8 @@ struct RollingBufferTraits : public UnpackedInstTraits { return sch->RollingBuffer(block, write_buffer_index.IntValue()); } - static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, + Integer write_buffer_index) { PythonAPICall py("rolling_buffer"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 1d3cabee1dd6..a8042e0c37eb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,8 +163,8 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; @@ -309,7 +309,7 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, - Optional>* decision) { + ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; @@ -370,7 +370,7 @@ TVM_DLL std::vector SamplePartitionedTile( std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos, - int32_t innerpart_factor, Optional>* decision) { + int32_t innerpart_factor, ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); std::vector result; @@ -419,7 +419,7 @@ std::vector SamplePartitionedTile( tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, - const StmtSRef& block_sref, Optional* decision) { + const StmtSRef& block_sref, ffi::Optional* decision) { // Step 1. Collect all possible compute-at locations. auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); ICHECK_EQ(location_srefs.size(), location_indices.size()); @@ -460,17 +460,17 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + ffi::Array candidates, // + ffi::Array probs, // + ffi::Optional decision) { return sch->SampleCategorical(candidates, probs, decision); } - static String UnpackedAsPython(Array outputs, // - Array candidates, // - Array probs, // - Optional decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, // + ffi::Array candidates, // + ffi::Array probs, // + ffi::Optional decision) { PythonAPICall py("sample_categorical"); py.Input("candidates", candidates); py.Input("probs", probs); @@ -492,14 +492,15 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer max_innermost_factor, - Optional> decision) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer max_innermost_factor, + ffi::Optional> decision) { return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, - Integer max_innermost_factor, Optional> decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer n, Integer max_innermost_factor, + ffi::Optional> decision) { PythonAPICall py("sample_perfect_tile"); py.Input("loop", loop_rv); py.Input("n", n->value); @@ -522,16 +523,16 @@ struct SamplePartitionedTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer partition_pos, Integer innerpart_factor, - Optional> decision) { + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer partition_pos, Integer innerpart_factor, + ffi::Optional> decision) { return sch->SamplePartitionedTile(loop_rv, n->value, partition_pos->value, innerpart_factor->value, decision); } - static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, - Integer partition_pos, Integer innerpart_factor, - Optional> decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, + Integer n, Integer partition_pos, Integer innerpart_factor, + ffi::Optional> decision) { PythonAPICall py("sample_partitioned_tile"); py.Input("loop", loop_rv); py.Input("n", n->value); @@ -557,13 +558,13 @@ struct SampleComputeLocationTraits : public UnpackedInstTraits decision) { + ffi::Optional decision) { return sch->SampleComputeLocation(block_rv, decision); } - static String UnpackedAsPython(Array outputs, // - String block_rv, // - Optional decision) { + static ffi::String UnpackedAsPython(ffi::Array outputs, // + ffi::String block_rv, // + ffi::Optional decision) { PythonAPICall py("sample_compute_location"); py.Input("block", block_rv); py.Decision(decision); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 86b8675dbf56..006a6e081755 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -29,9 +29,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ /**************** Constructor ****************/ -BlockRV::BlockRV() { this->data_ = make_object(); } +BlockRV::BlockRV() { this->data_ = ffi::make_object(); } -LoopRV::LoopRV() { this->data_ = make_object(); } +LoopRV::LoopRV() { this->data_ = ffi::make_object(); } /**************** GetSRef ****************/ @@ -103,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; }) .def("tir.schedule.ScheduleGetSRef", - [](Schedule self, ObjectRef obj) -> Optional { + [](Schedule self, ObjectRef obj) -> ffi::Optional { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); } @@ -250,13 +250,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); - } else if (auto blocks = target.as>()) { + } else if (auto blocks = target.as>()) { return self->Blockize(blocks.value(), preserve_unit_iters); } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }) .def("tir.schedule.ScheduleTensorize", - [](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { + [](Schedule self, ObjectRef rv, ffi::String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); } else if (auto loop_rv = rv.as()) { @@ -273,7 +273,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.ScheduleAnnotate", - [](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { + [](Schedule self, ObjectRef rv, const ffi::String& ann_key, const Any& ann_val) { if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); } @@ -285,7 +285,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ throw; }) .def("tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, - const String& ann_key) { + const ffi::String& ann_key) { if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); } @@ -304,7 +304,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("tir.schedule.ScheduleTransformLayout", [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) { return self->TransformLayout(block_rv, buffer_index, static_cast(buffer_index_type), @@ -313,7 +313,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleTransformBlockLayout", &ScheduleNode::TransformBlockLayout) .def("tir.schedule.ScheduleSetAxisSeparator", [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { return self->SetAxisSeparator(block_rv, buffer_index, static_cast(buffer_index_type), axis_separators); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ff653502ccaa..d6d787e83650 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -39,12 +39,12 @@ using SMap = std::unordered_map; * \param dom_high_exclusive The highest node in the sref tree path * \return An n-dimensional integer set */ -Array AnalyzeRegionUpperBound(const BufferRegion& region, // - const PrimExpr& predicate, // - const StmtSRef& dom_low_inclusive, // - const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { - Map var_dom = LoopDomainOfSRefTreePath( +ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // + arith::Analyzer* analyzer) { + ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); @@ -64,22 +64,22 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, * \param analyzer The analyzer * \return An n-dimensional integer set */ -Array AnalyzeRegionLowerBound(const BufferRegion& region, // - const PrimExpr& predicate, // - const StmtSRef& dom_low_inclusive, // - const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { - Map var_dom = LoopDomainOfSRefTreePath( +ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // + arith::Analyzer* analyzer) { + ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); - if (Optional> result = EstimateRegionLowerBound( + if (ffi::Optional> result = EstimateRegionLowerBound( /*region=*/region->region, /*var_dom=*/var_dom, /*predicate=*/predicate, /*analyzer=*/analyzer)) { return result.value(); } - return Array(region->buffer->shape.size(), arith::IntSet::Nothing()); + return ffi::Array(region->buffer->shape.size(), arith::IntSet::Nothing()); } /*! @@ -90,9 +90,9 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, * \param analyzer The analyzer * \return A boolean indicating if the produced region could cover the consumed region */ -bool ProducerCoversConsumer(const Array& buffer_shape, - const Array& produced_region, - const Array& consumed_region, +bool ProducerCoversConsumer(const ffi::Array& buffer_shape, + const ffi::Array& produced_region, + const ffi::Array& consumed_region, arith::Analyzer* analyzer) { ICHECK_EQ(buffer_shape.size(), consumed_region.size()); ICHECK_EQ(produced_region.size(), consumed_region.size()); @@ -140,7 +140,7 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); const StmtNode* old_stmt = sref->stmt; ICHECK_NE(new_stmt, old_stmt); - self->stmt2ref[new_stmt] = GetRef(sref); + self->stmt2ref[new_stmt] = ffi::GetRef(sref); self->stmt2ref.erase(sref->stmt); sref->stmt = new_stmt; } @@ -177,7 +177,7 @@ class BlockInfoCollector : private StmtVisitor { void MakeBlockInfo(StmtSRef scope_root) { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` - Array child_block_srefs = std::move(block_frames_.back()); + ffi::Array child_block_srefs = std::move(block_frames_.back()); BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { @@ -198,26 +198,26 @@ class BlockInfoCollector : private StmtVisitor { } bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root, - const Array& child_block_srefs) { + const ffi::Array& child_block_srefs) { const StmtSRefNode* limit = scope_root->parent; bool stage_pipeline = true; // Step 1. Unbind the read/write regions of each child block - std::unordered_map> block_reads_unbound; - std::unordered_map> block_writes_unbound; + std::unordered_map> block_reads_unbound; + std::unordered_map> block_writes_unbound; block_reads_unbound.reserve(child_block_srefs.size()); block_writes_unbound.reserve(child_block_srefs.size()); for (const StmtSRef& block_sref : child_block_srefs) { const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Map binding = GetBindings(block2realize_.at(block)); + ffi::Map binding = GetBindings(block2realize_.at(block)); // Step 1.1. Unbind read regions - Array reads; + ffi::Array reads; reads.reserve(block->reads.size()); for (const BufferRegion& region : block->reads) { reads.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); } block_reads_unbound.emplace(block_sref.get(), std::move(reads)); // Step 1.2. Unbind write regions - Array writes; + ffi::Array writes; writes.reserve(block->writes.size()); for (const BufferRegion& region : block->writes) { writes.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); @@ -227,7 +227,7 @@ class BlockInfoCollector : private StmtVisitor { // Step 2. For each consumer, check the region cover property for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; - const Array& deps = kv.second; + const ffi::Array& deps = kv.second; const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); const BlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; @@ -261,14 +261,15 @@ class BlockInfoCollector : private StmtVisitor { // Step 2.3. For each LCA, gather the produced regions, // then check if it could cover the consumed region for (StmtSRef lca = consumer_block_sref; region_cover && lca.get() != limit; - lca = GetRef(lca->parent)) { + lca = ffi::GetRef(lca->parent)) { const std::vector& producer_block_srefs = lca_loc.at(lca.get()); // Skip empty LCA positions if (producer_block_srefs.empty()) { continue; } // For each buffer, record the regions generated under this loop - std::unordered_map>> touched_regions; + std::unordered_map>> + touched_regions; // Step 2.3.1. Find all the regions read by the consumer that we care about for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { const BufferNode* buffer = region->buffer.get(); @@ -277,13 +278,13 @@ class BlockInfoCollector : private StmtVisitor { // Step 2.3.2. Find all the regions written by each producer for (const StmtSRefNode* producer_block_sref : producer_block_srefs) { const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); - StmtSRef parent_sref = GetRef(producer_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(producer_block_sref->parent); for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) { const BufferNode* buffer = region->buffer.get(); auto it = touched_regions.find(buffer); // Skip the regions that is not read by the consumer if (it != touched_regions.end()) { - std::vector>& touched_region = it->second; + std::vector>& touched_region = it->second; // The analysis here is trying to be conservation to rule out false positive cases, // and to make sure region cover property must be satisfied once the flag is on // Therefore, we use lower-bound analysis for producers and upper-bound analysis for @@ -299,14 +300,15 @@ class BlockInfoCollector : private StmtVisitor { } // Step 2.3.3. For each buffer, check the region cover property { - StmtSRef parent_sref = GetRef(consumer_block_sref->parent); + StmtSRef parent_sref = ffi::GetRef(consumer_block_sref->parent); for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { const BufferNode* buffer = region->buffer.get(); - const std::vector>& touched_region = touched_regions.at(buffer); + const std::vector>& touched_region = + touched_regions.at(buffer); if (!touched_region.empty()) { - Array produced_region = + ffi::Array produced_region = arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()}); - Array consumed_region = AnalyzeRegionUpperBound( + ffi::Array consumed_region = AnalyzeRegionUpperBound( /*region=*/region, /*predicate=*/consumer_realize->predicate, /*dom_low_inclusive=*/parent_sref, @@ -337,7 +339,7 @@ class BlockInfoCollector : private StmtVisitor { void VisitStmt_(const BlockRealizeNode* realize) final { block_frames_.emplace_back(); const BlockNode* block = realize->block.get(); - block2realize_.emplace(block, GetRef(realize)); + block2realize_.emplace(block, ffi::GetRef(realize)); // Recursive visit PushSRef(block); VisitStmt(block->body); // `block->init` is not visited @@ -362,7 +364,7 @@ class BlockInfoCollector : private StmtVisitor { /*! \brief The BlockRealize corresponding to blocks */ std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ - std::vector> block_frames_; + std::vector> block_frames_; /*! \brief The auxiliary analyzer */ arith::Analyzer analyzer_; }; @@ -371,7 +373,7 @@ class BlockInfoCollector : private StmtVisitor { ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` n->mod = std::move(mod); @@ -544,7 +546,7 @@ class SRefTreePruner : public StmtVisitor { auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" - << GetRef(op); + << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse const VarNode* loop_var = op->loop_var.get(); @@ -567,7 +569,7 @@ class SRefTreePruner : public StmtVisitor { auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) << "IndexError: Cannot find corresponding StmtSRef for the block:\n" - << GetRef(op); + << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse const auto& sref_reuse = reuse_info_.block_sref_reuse; @@ -617,7 +619,7 @@ class SRefUpdater : public StmtVisitor { private: explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, const std::unordered_map& reused_srefs) - : self_(GetRef(self)), + : self_(ffi::GetRef(self)), ancestors_{src_stmt_parent}, reused_srefs_(reused_srefs) {} @@ -745,15 +747,15 @@ class ChildReplacer : private StmtMutator { } // Skipping sibling blocks and loops other than `src_stmt_` - Stmt VisitStmt_(const BlockNode* op) final { return GetRef(op); } - Stmt VisitStmt_(const ForNode* op) final { return GetRef(op); } + Stmt VisitStmt_(const BlockNode* op) final { return ffi::GetRef(op); } + Stmt VisitStmt_(const ForNode* op) final { return ffi::GetRef(op); } Stmt VisitStmt_(const SeqStmtNode* op) final { int i = this->seq_index_; int n = static_cast(op->seq.size()); if (0 <= i && i < n) { const Stmt& stmt = op->seq[i]; - Optional new_stmt = std::nullopt; + ffi::Optional new_stmt = std::nullopt; const StmtNode* src_stmt = this->src_stmt_; // `stmt` can be For or BlockRealize // `src_stmt` can be For or Block @@ -767,8 +769,8 @@ class ChildReplacer : private StmtMutator { // Case 2. stmt is BlockRealize, src_stmt is Block if (realize->block.get() == src_stmt) { const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); - ObjectPtr new_realize = make_object(*realize); - new_realize->block = GetRef(tgt_block); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(tgt_block); new_stmt = BlockRealize(std::move(new_realize)); } } @@ -814,7 +816,7 @@ class ChildReplacer : private StmtMutator { }; void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, - const Map& _block_sref_reuse) { + const ffi::Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; bool input_correct = @@ -824,7 +826,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ if (!input_correct) { LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" - << GetRef(src_stmt) << "\ntgt_stmt:\n" + << ffi::GetRef(src_stmt) << "\ntgt_stmt:\n" << tgt_stmt; } } @@ -834,7 +836,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } // Reset sref as a new sref so that its content won't be affected by subsequent changes StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index); - Stmt src_stmt = GetRef(src_sref->stmt); + Stmt src_stmt = ffi::GetRef(src_sref->stmt); // Step 1. Create all the nodes needed for the new sref tree. // After this step // 1) all `parent`s are correct @@ -962,18 +964,18 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); // Make `child_tgt_stmt` the root block const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); - ObjectPtr new_realize = make_object(*realize); - new_realize->block = GetRef(child_block); + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = ffi::GetRef(child_block); new_func->body = BlockRealize(std::move(new_realize)); // Finally, move the `ref_new_func` back and update `this->mod` new_map->at(g_var) = std::move(ref_new_func); - this->mod = GetRef(new_mod); + this->mod = ffi::GetRef(new_mod); } uint32_t flag = (debug_mask != -1) // ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { - VerifySRefTree(GetRef(this)); + VerifySRefTree(ffi::GetRef(this)); } } @@ -983,10 +985,10 @@ void ScheduleStateNode::DebugVerify() const { ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { - VerifySRefTree(GetRef(this)); + VerifySRefTree(ffi::GetRef(this)); } if (flag & ScheduleDebugMask::kVerifyCachedFlags) { - VerifyCachedFlags(GetRef(this)); + VerifyCachedFlags(ffi::GetRef(this)); } } @@ -997,7 +999,7 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" - << GetRef(block_sref->stmt); + << ffi::GetRef(block_sref->stmt); return it->second; } @@ -1005,7 +1007,7 @@ void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) { BlockInfoCollector::Collect(this, stmt); } -TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { +TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { const BlockInfo& info = self->GetBlockInfo(block_sref); return {Bool(info.affine_binding), // Bool(info.region_cover), // @@ -1024,9 +1026,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_method("tir.schedule.ScheduleStateGetBlockScope", &ScheduleStateNode::GetBlockScope) .def_method("tir.schedule.ScheduleStateReplace", &ScheduleStateNode::Replace) .def("tir.schedule.ScheduleStateGetSRef", - [](ScheduleState self, Stmt stmt) -> Optional { + [](ScheduleState self, Stmt stmt) -> ffi::Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); + return it != self->stmt2ref.end() ? it->second : ffi::Optional(std::nullopt); }) .def("tir.schedule.ScheduleStateGetCachedFlags", GetCachedFlags); }); diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 5322f85ac1b4..02f99ddfd2a9 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -27,10 +27,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ TraceNode::RegisterReflection(); }); /**************** Constructors ****************/ -Trace::Trace() { data_ = make_object(); } +Trace::Trace() { data_ = ffi::make_object(); } -Trace::Trace(Array insts, Map decisions) { - ObjectPtr n = make_object(); +Trace::Trace(ffi::Array insts, ffi::Map decisions) { + ObjectPtr n = ffi::make_object(); n->insts = std::move(insts); n->decisions = std::move(decisions); data_ = std::move(n); @@ -38,7 +38,7 @@ Trace::Trace(Array insts, Map decisions) { /**************** Utilities ****************/ -int GetNumValidInstructions(const Array& insts, bool remove_postproc) { +int GetNumValidInstructions(const ffi::Array& insts, bool remove_postproc) { if (!remove_postproc) { return insts.size(); } @@ -55,11 +55,11 @@ int GetNumValidInstructions(const Array& insts, bool remove_postpro /**************** TranslateInputRVs ****************/ -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& rv_map) { - Array result; +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& rv_map) { + ffi::Array result; result.reserve(inputs.size()); - auto f_subst_with_rv_map = [&rv_map](const Var& var) -> Optional { + auto f_subst_with_rv_map = [&rv_map](const Var& var) -> ffi::Optional { auto it = rv_map.find(var.get()); if (it == rv_map.end()) { return std::nullopt; @@ -67,7 +67,7 @@ Array TranslateInputRVs(const Array& inputs, const Object* dst = it->second; ICHECK(dst->IsInstance()) << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); - return GetRef(static_cast(dst)); + return ffi::GetRef(static_cast(dst)); }; for (const Any& input : inputs) { @@ -81,12 +81,12 @@ Array TranslateInputRVs(const Array& inputs, input.as()) { // RV: var auto it = rv_map.find(input.as()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; - result.push_back(GetRef(it->second)); + result.push_back(ffi::GetRef(it->second)); } else if (auto expr = input.try_cast()) { // RV: Expr result.push_back(Substitute(expr.value(), f_subst_with_rv_map)); } else if (auto index_map = input.as()) { result.push_back(Substitute(index_map.value(), f_subst_with_rv_map)); - } else if (auto arr = input.as>()) { + } else if (auto arr = input.as>()) { // Recursively convert elements of the array into a new list of ObjectRefs. result.push_back(TranslateInputRVs(arr.value(), rv_map)); } else { @@ -99,20 +99,20 @@ Array TranslateInputRVs(const Array& inputs, } // translate rv to string -Array TranslateInputRVs( - const Array& inputs, - const std::unordered_map& rv_names) { - Array results; +ffi::Array TranslateInputRVs( + const ffi::Array& inputs, + const std::unordered_map& rv_names) { + ffi::Array results; results.reserve(inputs.size()); for (const Any& input : inputs) { if (input == nullptr) { // Case 0. nullptr => None - results.push_back(String("None")); + results.push_back(ffi::String("None")); continue; } // string => "content" if (auto opt_str = input.as()) { - results.push_back(String('"' + (*opt_str).operator std::string() + '"')); + results.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { // directly put back POD type and not string results.push_back(input); @@ -132,19 +132,20 @@ Array TranslateInputRVs( results.push_back(input); } else if (input.as()) { // Case 4: array - results.push_back(TranslateInputRVs(Downcast>(Any(input)), rv_names)); + results.push_back(TranslateInputRVs(Downcast>(Any(input)), rv_names)); } else if (input.as()) { // Case 5: dict results.push_back(input); } else if (input.as()) { // // Case 6: IndexMap IndexMap index_map = Downcast(input); - index_map = index_map.RenameVariables([&rv_names](const Var& var) -> Optional { - if (auto it = rv_names.find(var); it != rv_names.end()) { - return it->second; - } - return std::nullopt; - }); + index_map = + index_map.RenameVariables([&rv_names](const Var& var) -> ffi::Optional { + if (auto it = rv_names.find(var); it != rv_names.end()) { + return it->second; + } + return std::nullopt; + }); results.push_back(index_map); } else { LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input.GetTypeKey(); @@ -154,9 +155,9 @@ Array TranslateInputRVs( return results; } -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& named_rvs) { - Array results; +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& named_rvs) { + ffi::Array results; results.reserve(inputs.size()); for (const Any& input : inputs) { if (input.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { @@ -171,7 +172,7 @@ Array TranslateInputRVs(const Array& inputs, } // Case 4. array if (input.as()) { - results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); + results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); continue; } // Case 5. dict @@ -189,7 +190,7 @@ Array TranslateInputRVs(const Array& inputs, // Case 6. IndexMap if (obj.as()) { IndexMap index_map = Downcast(obj); - index_map = Substitute(index_map, [&named_rvs](const Var& var) -> Optional { + index_map = Substitute(index_map, [&named_rvs](const Var& var) -> ffi::Optional { auto it = named_rvs.find(var->name_hint); if (it != named_rvs.end()) { return Downcast(it->second); @@ -205,7 +206,7 @@ Array TranslateInputRVs(const Array& inputs, } // Case 2. string if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { - results.push_back(String(std::string(name + 1, size - 2))); + results.push_back(ffi::String(std::string(name + 1, size - 2))); continue; } // Case 0 & 1. None, BlockRV, LoopRV, VarRV @@ -218,7 +219,7 @@ Array TranslateInputRVs(const Array& inputs, /**************** TranslateAddOutputRVs ****************/ -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* rv_map) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); @@ -230,17 +231,17 @@ void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_ } } -Array TranslateAddOutputRVs( - const Array& outputs, - std::unordered_map* rv_names) { - Array results; +ffi::Array TranslateAddOutputRVs( + const ffi::Array& outputs, + std::unordered_map* rv_names) { + ffi::Array results; results.reserve(outputs.size()); for (const Any& output : outputs) { int i = rv_names->size(); ICHECK(!rv_names->count(output.cast())) << "ValueError: The random variable has been produced once: " << rv_names->at(output.cast()); - String result; + ffi::String result; if (output == nullptr) { result = "_"; } else if (output.as()) { @@ -260,12 +261,13 @@ Array TranslateAddOutputRVs( return results; } -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, + const ffi::Array& new_outputs, std::unordered_map* named_rvs) { ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); for (int i = 0; i < n; ++i) { - named_rvs->emplace(Downcast(old_outputs[i]), new_outputs[i].cast()); + named_rvs->emplace(Downcast(old_outputs[i]), new_outputs[i].cast()); } } @@ -282,7 +284,7 @@ void TraceNode::Append(Instruction inst, Any decision) { insts.push_back(std::move(inst)); } -Optional TraceNode::Pop() { +ffi::Optional TraceNode::Pop() { if (insts.empty()) { return std::nullopt; } @@ -298,8 +300,8 @@ Optional TraceNode::Pop() { void TraceNode::ApplyToSchedule( Schedule sch, bool remove_postproc, - ffi::TypedFunction& inputs, // - const Array& attrs, // + ffi::TypedFunction& inputs, // + const ffi::Array& attrs, // const Any& decision)> decision_provider) const { std::unordered_map rv_map; @@ -307,21 +309,21 @@ void TraceNode::ApplyToSchedule( if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array inputs = TranslateInputRVs(inst->inputs, rv_map); - Array attrs = inst->attrs; + ffi::Array inputs = TranslateInputRVs(inst->inputs, rv_map); + ffi::Array attrs = inst->attrs; Any decision = this->GetDecision(inst); if (decision_provider != nullptr) { decision = decision_provider(inst, inputs, attrs, decision); } - Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); + ffi::Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); } } ObjectRef TraceNode::AsJSON(bool remove_postproc) const { - std::unordered_map rv_names; - Array json_insts; - Array json_decisions; + std::unordered_map rv_names; + ffi::Array json_insts; + ffi::Array json_decisions; json_insts.reserve(this->insts.size()); json_decisions.reserve(this->insts.size()); @@ -331,40 +333,40 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { if (remove_postproc && kind->IsPostproc()) { break; } - json_insts.push_back(Array{ + json_insts.push_back(ffi::Array{ /* 0: inst name */ kind->name, /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) : ObjectRef(inst->attrs), /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), }); - if (auto decision = this->GetDecision(inst).cast>()) { - json_decisions.push_back(Array{ + if (auto decision = this->GetDecision(inst).cast>()) { + json_decisions.push_back(ffi::Array{ /* 0: index */ Integer(i), /* 1: decision */ decision.value(), }); } ++i; } - return Array{ + return ffi::Array{ /* 0: trace */ std::move(json_insts), /* 1: decision */ std::move(json_decisions), }; } -Array TraceNode::AsPython(bool remove_postproc) const { - std::unordered_map rv_names; - Array py_trace; +ffi::Array TraceNode::AsPython(bool remove_postproc) const { + std::unordered_map rv_names; + ffi::Array py_trace; py_trace.reserve(this->insts.size()); for (const Instruction& inst : this->insts) { if (remove_postproc && inst->kind->IsPostproc()) { break; } - Array attrs; + ffi::Array attrs; attrs.reserve(inst->attrs.size()); for (const Any& obj : inst->attrs) { if (auto opt_str = obj.as()) { - attrs.push_back(String('"' + (*opt_str).operator std::string() + '"')); + attrs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); } else { attrs.push_back(obj); } @@ -379,8 +381,8 @@ Array TraceNode::AsPython(bool remove_postproc) const { } void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { - Array json_insts{nullptr}; - Array json_decisions{nullptr}; + ffi::Array json_insts{nullptr}; + ffi::Array json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` try { const ffi::ArrayObj* arr = json.as(); @@ -388,8 +390,8 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { const auto* arr0 = arr->at(0).as(); const auto* arr1 = arr->at(1).as(); ICHECK(arr0 && arr1); - json_insts = GetRef>(arr0); - json_decisions = GetRef>(arr1); + json_insts = ffi::GetRef>(arr0); + json_decisions = ffi::GetRef>(arr1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " "instructions and an array of decisions, but gets: " @@ -421,18 +423,18 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { int i = 0; for (const Any& inst_entry : json_insts) { InstructionKind kind{nullptr}; - Array inputs{nullptr}; - Array attrs{nullptr}; - Array outputs{ObjectPtr{nullptr}}; + ffi::Array inputs{nullptr}; + ffi::Array attrs{nullptr}; + ffi::Array outputs{ObjectPtr{nullptr}}; // Parse the entry try { const auto* arr = inst_entry.as(); ICHECK(arr && arr->size() == 4); ffi::String arr0 = arr->at(0).cast(); kind = InstructionKind::Get(arr0); - inputs = arr->at(1).cast>(); - attrs = arr->at(2).cast>(); - outputs = arr->at(3).cast>(); + inputs = arr->at(1).cast>(); + attrs = arr->at(2).cast>(); + outputs = arr->at(3).cast>(); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " "inputs, attrs, outputs], but gets: " @@ -446,7 +448,7 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { attrs = kind->f_attrs_from_json(attrs); } // Apply to the schedule - Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); + ffi::Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); // Parse outputs TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); ++i; @@ -457,9 +459,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Trace TraceNode::WithDecision(Instruction inst, Any decision, bool remove_postproc) const { int n_insts = GetNumValidInstructions(this->insts, remove_postproc); - Array new_insts = - Array{this->insts.begin(), this->insts.begin() + n_insts}; - Map new_decisions{this->decisions.begin(), this->decisions.end()}; + ffi::Array new_insts = + ffi::Array{this->insts.begin(), this->insts.begin() + n_insts}; + ffi::Map new_decisions{this->decisions.begin(), this->decisions.end()}; new_decisions.Set(std::move(inst), std::move(decision)); return Trace(new_insts, new_decisions); } @@ -512,8 +514,8 @@ Trace TraceNode::Simplified(bool remove_postproc) const { } } } - return Trace(Array(new_insts.rbegin(), new_insts.rend()), - Map(new_decisions)); + return Trace(ffi::Array(new_insts.rbegin(), new_insts.rend()), + ffi::Map(new_decisions)); } /**************** Repr ****************/ @@ -524,9 +526,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ICHECK_NOTNULL(self); p->stream << "# from tvm import tir\n"; p->stream << "def apply_trace(sch: tir.Schedule) -> None:\n"; - Array repr = self->AsPython(/*remove_postproc=*/false); + ffi::Array repr = self->AsPython(/*remove_postproc=*/false); bool is_first = true; - for (const String& line : repr) { + for (const ffi::String& line : repr) { if (is_first) { is_first = false; } else { @@ -553,7 +555,7 @@ struct EnterPostprocTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } - static String UnpackedAsPython(Array outputs) { + static ffi::String UnpackedAsPython(ffi::Array outputs) { PythonAPICall py("enter_postproc"); return py.Str(); } @@ -570,12 +572,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tir.schedule.Trace", - [](Optional> insts, Optional> decisions) { - return Trace(insts.value_or(Array()), decisions.value_or({})); + [](ffi::Optional> insts, + ffi::Optional> decisions) { + return Trace(insts.value_or(ffi::Array()), decisions.value_or({})); }) .def_method("tir.schedule.TraceGetDecision", &TraceNode::GetDecision) .def("tir.schedule.TraceAppend", - [](Trace self, Instruction inst, Optional decision) { + [](Trace self, Instruction inst, ffi::Optional decision) { if (decision.defined()) { return self->Append(inst, decision.value()); } else { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b9718c1a5f9c..8129f43833c4 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -24,7 +24,7 @@ namespace tir { Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, bool enable_check) { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; @@ -41,7 +41,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand } Schedule TracedScheduleNode::Copy() { - ObjectPtr n = make_object(); + ObjectPtr n = ffi::make_object(); n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->func_working_on_ = this->func_working_on_; @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const ffi::Array& candidates, + const ffi::Array& probs, + ffi::Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); @@ -67,11 +67,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, return result; } -Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, - int max_innermost_factor, - Optional> decision) { +ffi::Array TracedScheduleNode::SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision) { // use None RV object to denotes auto-infer tile factors. - Array results = + ffi::Array results = CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision), /*convert_negone_to_none=*/true); @@ -84,10 +84,10 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } -Array TracedScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, int n, - int partition_pos, int innerpart_factor, - Optional> decision) { - Array results = CreateRV(tir::SamplePartitionedTile( +ffi::Array TracedScheduleNode::SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision) { + ffi::Array results = CreateRV(tir::SamplePartitionedTile( &this->rand_state_, this->GetSRef(loop_rv), n, partition_pos, innerpart_factor, &decision)); static const InstructionKind& kind = InstructionKind::Get("SamplePartitionedTile"); @@ -101,7 +101,7 @@ Array TracedScheduleNode::SamplePartitionedTile(const LoopRV& loop_rv, i } LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, - Optional decision) { + ffi::Optional decision) { LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); @@ -116,7 +116,8 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& func_name) { +BlockRV TracedScheduleNode::GetBlock(const ffi::String& name, + const ffi::Optional& func_name) { GlobalVar gv = NullValue(); if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); @@ -137,8 +138,8 @@ BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& return result; } -Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetLoops(block_rv); +ffi::Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetLoops(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetLoops"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -148,8 +149,8 @@ Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -159,8 +160,8 @@ Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { - Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); +ffi::Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + ffi::Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -170,8 +171,8 @@ Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return results; } -Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetProducers(block_rv); +ffi::Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetProducers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetProducers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -181,8 +182,8 @@ Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { - Array results = ConcreteScheduleNode::GetConsumers(block_rv); +ffi::Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { + ffi::Array results = ConcreteScheduleNode::GetConsumers(block_rv); static const InstructionKind& kind = InstructionKind::Get("GetConsumers"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -192,8 +193,8 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { - Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); +ffi::Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { + ffi::Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // @@ -205,7 +206,7 @@ Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv /******** Schedule: Transform loops ********/ -LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { +LoopRV TracedScheduleNode::Merge(const ffi::Array& loop_rvs) { LoopRV result = ConcreteScheduleNode::Merge(loop_rvs); static const InstructionKind& kind = InstructionKind::Get("Merge"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -215,7 +216,7 @@ LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { return result; } -LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_unit_loops) { +LoopRV TracedScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserve_unit_loops) { LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops); static const InstructionKind& kind = InstructionKind::Get("Fuse"); @@ -226,13 +227,13 @@ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_uni return result; } -Array TracedScheduleNode::Split(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) { - Array results = +ffi::Array TracedScheduleNode::Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) { + ffi::Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters, disable_predication); - Array inputs; + ffi::Array inputs; inputs.reserve(1 + factor_rvs.size()); inputs.push_back(loop_rv); for (const auto& obj : factor_rvs) { @@ -243,18 +244,18 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, trace_->Append( /*inst=*/Instruction(/*kind=*/kind, /*inputs=*/inputs, - /*attrs=*/Array({preserve_unit_iters, disable_predication}), + /*attrs=*/ffi::Array({preserve_unit_iters, disable_predication}), /*outputs=*/results)); return results; } -Array TracedScheduleNode::LoopPartition(const LoopRV& loop_rv, - const Array>& factor_rvs, - bool preserve_unit_iters) { - Array results = +ffi::Array TracedScheduleNode::LoopPartition( + const LoopRV& loop_rv, const ffi::Array>& factor_rvs, + bool preserve_unit_iters) { + ffi::Array results = ConcreteScheduleNode::LoopPartition(loop_rv, factor_rvs, preserve_unit_iters); - Array inputs; + ffi::Array inputs; inputs.reserve(1 + factor_rvs.size()); inputs.push_back(loop_rv); for (const auto& obj : factor_rvs) { @@ -269,7 +270,7 @@ Array TracedScheduleNode::LoopPartition(const LoopRV& loop_rv, return results; } -void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { +void TracedScheduleNode::Reorder(const ffi::Array& ordered_loop_rvs) { ConcreteScheduleNode::Reorder(ordered_loop_rvs); static const InstructionKind& kind = InstructionKind::Get("Reorder"); @@ -280,7 +281,7 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { } void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, - const Array new_order) { + const ffi::Array new_order) { ConcreteScheduleNode::ReorderBlockIterVar(block_rv, new_order); static const InstructionKind& kind = InstructionKind::Get("ReorderBlockIterVar"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -332,7 +333,7 @@ void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { /*outputs=*/{})); } -void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { +void TracedScheduleNode::Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) { ConcreteScheduleNode::Bind(loop_rv, thread_axis); static const InstructionKind& kind = InstructionKind::Get("Bind"); @@ -354,8 +355,8 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { /******** Schedule: Insert cache stages ********/ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks); @@ -368,8 +369,8 @@ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_i } BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - const Array consumer_blocks) { + const ffi::String& storage_scope, + const ffi::Array consumer_blocks) { BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, consumer_blocks); @@ -382,7 +383,7 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer } BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); @@ -398,7 +399,7 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b } BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, + const ffi::String& storage_scope, const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, storage_scope, index_map); @@ -413,11 +414,11 @@ BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write return result; } -Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) { - Array result = +ffi::Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) { + ffi::Array result = ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope); - Array results; + ffi::Array results; for (const BlockRV& r : result) { results.push_back(r); } @@ -429,10 +430,12 @@ Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int rea return result; } -Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) { - Array result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); - Array outputs; +ffi::Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, + const ffi::String& storage_scope, + int cse_thresh) { + ffi::Array result = + ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); + ffi::Array outputs; for (const BlockRV& r : result) { outputs.push_back(r); } @@ -459,7 +462,7 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, /******** Schedule: Data movement ********/ BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int read_buffer_index, const String& storage_scope) { + int read_buffer_index, const ffi::String& storage_scope) { BlockRV result = ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); @@ -472,7 +475,7 @@ BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_r } BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, - int write_buffer_index, const String& storage_scope) { + int write_buffer_index, const ffi::String& storage_scope) { BlockRV result = ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); @@ -565,7 +568,7 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, } void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, - const String& storage_scope) { + const ffi::String& storage_scope) { ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("SetScope"); trace_->Append(/*inst=*/Instruction( @@ -576,7 +579,7 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, } void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index, - const String& dtype) { + const ffi::String& dtype) { ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype); static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType"); trace_->Append(/*inst=*/Instruction( @@ -599,7 +602,7 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_i return new_block; } -BlockRV TracedScheduleNode::Blockize(const Array& blocks, bool preserve_unit_iters) { +BlockRV TracedScheduleNode::Blockize(const ffi::Array& blocks, bool preserve_unit_iters) { BlockRV new_block = ConcreteScheduleNode::Blockize(blocks, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Blockize"); trace_->Append(/*inst=*/Instruction( @@ -610,7 +613,7 @@ BlockRV TracedScheduleNode::Blockize(const Array& blocks, bool preserve return new_block; } -void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); @@ -621,7 +624,7 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, /*outputs=*/{})); } -void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters); static const InstructionKind& kind = InstructionKind::Get("Tensorize"); @@ -634,7 +637,7 @@ void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin /******** Schedule: Annotation ********/ -void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, +void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) { ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val); static const InstructionKind& kind = InstructionKind::Get("Annotate"); @@ -644,7 +647,7 @@ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, /*outputs=*/{})); } -void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, +void TracedScheduleNode::Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) { ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); static const InstructionKind& kind = InstructionKind::Get("Annotate"); @@ -654,7 +657,7 @@ void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key /*outputs=*/{})); } -void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { +void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) { ConcreteScheduleNode::Unannotate(loop_rv, ann_key); static const InstructionKind& kind = InstructionKind::Get("Unannotate"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -663,7 +666,7 @@ void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key /*outputs=*/{})); } -void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { +void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) { ConcreteScheduleNode::Unannotate(block_rv, ann_key); static const InstructionKind& kind = InstructionKind::Get("Unannotate"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, @@ -677,7 +680,7 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value, + const ffi::Optional& pad_value, bool assume_injective_transform) { ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, pad_value, assume_injective_transform); @@ -704,7 +707,7 @@ void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const Ind void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) { + const ffi::Array& axis_separators) { ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type, axis_separators); static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator"); @@ -727,7 +730,7 @@ BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const Loop return new_block; } -void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array& padding) { +void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) { ConcreteScheduleNode::PadEinsum(block_rv, padding); static const InstructionKind& kind = InstructionKind::Get("PadEinsum"); trace_->Append(/*inst=*/Instruction( @@ -760,8 +763,9 @@ void TracedScheduleNode::EnterPostproc() { /*outputs=*/{})); } -void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) { +void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, + const ffi::String& buf_type, + const ffi::Array& buf_index_array) { ConcreteScheduleNode::UnsafeHideBufferAccess(block_rv, buf_type, buf_index_array); static const InstructionKind& kind = InstructionKind::Get("UnsafeHideBufferAccess"); trace_->Append(/*inst=*/Instruction( diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 024c3fb873f2..cf9e53a3a78d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -38,64 +38,69 @@ class TracedScheduleNode : public ConcreteScheduleNode { ~TracedScheduleNode() = default; public: - Optional trace() const final { return trace_; } + ffi::Optional trace() const final { return trace_; } Schedule Copy() final; public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = std::nullopt) final; - Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = std::nullopt) final; - Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, - int innerpart_factor, - Optional> decision = std::nullopt) final; + ExprRV SampleCategorical(const ffi::Array& candidates, const ffi::Array& probs, + ffi::Optional decision = std::nullopt) final; + ffi::Array SamplePerfectTile( + const LoopRV& loop_rv, int n, int max_innermost_factor, + ffi::Optional> decision = std::nullopt) final; + ffi::Array SamplePartitionedTile( + const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, + ffi::Optional> decision = std::nullopt) final; LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = std::nullopt) final; + ffi::Optional decision = std::nullopt) final; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const Optional& func_name) final; - Array GetLoops(const BlockRV& block_rv) final; - Array GetChildBlocks(const BlockRV& block_rv) final; - Array GetChildBlocks(const LoopRV& loop_rv) final; - Array GetProducers(const BlockRV& block_rv) final; - Array GetConsumers(const BlockRV& block_rv) final; - Array GetOutputBlocks(const BlockRV& scope_block_rv) final; + BlockRV GetBlock(const ffi::String& name, const ffi::Optional& func_name) final; + ffi::Array GetLoops(const BlockRV& block_rv) final; + ffi::Array GetChildBlocks(const BlockRV& block_rv) final; + ffi::Array GetChildBlocks(const LoopRV& loop_rv) final; + ffi::Array GetProducers(const BlockRV& block_rv) final; + ffi::Array GetConsumers(const BlockRV& block_rv) final; + ffi::Array GetOutputBlocks(const BlockRV& scope_block_rv) final; /******** Schedule: Transform loops ********/ - LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; - LoopRV Merge(const Array& loop_rvs) final; - Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters, bool disable_predication) final; - Array LoopPartition(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) final; - void Reorder(const Array& ordered_loop_rvs) final; - void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; + LoopRV Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) final; + LoopRV Merge(const ffi::Array& loop_rvs) final; + ffi::Array Split(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters, bool disable_predication) final; + ffi::Array LoopPartition(const LoopRV& loop_rv, + const ffi::Array>& factor_rvs, + bool preserve_unit_iters) final; + void Reorder(const ffi::Array& ordered_loop_rvs) final; + void ReorderBlockIterVar(const BlockRV& block_rv, const ffi::Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; - void Bind(const LoopRV& loop_rv, const String& thread_axis) final; + void Bind(const LoopRV& loop_rv, const ffi::String& thread_axis) final; void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ - BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) final; - BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const Array consumer_blocks = {}) final; + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) final; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const ffi::String& storage_scope, + const ffi::Array consumer_blocks = {}) final; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) final; + const ffi::String& storage_scope, const IndexMap& index_map) final; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map) final; - Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope, const IndexMap& index_map) final; + ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const ffi::String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; - Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, - int cse_thresh) final; + ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, + int cse_thresh) final; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope) final; BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) final; + const ffi::String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; @@ -109,35 +114,36 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; - void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; - void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final; + void SetScope(const BlockRV& block_rv, int buffer_index, const ffi::String& storage_scope) final; + void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const ffi::String& dtype) final; /******** Schedule: Blockize & Tensorize ********/ BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; - BlockRV Blockize(const Array& blocks, bool preserve_unit_iters) final; - void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; - void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final; + BlockRV Blockize(const ffi::Array& blocks, bool preserve_unit_iters) final; + void Tensorize(const BlockRV& block_rv, const ffi::String& intrin, + bool preserve_unit_iters) final; + void Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) final; /******** Schedule: Annotation ********/ - void Annotate(const LoopRV& loop_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; - void Annotate(const BlockRV& block_rv, const String& ann_key, const Any& ann_val) override; - void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + void Annotate(const LoopRV& loop_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const LoopRV& loop_rv, const ffi::String& ann_key) override; + void Annotate(const BlockRV& block_rv, const ffi::String& ann_key, const Any& ann_val) override; + void Unannotate(const BlockRV& block_rv, const ffi::String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map, const Optional& pad_value, + const IndexMap& index_map, const ffi::Optional& pad_value, bool assume_injective_transform) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const Array& axis_separators) final; + const ffi::Array& axis_separators) final; /******** Schedule: Padding ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; - void PadEinsum(const BlockRV& block_rv, const Array& padding) final; + void PadEinsum(const BlockRV& block_rv, const ffi::Array& padding) final; /******** Schedule: Buffer transformation ********/ void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; - void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, - const Array& buf_index_array) final; + void UnsafeHideBufferAccess(const BlockRV& block_rv, const ffi::String& buf_type, + const ffi::Array& buf_index_array) final; void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 256f44e14894..032365e9f592 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -27,18 +27,19 @@ namespace tir { /******** Annotation ********/ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { - Map annotations = block->annotations; +Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value) { + ffi::Map annotations = block->annotations; annotations.Set(attr_key, attr_value); - ObjectPtr new_block = make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->annotations = std::move(annotations); return Block(new_block); } /******** Buffer Related ********/ -Buffer WithScope(const Buffer& buffer, const String& scope) { - ObjectPtr new_buffer = make_object(*buffer.get()); - ObjectPtr new_var = make_object(*buffer->data.get()); +Buffer WithScope(const Buffer& buffer, const ffi::String& scope) { + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); + ObjectPtr new_var = ffi::make_object(*buffer->data.get()); const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_var->type_annotation = PointerType(ptr_type->element_type, scope); new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); @@ -47,7 +48,7 @@ Buffer WithScope(const Buffer& buffer, const String& scope) { } Buffer WithDType(const Buffer& buffer, const DataType& dtype) { - ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_buffer = ffi::make_object(*buffer.get()); new_buffer->dtype = dtype; const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); new_buffer->data = @@ -56,11 +57,11 @@ Buffer WithDType(const Buffer& buffer, const DataType& dtype) { return Buffer(new_buffer); } -Array ReplaceBuffer(Array regions, const Buffer& source, - const Buffer& target) { +ffi::Array ReplaceBuffer(ffi::Array regions, const Buffer& source, + const Buffer& target) { regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { if (region->buffer.same_as(source)) { - ObjectPtr n = make_object(*region.get()); + ObjectPtr n = ffi::make_object(*region.get()); n->buffer = target; return BufferRegion(n); } @@ -69,11 +70,11 @@ Array ReplaceBuffer(Array regions, const Buffer& sou return regions; } -Array ReplaceBuffer(Array regions, - const Map& buffer_map) { +ffi::Array ReplaceBuffer(ffi::Array regions, + const ffi::Map& buffer_map) { regions.MutateByApply([&buffer_map](BufferRegion region) -> BufferRegion { if (buffer_map.count(region->buffer)) { - ObjectPtr n = make_object(*region.get()); + ObjectPtr n = ffi::make_object(*region.get()); n->buffer = buffer_map[region->buffer]; return BufferRegion(n); } @@ -82,22 +83,24 @@ Array ReplaceBuffer(Array regions, return regions; } -Array ReplaceBuffer(Array match_buffers, const Buffer& source, - const Buffer& target) { - match_buffers.MutateByApply([&source, - &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { - if (match_buffer->source->buffer.same_as(source)) { - ObjectPtr n = make_object(*match_buffer.get()); - n->source = BufferRegion(target, n->source->region); - return MatchBufferRegion(n); - } - return match_buffer; - }); +ffi::Array ReplaceBuffer(ffi::Array match_buffers, + const Buffer& source, const Buffer& target) { + match_buffers.MutateByApply( + [&source, &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source)) { + ObjectPtr n = + ffi::make_object(*match_buffer.get()); + n->source = BufferRegion(target, n->source->region); + return MatchBufferRegion(n); + } + return match_buffer; + }); return match_buffers; } -Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, - const BufferRegion& target) { +ffi::Array ReplaceBufferRegion(ffi::Array regions, + const Buffer& source_buffer, + const BufferRegion& target) { regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion { if (region->buffer.same_as(source_buffer)) { return target; @@ -107,30 +110,31 @@ Array ReplaceBufferRegion(Array regions, const Buffe return regions; } -Array ReplaceBufferRegion(Array match_buffers, - const Buffer& source_buffer, - const BufferRegion& target) { - match_buffers.MutateByApply([&source_buffer, &target]( - const MatchBufferRegion& match_buffer) -> MatchBufferRegion { - if (match_buffer->source->buffer.same_as(source_buffer)) { - ObjectPtr n = make_object(*match_buffer.get()); - n->source = target; - return MatchBufferRegion(n); - } - return match_buffer; - }); +ffi::Array ReplaceBufferRegion(ffi::Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target) { + match_buffers.MutateByApply( + [&source_buffer, &target](const MatchBufferRegion& match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source_buffer)) { + ObjectPtr n = + ffi::make_object(*match_buffer.get()); + n->source = target; + return MatchBufferRegion(n); + } + return match_buffer; + }); return match_buffers; } /******** ReplaceBufferMutator ********/ ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse) + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer); } -ReplaceBufferMutator::ReplaceBufferMutator(const Map& buffer_map, - Map* block_sref_reuse) +ReplaceBufferMutator::ReplaceBufferMutator(const ffi::Map& buffer_map, + ffi::Map* block_sref_reuse) : block_sref_reuse_(block_sref_reuse) { for (const auto& [old_buffer, new_buffer] : buffer_map) { buffer_var_map_[old_buffer->data.get()] = new_buffer; @@ -139,7 +143,7 @@ ReplaceBufferMutator::ReplaceBufferMutator(const Map& buffer_map PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) { auto it = buffer_var_map_.find(var); - return it != buffer_var_map_.end() ? it->second->data : GetRef(var); + return it != buffer_var_map_.end() ? it->second->data : ffi::GetRef(var); } Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) { @@ -203,12 +207,12 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { }; // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion, - Array match_buffers = block->match_buffers.Map(f_mutate_match_buffer); + ffi::Array match_buffers = block->match_buffers.Map(f_mutate_match_buffer); // Step 2. Mutate the read/write region. - Array reads = block->reads.Map(f_mutate_read_write_region); - Array writes = block->writes.Map(f_mutate_read_write_region); + ffi::Array reads = block->reads.Map(f_mutate_read_write_region); + ffi::Array writes = block->writes.Map(f_mutate_read_write_region); // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. - Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); + ffi::Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); // Step 4. Recursively mutate the block. Block mutated_block = Downcast(StmtMutator::VisitStmt_(block)); @@ -216,7 +220,7 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { writes.same_as(mutated_block->writes) && alloc_buffers.same_as(mutated_block->alloc_buffers) && match_buffers.same_as(mutated_block->match_buffers)) { - return GetRef(block); + return ffi::GetRef(block); } else { ObjectPtr n = CopyOnWrite(mutated_block.get()); n->reads = std::move(reads); @@ -226,7 +230,7 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { Block new_block(n); if (block_sref_reuse_ != nullptr) { - block_sref_reuse_->Set(GetRef(block), new_block); + block_sref_reuse_->Set(ffi::GetRef(block), new_block); } return new_block; } @@ -241,17 +245,17 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} - String FastErrorString() const final { + ffi::String FastErrorString() const final { return "ScheduleError: Cannot remove the only leaf in the scope"; } - String DetailRenderTemplate() const final { + ffi::String DetailRenderTemplate() const final { return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " "scope will be empty."; } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } + ffi::Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } IRModule mod_; Block leaf_block_; @@ -295,21 +299,21 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* seq = body.as()) { - ObjectPtr n = make_object(*block); - auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + ObjectPtr n = ffi::make_object(*block); + auto new_seq = RemoveFromSeqStmt(ffi::GetRef(seq), ffi::GetRef(last_stmt)); // Re-attach AllocateConst nodes auto new_body = MergeNest(allocs, new_seq); n->body = new_body; - *src_stmt = GetRef(block); + *src_stmt = ffi::GetRef(block); *tgt_stmt = Stmt(std::move(n)); return; } } if (const auto* loop = sref->StmtAs()) { if (const auto* seq = loop->body.as()) { - ObjectPtr n = make_object(*loop); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); - *src_stmt = GetRef(loop); + ObjectPtr n = ffi::make_object(*loop); + n->body = RemoveFromSeqStmt(ffi::GetRef(seq), ffi::GetRef(last_stmt)); + *src_stmt = ffi::GetRef(loop); *tgt_stmt = Stmt(std::move(n)); return; } @@ -317,12 +321,12 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ ICHECK(sref != nullptr && sref->stmt != nullptr); const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); const auto* scope_block = TVM_SREF_TO_BLOCK(sref); - throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); + throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); } -Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name, bool allow_padding) { - Optional opt_tensorize_info = +ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const ffi::String& intrin_name, bool allow_padding) { + ffi::Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); if (!opt_tensorize_info) return std::nullopt; @@ -342,7 +346,7 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block sch->PadEinsum(block_rv, info->block_iter_paddings.value()); // Now we need to find out all the padded Block's. - Array inlined_producers, inlined_consumers; + ffi::Array inlined_producers, inlined_consumers; for (const auto& producer : sch->GetProducers(block_rv)) { // PadEinsum will not modify the producer if it does not need padding. if (original_producers.count(sch->GetSRef(producer).get())) { @@ -387,9 +391,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block } } // Construct a mapping from tir loops back to LoopRVs - Map loop2rv; + ffi::Map loop2rv; { - Array loop_rvs = sch->GetLoops(block_rv); + ffi::Array loop_rvs = sch->GetLoops(block_rv); for (const LoopRV& loop_rv : loop_rvs) { loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); } @@ -417,17 +421,18 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block ICHECK_EQ(total % inner, 0); // Do the split. Leave the outer extent as std::nullopt (unspecified) so that the split factors // can be used for different extents (needed during tuning). - Array split = sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); + ffi::Array split = + sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized - int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)).IntValue(); + int desc_loop_index = info->desc_loop_indexer.at(ffi::GetRef(desc_loop)).IntValue(); reorder_suffix[desc_loop_index] = split[1]; } // Reorder the loops std::vector reorder_list; bool meet = false; - Array all_loops = sch->GetLoops(block_rv); + ffi::Array all_loops = sch->GetLoops(block_rv); for (const LoopRV& loop : all_loops) { if (inner_loops.count(sch->GetSRef(loop).operator->())) { meet = true; @@ -447,10 +452,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); /******** BlockBufferAccessSimplifier ********/ -void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { +void BlockBufferAccessSimplifier::SimplifyAccessRegion( + ffi::Array* old_access_regions) { auto fmutate = [this](const BufferRegion& buffer_region) { - Array new_buffer_region; - Array simplified_min; + ffi::Array new_buffer_region; + ffi::Array simplified_min; for (const auto& range : buffer_region->region) { simplified_min.push_back(range->min); } @@ -466,7 +472,7 @@ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_ (*old_access_regions).MutateByApply(fmutate); } -void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array* indices) { +void BlockBufferAccessSimplifier::SimplifyBufferIndices(ffi::Array* indices) { *indices = this->IterMapSimplifyWithContext(*indices, true); } @@ -492,8 +498,8 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { /******** PrimFunc-level analysis and transformation ********/ -void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array* leaf_blocks) { - Array blocks = sch->GetChildBlocks(cur_block_rv); +void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, ffi::Array* leaf_blocks) { + ffi::Array blocks = sch->GetChildBlocks(cur_block_rv); if (blocks.empty()) { leaf_blocks->push_back(cur_block_rv); } else { @@ -503,14 +509,14 @@ void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array* lea } } -Optional NormalizePrimFunc(Schedule sch) { +ffi::Optional NormalizePrimFunc(Schedule sch) { BlockRV root_block = sch->GetBlock("root"); - Array leaf_blocks; + ffi::Array leaf_blocks; GetLeafBlocksHelper(sch, root_block, &leaf_blocks); for (const BlockRV& block : leaf_blocks) { StmtSRef block_sref = sch->GetSRef(block); - Array loops = GetLoops(block_sref); - Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; + ffi::Array loops = GetLoops(block_sref); + ffi::Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; if (loops.size() == 0) continue; if (loops.size() != binds.size()) { return std::nullopt; @@ -526,14 +532,14 @@ Optional NormalizePrimFunc(Schedule sch) { } } - Array> block_loops; - Array> block_iters; - Array block_is_reduction; + ffi::Array> block_loops; + ffi::Array> block_iters; + ffi::Array block_is_reduction; for (const BlockRV& block : leaf_blocks) { - Array iters = sch->Get(block)->iter_vars; + ffi::Array iters = sch->Get(block)->iter_vars; bool has_spatial_iter = false; - Array index_map_inputs; - Array index_map_outputs; + ffi::Array index_map_inputs; + ffi::Array index_map_outputs; for (const IterVar& iter : sch->Get(block)->iter_vars) { Var var = iter->var.copy_with_suffix(""); index_map_inputs.push_back(var); @@ -559,7 +565,7 @@ Optional NormalizePrimFunc(Schedule sch) { sch->GetSRef(root_block)); block_is_reduction.push_back(Bool(is_reduction)); } - return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; + return ffi::Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 73d6a0d85371..6e26f48320db 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -41,7 +41,8 @@ namespace tir { * \param attr_value The annotation value to be added * \return A new block with the given annotation as its last annotation */ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); +Block WithAnnotation(const BlockNode* block, const ffi::String& attr_key, + const ObjectRef& attr_value); /******** Buffer Related ********/ @@ -51,7 +52,7 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec * \param scope The target storage scope. * \return The new buffer with target storage scope. */ -Buffer WithScope(const Buffer& buffer, const String& scope); +Buffer WithScope(const Buffer& buffer, const ffi::String& scope); /*! * \brief Create a new buffer by changint the data type. @@ -68,8 +69,8 @@ Buffer WithDType(const Buffer& buffer, const DataType& dtype); * \param target The buffer to be replaced to * \return The new sequence of regions after replacement */ -Array ReplaceBuffer(Array regions, const Buffer& source, - const Buffer& target); +ffi::Array ReplaceBuffer(ffi::Array regions, const Buffer& source, + const Buffer& target); /*! * \brief Replaces the buffer within the specific sequence of regions @@ -77,8 +78,8 @@ Array ReplaceBuffer(Array regions, const Buffer& sou * \param buffer_map The mapping from old buffers to new buffers * \return The new sequence of regions after replacement */ -Array ReplaceBuffer(Array regions, - const Map& buffer_map); +ffi::Array ReplaceBuffer(ffi::Array regions, + const ffi::Map& buffer_map); /*! * \brief Replaces the buffer within the specific sequence of match_buffers @@ -87,8 +88,8 @@ Array ReplaceBuffer(Array regions, * \param target The buffer to be replaced to * \return The new sequence of match_buffers after replacement */ -Array ReplaceBuffer(Array match_buffers, const Buffer& source, - const Buffer& target); +ffi::Array ReplaceBuffer(ffi::Array match_buffers, + const Buffer& source, const Buffer& target); /*! * \brief Replaces the buffer region within the specific sequence of regions @@ -97,8 +98,9 @@ Array ReplaceBuffer(Array match_buffers, c * \param target The buffer region to be replaced to * \return The new sequence of regions after replacement */ -Array ReplaceBufferRegion(Array regions, const Buffer& source_buffer, - const BufferRegion& target); +ffi::Array ReplaceBufferRegion(ffi::Array regions, + const Buffer& source_buffer, + const BufferRegion& target); /*! * \brief Replaces the buffer region within the specific sequence of match_buffers @@ -107,9 +109,9 @@ Array ReplaceBufferRegion(Array regions, const Buffe * \param target The buffer region to be replaced to * \return The new sequence of match_buffers after replacement */ -Array ReplaceBufferRegion(Array match_buffers, - const Buffer& source_buffer, - const BufferRegion& target); +ffi::Array ReplaceBufferRegion(ffi::Array match_buffers, + const Buffer& source_buffer, + const BufferRegion& target); /*! * \brief A helper mutator which recursively replaces the old buffer with the new buffer and @@ -129,9 +131,10 @@ class ReplaceBufferMutator : public StmtExprMutator { * sref. */ ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, - Map* block_sref_reuse); + ffi::Map* block_sref_reuse); - ReplaceBufferMutator(const Map& buffer_map, Map* block_sref_reuse); + ReplaceBufferMutator(const ffi::Map& buffer_map, + ffi::Map* block_sref_reuse); protected: using StmtExprMutator::VisitExpr_; @@ -162,7 +165,7 @@ class ReplaceBufferMutator : public StmtExprMutator { */ std::unordered_map buffer_var_map_; /*! \brief The block sref reuse map for the following replacement */ - Map* block_sref_reuse_; + ffi::Map* block_sref_reuse_; }; /******** Block Removal ********/ @@ -214,8 +217,10 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * \return LoopRV corresponding to the outermost loop of a * block tiled according to the given intrin, std::nullopt if a valid loop mapping is not found */ -Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name, bool allow_padding = false); +ffi::Optional TileWithTensorIntrin(const tir::Schedule& sch, + const tir::BlockRV& block_rv, + const ffi::String& intrin_name, + bool allow_padding = false); /******** Block mutation ********/ @@ -242,8 +247,8 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - void SimplifyAccessRegion(Array* old_access_regions); - void SimplifyBufferIndices(Array* indices); + void SimplifyAccessRegion(ffi::Array* old_access_regions); + void SimplifyBufferIndices(ffi::Array* indices); Stmt VisitStmt_(const BlockNode* op) final; Stmt VisitStmt_(const BufferStoreNode* op) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 0c35c5f043a2..cd48cb13d5aa 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -56,12 +56,12 @@ namespace tir { * \param loop_srefs The loop StmtSRefs to be converted * \return The conversion result loops */ -inline Array LoopSRefs2Loops(const Array& loop_srefs) { - Array loops; +inline ffi::Array LoopSRefs2Loops(const ffi::Array& loop_srefs) { + ffi::Array loops; loops.reserve(loop_srefs.size()); for (StmtSRef loop_sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - loops.push_back(GetRef(loop)); + loops.push_back(ffi::GetRef(loop)); } return loops; } @@ -72,8 +72,9 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { * \param block_rvs The random variables to be converted * \return The conversion result srefs */ -inline Array BlockRVs2StmtSRefs(const Schedule& sch, const Array& block_rvs) { - Array block_srefs; +inline ffi::Array BlockRVs2StmtSRefs(const Schedule& sch, + const ffi::Array& block_rvs) { + ffi::Array block_srefs; block_srefs.reserve(block_rvs.size()); for (const BlockRV& block_rv : block_rvs) { block_srefs.push_back(sch->GetSRef(block_rv)); @@ -110,7 +111,7 @@ inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scop */ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { ICHECK_GT(seq->size(), 1); - Array new_stmts; + ffi::Array new_stmts; new_stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { if (to_remove.same_as(stmt)) { @@ -132,7 +133,7 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { * \return If the Stmt is SeqStmt, then returns the sequence; * Otherwise, returns a single-element Array with the Stmt inside. */ -inline Array AsArray(const Stmt& stmt) { +inline ffi::Array AsArray(const Stmt& stmt) { if (const auto* seq_stmt = stmt.as()) { return seq_stmt->seq; } @@ -160,7 +161,7 @@ inline bool IsSingleStmt(const Stmt& stmt) { * \param iter_var_type The type of the new IterVar * \return The newly created IterVar */ -inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { +inline IterVar IterVarFromLoop(const For& loop, ffi::String name, IterVarType iter_var_type) { return IterVar(Range::FromMinExtent(loop->min, loop->extent), Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } @@ -221,10 +222,11 @@ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { * \return The single variable in the expression, or std::nullopt if the expression is neither a * variable or a constant shift from a variable */ -inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { +inline ffi::Optional AnalyzeVarWithShift(const PrimExpr& expr, + ffi::Optional* constant) { if (const auto* var = expr.as()) { *constant = std::nullopt; - return GetRef(var); + return ffi::GetRef(var); } arith::PVar var; arith::PVar shift; @@ -252,8 +254,8 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* * \return std::nullopt if not found; otherwise the annotation value */ template -inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { - const Map* annotations = &stmt->annotations; +inline ffi::Optional GetAnn(const TStmtNode* stmt, const ffi::String& ann_key) { + const ffi::Map* annotations = &stmt->annotations; for (const auto& ann : *annotations) { if (ann.first == ann_key) { return Downcast(ann.second); @@ -270,7 +272,7 @@ inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) * \return std::nullopt if not found; otherwise the annotation value */ template -inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { +inline ffi::Optional GetAnn(const StmtSRef& sref, const ffi::String& ann_key) { if (const auto* loop = sref->StmtAs()) { return GetAnn(loop, ann_key); } else if (const auto* block = sref->StmtAs()) { @@ -288,8 +290,8 @@ inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) * \param ann_val The annotation value to be checked * \return Whether a Block/For has a specific pair of annotation key and values */ -inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) { - Optional result = GetAnn(sref, ann_key); +inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, const ffi::String& ann_val) { + ffi::Optional result = GetAnn(sref, ann_key); return result.has_value() && result.value() == ann_val; } @@ -300,8 +302,8 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& an * \param ann_val The boolean annotation value to be checked * \return Whether a Block/For has a specific pair of annotation key and values */ -inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { - Optional result = GetAnn(sref, ann_key); +inline bool HasAnn(const StmtSRef& sref, const ffi::String& ann_key, bool ann_val) { + ffi::Optional result = GetAnn(sref, ann_key); return result.defined() && result.value() == ann_val; } @@ -319,13 +321,13 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { - Array loops = sch->GetLoops(block_rv); - Array loop_srefs; + ffi::Array loops = sch->GetLoops(block_rv); + ffi::Array loop_srefs; for (const tir::LoopRV& loop_rv : loops) { loop_srefs.push_back(sch->GetSRef(loop_rv)); } - Array new_order; + ffi::Array new_order; // Step 1. Add spatial loops. *num_spatial_loops = 0; for (size_t i = 0; i < loops.size(); ++i) { @@ -335,7 +337,7 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } // Step 2. Add reduction loops. - Array reduction_loops; + ffi::Array reduction_loops; for (size_t i = 0; i < loops.size(); ++i) { if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { new_order.push_back(loops[i]); @@ -366,7 +368,7 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl * \param buffer_index_type The BufferIndexType value to convert * \return The string representation of BufferIndexType */ -inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { +inline ffi::String BufferIndexType2Str(BufferIndexType buffer_index_type) { if (buffer_index_type == BufferIndexType::kRead) { return "read"; } else { @@ -409,8 +411,8 @@ inline bool HasBlock(const Schedule& sch, const std::string& block_name) { * \param rv_map The substitution map for variables. * \return The transformed objects. */ -Array TranslateInputRVs(const Array& inputs, - const std::unordered_map& rv_map); +ffi::Array TranslateInputRVs(const ffi::Array& inputs, + const std::unordered_map& rv_map); /*! * \brief Update the variable substitution map according to the new outputs. @@ -418,7 +420,7 @@ Array TranslateInputRVs(const Array& inputs, * \param new_outputs The new outputs of the same schedule instruction. * \param rv_map The substitution map for variables. */ -void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, +void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* rv_map); /*! @@ -427,7 +429,7 @@ void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_ * \param remove_postproc If postprocessing instructions are removed. * \return Number of instructions. */ -int GetNumValidInstructions(const Array& insts, bool remove_postproc); +int GetNumValidInstructions(const ffi::Array& insts, bool remove_postproc); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index 5cd2d6556572..310cb74e4ee6 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -40,12 +40,12 @@ class DeviceRegionAnnotater : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. - return GetRef(op); + return ffi::GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. - Stmt body = GetRef(op); + Stmt body = ffi::GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 5b9e005b7ea3..15365802e0c9 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -76,7 +76,7 @@ void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::stri Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, const Array& value, +void ArgBinder::BindArray(const ffi::Array& arg, const ffi::Array& value, const std::string& arg_name) { ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { @@ -223,7 +223,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); - Array conds; + ffi::Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; PrimExpr svalue = cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index 68cbbb677311..fad5e4d70222 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -79,7 +79,7 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, const Array& value, + void BindArray(const ffi::Array& arg, const ffi::Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer @@ -145,7 +145,7 @@ class ArgBinder { */ const std::vector& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { return def_handle_dtype_; } + const ffi::Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function @@ -158,7 +158,7 @@ class ArgBinder { /*! \brief Initialize nest */ std::vector init_nest_; /*! \brief handle data type in the defintiions */ - Map def_handle_dtype_; + ffi::Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; /*! \brief internal analyzer. */ diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 520f6e871200..2b4598a99fa7 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -40,7 +40,7 @@ namespace tir { class ParamsCollector : public StmtExprVisitor { public: - explicit ParamsCollector(const Map& constant_map) + explicit ParamsCollector(const ffi::Map& constant_map) : constant_map_(constant_map) {} std::vector CollectParams(tir::Stmt body) { this->VisitStmt(body); @@ -75,16 +75,16 @@ class ParamsCollector : public StmtExprVisitor { private: std::vector constant_list_; - Map constant_map_; + ffi::Map constant_map_; }; -PrimFunc BindParams(PrimFunc f, const Array& constants) { - Map constant_map; +PrimFunc BindParams(PrimFunc f, const ffi::Array& constants) { + ffi::Map constant_map; // Remove constants from the primfunc signature size_t num_constants = constants.size(); size_t start = f->params.size() - num_constants; - Array params; + ffi::Array params; for (unsigned i = 0; i < start; i++) { params.push_back(f->params[i]); } @@ -101,9 +101,9 @@ PrimFunc BindParams(PrimFunc f, const Array& constants) { // Allocate constants within the primfunc for (auto i : constant_list) { - auto var = GetRef(i); + auto var = ffi::GetRef(i); int ndim = constant_map[var]->ndim; - Array extents; + ffi::Array extents; for (int i = 0; i < ndim; i++) { int shape = constant_map[var]->shape[i]; @@ -126,7 +126,7 @@ PrimFunc BindParams(PrimFunc f, const Array& constants) { namespace transform { -Pass BindParams(const Array& constants) { +Pass BindParams(const ffi::Array& constants) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return BindParams(f, constants); }; diff --git a/src/tir/transforms/bind_target.cc b/src/tir/transforms/bind_target.cc index 46a40228eaa1..6e3b9ff853a4 100644 --- a/src/tir/transforms/bind_target.cc +++ b/src/tir/transforms/bind_target.cc @@ -71,7 +71,7 @@ class FunctionClassifierVisitor : public StmtExprVisitor { // Only analyze externally exposed functions as potential callers // since they represent the entry points where host/device calls originate for (const auto& [gvar, func] : mod->functions) { - bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); const auto* prim_func = func.as(); if (is_externally_exposed && prim_func != nullptr) { @@ -144,7 +144,7 @@ class CallSubstitutor : public StmtExprMutator { * \brief Constructor with function replacement mapping. * \param replacements Map from original GlobalVar to host-specific GlobalVar */ - explicit CallSubstitutor(const Map& replacements) + explicit CallSubstitutor(const ffi::Map& replacements) : replacements_(replacements) {} /*! @@ -212,7 +212,7 @@ class CallSubstitutor : public StmtExprMutator { /*! \brief Whether the current statement is under a GPU scope */ bool is_under_gpu_scope_ = false; /*! \brief Mapping from original functions to host-specific duplicates */ - Map replacements_; + ffi::Map replacements_; }; /*! @@ -238,7 +238,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { auto target_without_host = target.WithoutHost(); auto mod_copy_on_write = mod.CopyOnWrite(); - auto new_mod = GetRef(mod_copy_on_write); + auto new_mod = ffi::GetRef(mod_copy_on_write); // Step 1: Analyze function call patterns auto [host_called_global_vars, device_called_global_vars] = @@ -257,7 +257,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // 2.4 If the function is not called by any host or device, skip binding // Track duplicated functions for call replacement - Map host_function_replacements; + ffi::Map host_function_replacements; GlobalVarSupply gvar_supply(new_mod); for (auto [gvar, func] : mod->functions) { @@ -266,9 +266,10 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Skip non-PrimFunc entries continue; } - auto prim_func = GetRef(prim_func_node); + auto prim_func = ffi::GetRef(prim_func_node); - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = + prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { // Rule 1: If the function has a target, and the target has a host, and the function does not @@ -308,7 +309,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Create duplicate with host target for host callers host_func = WithAttr(std::move(host_func), tvm::attr::kTarget, target_host); - String host_func_name = gvar->name_hint + "_host"; + ffi::String host_func_name = gvar->name_hint + "_host"; GlobalVar host_gvar = gvar_supply->FreshGlobal(host_func_name, false); new_mod->Add(host_gvar, host_func); @@ -341,7 +342,8 @@ IRModule BindTarget(IRModule mod, const Target& target) { continue; } - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_externally_exposed = + prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_externally_exposed) { // Update calls in externally exposed functions to use host duplicates PrimFunc new_func = substitutor.Substitute(Downcast(func)); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 6d5537e7756e..c9ad70bf807a 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -58,12 +58,13 @@ class BoundCollector : public StmtVisitor { StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map> mem_to_shape; + std::unordered_map> mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker(const std::unordered_map>& mem_to_shape) + explicit BoundChecker( + const std::unordered_map>& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -95,13 +96,13 @@ class BoundChecker : public StmtExprMutator { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = Evaluate(1); - Stmt then_case = GetRef(op); + Stmt then_case = ffi::GetRef(op); Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); Stmt body = IfThenElse(condition, then_case, else_case); return body; } } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -116,7 +117,7 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, Array new_shape, const DataType& type) { + void Update(const Var& buffer_var, ffi::Array new_shape, const DataType& type) { // Sanity check at first. if (!ShapeIsValid(new_shape)) { return; @@ -129,7 +130,7 @@ class BoundChecker : public StmtExprMutator { mem_to_shape_[buffer_var.get()] = new_shape; } - bool ShapeIsValid(const Array& shape) const { + bool ShapeIsValid(const ffi::Array& shape) const { if (!shape.defined()) { return false; } @@ -142,7 +143,7 @@ class BoundChecker : public StmtExprMutator { return true; } - bool IndicesAreValid(const Array& indices) const { + bool IndicesAreValid(const ffi::Array& indices) const { if (!indices.defined()) { return false; } @@ -176,12 +177,12 @@ class BoundChecker : public StmtExprMutator { return expr.defined() && expr.dtype().is_scalar(); } - bool CanInstrument(const Array& indices, const Var& buffer_var) const { + bool CanInstrument(const ffi::Array& indices, const Var& buffer_var) const { return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndicesAreValid(indices) && !unsafe_rewritten_; } - void Collect(Array indices, Var buffer_var) { + void Collect(ffi::Array indices, Var buffer_var) { store_scope_bound_collector_.push_back( std::make_pair(indices, mem_to_shape_[buffer_var.get()])); } @@ -189,8 +190,8 @@ class BoundChecker : public StmtExprMutator { PrimExpr MakeCondition() { PrimExpr condition; for (const auto& pair : store_scope_bound_collector_) { - Array indices = pair.first; - Array shape = pair.second; + ffi::Array indices = pair.first; + ffi::Array shape = pair.second; ICHECK_EQ(indices.size(), shape.size()) << "Mismatch between dimension of physical shape and physical indices"; @@ -200,7 +201,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr upper_bound = shape[i]; if (const RampNode* ramp_index = index.as()) { - index = arith::UnwrapVectorExpr(GetRef(ramp_index), ramp_index->lanes); + index = arith::UnwrapVectorExpr(ffi::GetRef(ramp_index), ramp_index->lanes); } // Try to simplify index and bound. @@ -226,11 +227,11 @@ class BoundChecker : public StmtExprMutator { // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. - std::vector, Array>> store_scope_bound_collector_; + std::vector, ffi::Array>> store_scope_bound_collector_; // Error message. const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map> mem_to_shape_; + std::unordered_map> mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 23c7d88d47c9..71f425c25048 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -150,8 +150,8 @@ Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { // Builds the variable name, which is cse_vi where i will go up from 1 std::string prefix = "cse_v"; std::string name = prefix.append(std::to_string(num_last_try_)); - // Builds a String using the std::string - String string_name(name); + // Builds a ffi::String using the std::string + ffi::String string_name(name); // Check that the name that we want to use for the new variable isn't already being used // (names don't really have to be unique as they are just hints, and having the same name @@ -280,11 +280,11 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { [](const std::pair& pair) { return pair.first; }; std::vector vector_vars_known = VectorMap(context_, forget_value); // 2.2 - Transform the std::vector into an Array - Array array_vars_known = Array(vector_vars_known); + ffi::Array array_vars_known = ffi::Array(vector_vars_known); // --- End of chunk needed for reusing the UndefinedVars() analysis --- // We use the UndefinedVars() analysis to get the undefined vars of the computation - Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); + ffi::Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); // Check if we can introduce it : if it contains no undefined variables and if we want // to introduce it according to the predicate @@ -375,7 +375,7 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { // If the `value` and the `body` of the let-in have been rewritten to the same thing if (value_new.same_as(op->value) && body_new.same_as(op->body)) { // then return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a let-in built with the new `value_new` and the new `body_new` that // have just been obtained @@ -460,11 +460,11 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { [](const std::pair& pair) { return pair.first; }; std::vector vector_vars_known = VectorMap(context_, forget_value); // 2.2 - Transform the std::vector into an Array - Array array_vars_known = Array(vector_vars_known); + ffi::Array array_vars_known = ffi::Array(vector_vars_known); // --- End of chunk needed for reusing the UndefinedVars() analysis --- // We use the UndefinedVars() analysis to get the undefined vars of the computation - Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); + ffi::Array vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); // Check if we can introduce it : if it contains no undefined variables and if we want // to introduce it according to the predicate @@ -556,7 +556,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) { // If the `value` and the `body` of the let-in have been rewritten to the same thing if (value_new.same_as(op->value) && body_new.same_as(op->body)) { // Return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a let-in built with the new `value_new` and the new `body_new` that // have just been obtained @@ -597,7 +597,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { // If the `min`, `extent` and `body` of the for loop have been rewritten to the same thing if (min_new.same_as(op->min) && extent_new.same_as(op->extent) && body_new.same_as(op->body)) { // Return a reference to the same node - return GetRef(op); + return ffi::GetRef(op); } else { // Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new` // that have just been obtained diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index f71d2cf42a02..1c52c6f97f5d 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -447,7 +447,7 @@ void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -482,7 +482,7 @@ void ComputationsDoneBy::VisitStmt_(const ForNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -512,7 +512,7 @@ void ComputationsDoneBy::VisitStmt_(const WhileNode* op) { // Copy the `table_of_computations_` into the cache // for the future queries - Stmt ref_to_op = GetRef(op); + Stmt ref_to_op = ffi::GetRef(op); cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; } @@ -646,7 +646,7 @@ void DirectSubexpr::VisitExpr(const PrimExpr& expr) { * \param var_name The variable name to check for * \return A boolean telling if `expr` uses `var_name` */ -bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) { +bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, ffi::String var_name) { UsesVarName uses_var_name(var_name); uses_var_name.VisitExpr(expr); @@ -659,7 +659,7 @@ bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) { * \param var_name The variable name to check for * \return A boolean telling if `stmt` uses `var_name` */ -bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) { +bool UsesVarName::StmtUsesVarName(const Stmt& stmt, ffi::String var_name) { UsesVarName uses_var_name(var_name); uses_var_name.VisitStmt(stmt); @@ -668,9 +668,9 @@ bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) { /*! * \brief Protected constructor of UsesVarName. - * \param var_name The String that we are looking for + * \param var_name The ffi::String that we are looking for */ -UsesVarName::UsesVarName(String var_name) : var_name_(var_name) {} +UsesVarName::UsesVarName(ffi::String var_name) : var_name_(var_name) {} /*! * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions. diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 31a81dabdbf2..ab1e76592a90 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -158,18 +158,18 @@ class DirectSubexpr : public ExprVisitor { class UsesVarName : public StmtExprVisitor { public: // Toplevel (static) methods - static bool ExprUsesVarName(const PrimExpr& expr, String var_name); - static bool StmtUsesVarName(const Stmt& stmt, String var_name); + static bool ExprUsesVarName(const PrimExpr& expr, ffi::String var_name); + static bool StmtUsesVarName(const Stmt& stmt, ffi::String var_name); protected: // Constructor - explicit UsesVarName(String var_name); + explicit UsesVarName(ffi::String var_name); void VisitExpr(const PrimExpr& expr) override; void VisitStmt(const Stmt& stmt) override; private: - String var_name_; + ffi::String var_name_; bool uses_var_name_ = false; }; diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a1e99313b663..713ddcad298c 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -49,9 +49,9 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate, arith::Analyzer* analyzer) { std::unordered_map var_dom; for (const auto& it : dom_map) { - var_dom[GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); + var_dom[ffi::GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); } - Optional> eval_res = + ffi::Optional> eval_res = arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer); if (eval_res.defined()) { @@ -146,7 +146,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef(op)); } + void VisitExpr_(const VarNode* op) final { VisitBufferVar(ffi::GetRef(op)); } void VisitStmt_(const ForNode* op) final { Range loop_range = Range::FromMinExtent(op->min, op->extent); @@ -243,10 +243,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } // Step 2. Record explicit read/write region annotations - auto record_explicit_region = [&](const String& attr_key, BufferIndexType index_type) { + auto record_explicit_region = [&](const ffi::String& attr_key, BufferIndexType index_type) { auto it = op->annotations.find(attr_key); if (it != op->annotations.end()) { - Array buffer_indices = Downcast>((*it).second); + ffi::Array buffer_indices = Downcast>((*it).second); for (const auto& index : buffer_indices) { int buffer_index = index->value; if (buffer_index >= 0 && buffer_index < static_cast(op->reads.size())) { @@ -430,9 +430,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { ICHECK(it != relaxed_accesses_.end()) << buffer << " is allocated but not accessed within block scope"; - const Array& original_shape = buffer->shape; + const ffi::Array& original_shape = buffer->shape; const NDIntSet& nd_int_set = it->second; - Array& result_region = buffer_access_region_[buffer]; + ffi::Array& result_region = buffer_access_region_[buffer]; result_region.resize(nd_int_set.size()); for (size_t i = 0; i < nd_int_set.size(); ++i) { @@ -566,7 +566,7 @@ class BufferCompactor : public StmtExprMutator { // Step 0. Check there is no Init part. ICHECK(!op->init.defined()); // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. - Array alloc_buffers = + ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buf) { return RewriteAllocBuffer(buf); }); // Step 2. Recursively rewrite BufferLoad/BufferStore. Block block = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -600,7 +600,7 @@ class BufferCompactor : public StmtExprMutator { if (op->dtype != new_buffer->dtype) { return allocate; } - Array new_shape = GetBufferAllocationShape(new_buffer); + ffi::Array new_shape = GetBufferAllocationShape(new_buffer); auto n = allocate.CopyOnWrite(); ICHECK(n->buffer_var.same_as(new_buffer->data)); n->extents = new_shape; @@ -615,7 +615,7 @@ class BufferCompactor : public StmtExprMutator { return buffer; } - void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) const { auto it = buffer_info_.find((*buffer)->data); if (it == buffer_info_.end()) { return; @@ -623,7 +623,7 @@ class BufferCompactor : public StmtExprMutator { const BufferAllocInfo& info = it->second; ICHECK_EQ(indices->size(), info.region.size()); int ndim = info.region.size(); - Array new_indices; + ffi::Array new_indices; new_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { new_indices.push_back((*indices)[i] - info.region[i]->min); @@ -650,8 +650,8 @@ class BufferCompactor : public StmtExprMutator { *region = std::move(new_region); } - void RewriteBufferRegions(Array* regions) const { - Array new_regions; + void RewriteBufferRegions(ffi::Array* regions) const { + ffi::Array new_regions; new_regions.reserve(regions->size()); for (const auto& region : *regions) { BufferRegion buffer_region = region; @@ -662,12 +662,12 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } - void RewriteMatchBuffers(Array* match_buffers) const { - Array result; + void RewriteMatchBuffers(ffi::Array* match_buffers) const { + ffi::Array result; result.reserve(match_buffers->size()); for (const auto& match_buffer : *match_buffers) { const BufferRegion& buffer_region = match_buffer->source; - auto p = make_object(*buffer_region.get()); + auto p = ffi::make_object(*buffer_region.get()); RewriteBufferRegion(&p->buffer, &p->region); result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); } @@ -678,7 +678,8 @@ class BufferCompactor : public StmtExprMutator { std::unordered_map buffer_info_; }; -Array CalcStrides(const BufferAllocInfo& alloc_info, const Array& shape) { +ffi::Array CalcStrides(const BufferAllocInfo& alloc_info, + const ffi::Array& shape) { std::vector strides; if (alloc_info.dim_aligns.size()) { ICHECK(alloc_info.dim_aligns.size() == shape.size()); @@ -725,9 +726,9 @@ Stmt BufferCompactorCompact( } // prepare new buffer - Array shape = region.Map([](const Range& range) { return range->extent; }); - Array strides = CalcStrides(alloc_info, shape); - ObjectPtr n = make_object(*buffer.get()); + ffi::Array shape = region.Map([](const Range& range) { return range->extent; }); + ffi::Array strides = CalcStrides(alloc_info, shape); + ObjectPtr n = ffi::make_object(*buffer.get()); n->shape = std::move(shape); n->strides = std::move(strides); alloc_info.new_buffer = Buffer(std::move(n)); diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index bd340df97e61..a359367ee70b 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -54,7 +54,7 @@ class OpaqueBlockConverter : public StmtExprMutator { if (it != var_substitutes_.end()) { return it->second; } - return GetRef(var); + return ffi::GetRef(var); } Stmt VisitStmt_(const BlockNode* block) final { @@ -74,7 +74,7 @@ class OpaqueBlockConverter : public StmtExprMutator { // Step 1. Visit the predicate and iter_values, without any variable bindings for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.insert(iter->var.get()); PrimExpr predicate = VisitExpr(realize->predicate); - Array iter_values = realize->iter_values; + ffi::Array iter_values = realize->iter_values; iter_values.MutateByApply([this](PrimExpr expr) { return VisitExpr(std::move(expr)); }); for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.erase(iter->var.get()); @@ -96,7 +96,7 @@ class OpaqueBlockConverter : public StmtExprMutator { // Step 5. Return if (predicate.same_as(realize->predicate) && iter_values.same_as(realize->iter_values) && new_block.same_as(realize->block) && realize->iter_values.size() == 0) { - return GetRef(realize); + return ffi::GetRef(realize); } else { return BlockRealize({}, predicate, new_block); } diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 5e1e5efa0e4c..2113136cf4cd 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -34,20 +34,20 @@ namespace transform { void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread_per_block, int64_t max_threadblocks = 256) { // fetch the loops - Array loops = sch->GetLoops(block); + ffi::Array loops = sch->GetLoops(block); for (const tir::LoopRV& loop : loops) { // skip block if already scheduled if (sch->Get(loop)->thread_binding.defined()) { return; } } - Array iters = sch->Get(block)->iter_vars; + ffi::Array iters = sch->Get(block)->iter_vars; // when there is no loops, tir will add a dummy iter var for the block // so loops.size() == 0 && iters.size() == 1 ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); - Array data_parallel_loops; + ffi::Array data_parallel_loops; // only fuse data parallel loops for (size_t i = 0; i < loops.size(); ++i) { if (iters[i]->iter_type == tir::IterVarType::kDataPar) { @@ -68,14 +68,14 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { - Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(max_threadblocks), Integer(max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { - Array splits = sch->Split( + ffi::Array splits = sch->Split( fused, /*factors=*/{std::nullopt, Integer(std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); @@ -83,11 +83,11 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread } IRModule MarkScheduled(const IRModule& mod) { - Map result; + ffi::Map result; for (const auto& [gv, base_func] : mod->functions) { if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = GetRef(prim_func_node); + tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); tir::PrimFunc new_prim_func = WithAttr(std::move(prim_func), tir::attr::kIsScheduled, true); result.Set(gv, new_prim_func); } else { @@ -105,7 +105,7 @@ bool IsScheduledOnGPU(const BaseFunc& func) { // the target from context. tvm::Target target = tvm::Target::Current(); // the Target in kTarget attribute of PrimFunc - Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); + ffi::Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); if (func_target.defined()) { target = func_target.value(); } @@ -131,7 +131,7 @@ Pass DefaultGPUSchedule() { // get the target from context. tvm::Target target = tvm::Target::Current(); // get the target from kTarget attribute - Optional func_target = + ffi::Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); if (func_target.defined()) { target = func_target.value(); @@ -139,14 +139,14 @@ Pass DefaultGPUSchedule() { ICHECK(target.defined()) << "The target is missing either in the current context or in " "the prim_func's attribute."; // get the max thread per block from target. - Optional opt_max_thread_per_block = + ffi::Optional opt_max_thread_per_block = target->GetAttr("max_num_threads"); ICHECK(opt_max_thread_per_block.defined()) << "max_num_threads is not set for target " << target; int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); sch->WorkOn(gv->name_hint); - Array blocks = meta_schedule::BlockCollector::Collect(sch); + ffi::Array blocks = meta_schedule::BlockCollector::Collect(sch); for (const tir::BlockRV& block : blocks) { auto childs = sch->GetChildBlocks(block); if (!childs.empty()) { diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 51cd08c7a877..404a16fadf05 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -36,7 +36,7 @@ namespace tvm { namespace tir { -using ConstArrayType = Array; +using ConstArrayType = ffi::Array; class Applicator : public tir::StmtMutator { protected: // returns index of the a in constant_array_, if not found - appends @@ -62,7 +62,7 @@ class Applicator : public tir::StmtMutator { // and add array index. ICHECK(acn->data) << "data field should be defined"; auto node = CopyOnWrite(acn); - node->irmod_storage_idx = Optional(Integer(DeDup(node->data.value()))); + node->irmod_storage_idx = ffi::Optional(Integer(DeDup(node->data.value()))); return Stmt(node); } @@ -75,7 +75,7 @@ tvm::transform::Pass ExtractPrimFuncConstants() { auto prim_func_pass = [=](PrimFunc foo, IRModule m, tvm::transform::PassContext ctx) { auto* func = foo.CopyOnWrite(); if (!m->attrs.defined()) { - m->attrs = DictAttrs(Map()); + m->attrs = DictAttrs(ffi::Map()); } auto* attrs = m->attrs.CopyOnWrite(); ConstArrayType constant_array_ = @@ -88,11 +88,11 @@ tvm::transform::Pass ExtractPrimFuncConstants() { if (constant_list.size()) { attrs->dict.Set(tvm::attr::kConstants, constant_list); } - return GetRef(func); + return ffi::GetRef(func); }; auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) { - auto m = GetRef(module.CopyOnWrite()); + auto m = ffi::GetRef(module.CopyOnWrite()); for (const auto& kv : m->functions) { if (auto func = kv.second.as()) { m->Update(kv.first, prim_func_pass(func.value(), m, pc)); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 1515bfadb59a..ffaa274e2871 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -65,21 +65,21 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; - Block block = GetRef(op); + Block block = ffi::GetRef(op); - Array alloc_buffers = op->alloc_buffers; + ffi::Array alloc_buffers = op->alloc_buffers; alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); if (!alloc_buffers.same_as(op->alloc_buffers)) { block.CopyOnWrite()->alloc_buffers = alloc_buffers; } - Array reads = op->reads; + ffi::Array reads = op->reads; reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); if (!reads.same_as(op->reads)) { block.CopyOnWrite()->reads = reads; } - Array writes = op->writes; + ffi::Array writes = op->writes; writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); if (!writes.same_as(op->writes)) { block.CopyOnWrite()->writes = writes; @@ -91,7 +91,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const AllocateNode* op) final { // Determine the flattened extents first, before stripping of // DeclBuffer. - auto new_extents = [&]() -> Array { + auto new_extents = [&]() -> ffi::Array { if (op->extents.size() == 1) { // No flattening required for buffers that are already flat return op->extents; @@ -219,7 +219,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } } - Array GetSimplifiedElemOffset(const Buffer& buffer, const Array& indices) { + ffi::Array GetSimplifiedElemOffset(const Buffer& buffer, + const ffi::Array& indices) { auto flattened_indices = buffer->ElemOffset(indices); return this->IterMapSimplifyWithContext(flattened_indices, false); } @@ -243,17 +244,17 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { return region; } - Array min_values; - Array max_values; + ffi::Array min_values; + ffi::Array max_values; for (const auto& range : region->region) { min_values.push_back(range->min); max_values.push_back(range->min + range->extent - 1); } - Array flattened_min = GetSimplifiedElemOffset(orig_buf, min_values); - Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); + ffi::Array flattened_min = GetSimplifiedElemOffset(orig_buf, min_values); + ffi::Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); - Array flattened_ranges; + ffi::Array flattened_ranges; ICHECK_EQ(flattened_min.size(), flattened_max.size()); for (size_t i = 0; i < flattened_min.size(); i++) { flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); @@ -266,7 +267,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { std::unordered_map buffer_remap_; /*! \brief The updated external buffer map. */ - Map updated_extern_buffer_map_; + ffi::Map updated_extern_buffer_map_; }; PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); } diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index d291e40f3c31..52d68460e8e3 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -56,7 +56,7 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return IntImm(DataType::Int(32), op->value); } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BlockNode* block) final { diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 1548ea1da625..1c9b5893ab69 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -89,7 +89,7 @@ struct HoistExpressionConfigNode : public AttrsNodeReflAdapter(); + auto node = ffi::make_object(); node->hoisted_conditionals = hoisted_conditionals; node->hoisted_let_bindings = hoisted_let_bindings; data_ = std::move(node); @@ -250,7 +250,7 @@ class HoistInfoCollector : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) final { - active_loops.push_back({op->loop_var, GetRef(op)}); + active_loops.push_back({op->loop_var, ffi::GetRef(op)}); active_loop_vars.insert(op->loop_var.get()); Parent::VisitStmt_(op); @@ -272,7 +272,7 @@ class HoistInfoCollector : public StmtExprVisitor { active_block_vars.insert(var.get()); active_loop_vars.insert(var.get()); - active_loops.push_back({var, GetRef(op)}); + active_loops.push_back({var, ffi::GetRef(op)}); Parent::VisitStmt_(op); diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 8ced9c82253d..50bbbcc6b2b3 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -123,7 +123,7 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; + ffi::Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; ICHECK(entry.loop != nullptr); auto& alloc_nest = loop_allocs_[entry.loop]; alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition, @@ -249,7 +249,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { ICHECK(!dbuffer_info_.count(op)); - return GetRef(op); + return ffi::GetRef(op); } private: diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index f90752e26418..b2433ee70a35 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -59,7 +59,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { + ffi::Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { ICHECK(permute_); // Index after vectorizing by 8 PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR), @@ -104,7 +104,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } static bool CheckAnnotation(const Any& annotation) { - if (auto opt_str = annotation.as()) { + if (auto opt_str = annotation.as()) { // Support string annotation for backward compatibility return *opt_str != ""; } else if (auto* node = annotation.as()) { @@ -165,7 +165,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return buffer_row_size; } - Array HandleBufferIndices(Buffer buffer, Array indices) { + ffi::Array HandleBufferIndices(Buffer buffer, ffi::Array indices) { auto buffer_row_size = CheckAndGetBufferRowSize(buffer); // Mutate the last two indices @@ -216,7 +216,8 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return load; } - PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional offset = std::nullopt) { + PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, + ffi::Optional offset = std::nullopt) { // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to // smem_offset CHECK(access_ptr->IsInstance()) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index f0a88ba98192..8abcabae4048 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -81,8 +81,8 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)}; + ffi::Array args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; // use arguments size to indicate whether or not to use predicated cp.async if (predicated) { args.push_back(predicate_value); diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index 848e8491945f..3713531cfa37 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -95,8 +95,8 @@ class PTXRewriter : public StmtMutator { BufferStore value_store(store->buffer, imm_value, {new_indice}); Evaluate ptx_load(Call(store->buffer->dtype, tvm::tir::builtin::ptx_ldg32(), {store->buffer->data, new_predicate, new_lhs, new_indice})); - Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, - ptx_load}; + ffi::Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, + ptx_load}; SeqStmt seq_stmt = SeqStmt(tmp_seq); return seq_stmt; } diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index a68308261a19..6fb4b94fdb0e 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -50,7 +50,7 @@ struct RollingBufferInfo { int rolling_axis; int rolling_extent; std::vector axis_overlaps; - std::vector> axis_iter_vars; + std::vector> axis_iter_vars; }; class RollingBufferInjector : public StmtExprMutator { @@ -70,7 +70,7 @@ class RollingBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { // Manage the stack of iter_vars - for_loops.push_back(GetRef(op)); + for_loops.push_back(ffi::GetRef(op)); auto stmt{StmtExprMutator::VisitStmt_(op)}; op = stmt.as(); @@ -82,7 +82,7 @@ class RollingBufferInjector : public StmtExprMutator { if (it != hoist_buffer_to_for.end()) { // If the loop corresponds to an iter_var that needs a BufferRealize // hoisting to its scope, perform the hoisting - Stmt body{GetRef(op)}; + Stmt body{ffi::GetRef(op)}; for (auto realise : it->second) { auto attrs{buffer_to_attrs[realise->buffer]}; Stmt new_realize{BufferRealize(realise->buffer, realise->bounds, realise->condition, body, @@ -108,7 +108,7 @@ class RollingBufferInjector : public StmtExprMutator { // Keep a dictionary associating attribute statements with the buffers // they reference. We'll need this if the buffer gets hoisted and we // need to hoist all of its attributes at the same time. - buffer_to_attrs[buffer].push_back(GetRef(op)); + buffer_to_attrs[buffer].push_back(ffi::GetRef(op)); if (op->attr_key == attr::rolling_buffer_scope && Downcast(op->value)->value) { // If the attribute is indicating that a buffer should be a rolling @@ -122,13 +122,13 @@ class RollingBufferInjector : public StmtExprMutator { // If a BufferRealize has been identified as needing to be made into // a rolling buffer, begin the analysis. - std::vector> bound_iter_vars{}; + std::vector> bound_iter_vars{}; std::vector bound_overlaps{}; // We use the bound information of the BufferRealize to calculate // how we can legally roll auto stride{0}; auto divisor{1}; - Optional iter_var{}; + ffi::Optional iter_var{}; for (auto bound : buffer_realize->bounds) { divisor = 1; if (auto floor_div = bound->min.as()) { @@ -143,7 +143,7 @@ class RollingBufferInjector : public StmtExprMutator { iter_var = nullptr; } else if (auto var = bound->min.as()) { // If the bound is just a Var, that implies the stride is 1 - iter_var = GetRef(var); + iter_var = ffi::GetRef(var); stride = 1; } else { // Otherwise, it's the iter var multiplied by the stride @@ -154,7 +154,7 @@ class RollingBufferInjector : public StmtExprMutator { ICHECK(a) << "Rolling buffer injection failed: the buffer striding is unsupported"; auto b = mul->b.as(); ICHECK(b) << "Rolling buffer injection failed: the buffer striding is unsupported"; - iter_var = GetRef(a); + iter_var = ffi::GetRef(a); stride = b->value; } stride = std::ceil(static_cast(stride) / divisor); @@ -167,7 +167,7 @@ class RollingBufferInjector : public StmtExprMutator { } // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis - Optional roll_iter_var{}; + ffi::Optional roll_iter_var{}; int roll_axis{1}; for (auto loop : for_loops) { auto loop_var{loop->loop_var}; @@ -175,7 +175,7 @@ class RollingBufferInjector : public StmtExprMutator { auto it{std::find_if( bound_iter_vars.begin(), bound_iter_vars.end(), - [&](Optional var) { return var && (var.get() == loop_var.get()); })}; + [&](ffi::Optional var) { return var && (var.get() == loop_var.get()); })}; if (it != bound_iter_vars.end()) { auto i{std::distance(bound_iter_vars.begin(), it)}; @@ -195,7 +195,7 @@ class RollingBufferInjector : public StmtExprMutator { bound_iter_vars, }; rolling_buffer_to_info[buffer] = rolling_buffer_info; - Array new_bounds{}; + ffi::Array new_bounds{}; auto shape{buffer->shape}; for (size_t i{0}; i < shape.size(); ++i) { auto extent{shape[i]}; @@ -225,7 +225,7 @@ class RollingBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const BufferRealizeNode* op) final { - buffer_to_buffer_realize.insert({op->buffer, GetRef(op)}); + buffer_to_buffer_realize.insert({op->buffer, ffi::GetRef(op)}); auto stmt{StmtExprMutator::VisitStmt_(op)}; op = stmt.as(); @@ -266,7 +266,7 @@ class RollingBufferInjector : public StmtExprMutator { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) { Var var{iter_var.value()}; - const Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; + const ffi::Map dmap{std::make_pair(var, IntSet::Interval(0, 0))}; auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()}; auto condition = Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])); buffer_store = IfThenElse(likely(condition), buffer_store); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b89d3b89fa82..340c21140253 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -48,7 +48,7 @@ namespace software_pipeline { * \param buffer_data_to_buffer The map from buffer data to buffer. * \return The result block. */ -Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) { +Block MakeBlock(const Stmt& body, const ffi::Map& buffer_data_to_buffer) { if (const BlockRealizeNode* block_realize = body.as()) { if (is_one(block_realize->predicate)) { // no need to create a new block @@ -56,7 +56,8 @@ Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) } } Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); - Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + ffi::Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer); BlockNode* n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; @@ -88,8 +89,8 @@ class PipelineOpaqueAccessRewriter { * \param fragment_info Information about tensor core fragment */ PipelineOpaqueAccessRewriter( - const Map& buffer_data_to_buffer, const Map& buffer_remap, - const For& pipeline_loop, + const ffi::Map& buffer_data_to_buffer, + const ffi::Map& buffer_remap, const For& pipeline_loop, const std::unordered_map& fragment_info) : buffer_data_to_buffer_(buffer_data_to_buffer), buffer_remap_(buffer_remap), @@ -109,13 +110,13 @@ class PipelineOpaqueAccessRewriter { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); auto it = buffer_remap_.find(buffer); if (it != buffer_remap_.end()) { - Array new_args = call->args; + ffi::Array new_args = call->args; const Buffer& new_buffer = (*it).second; new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); return Call(call->dtype, call->op, new_args, call->span); } } else if (call->op.same_as(mma_sync)) { - Array new_args = call->args; + ffi::Array new_args = call->args; for (int i = 0; i < 4; i++) { const Var& buffer_var = Downcast(call->args[i * 2]); const PrimExpr& index = call->args[i * 2 + 1]; @@ -160,11 +161,11 @@ class PipelineOpaqueAccessRewriter { } PrimExpr RewriteBufferAccess(const Call& call, const std::vector arg_indices) { - auto product = [](const Array& input) { + auto product = [](const ffi::Array& input) { return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), input); }; - Array new_args = call->args; + ffi::Array new_args = call->args; for (int i : arg_indices) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[i])); auto it = buffer_remap_.find(buffer); @@ -192,8 +193,8 @@ class PipelineOpaqueAccessRewriter { return Call(call->dtype, call->op, new_args, call->span); } - const Map& buffer_data_to_buffer_; - const Map& buffer_remap_; + const ffi::Map& buffer_data_to_buffer_; + const ffi::Map& buffer_remap_; const For& pipeline_loop_; const std::unordered_map& fragment_info_; }; @@ -215,8 +216,8 @@ class PipelineBodyRewriter : public StmtExprMutator { * of a two-stage software pipeline, only one version of these buffers are accessed. * \param fragment_info Information about tensor core fragment */ - PipelineBodyRewriter(const Map& buffer_data_to_buffer, - const Map& buffer_remap, For pipeline_loop, + PipelineBodyRewriter(const ffi::Map& buffer_data_to_buffer, + const ffi::Map& buffer_remap, For pipeline_loop, bool access_all_versions, const std::unordered_map& fragment_info) : buffer_data_to_buffer_(buffer_data_to_buffer), @@ -299,8 +300,8 @@ class PipelineBodyRewriter : public StmtExprMutator { return opaque_access_rewriter_.Rewrite(call); } - Map buffer_data_to_buffer_; - Map buffer_remap_; + ffi::Map buffer_data_to_buffer_; + ffi::Map buffer_remap_; For pipeline_loop_; bool access_all_versions_; PipelineOpaqueAccessRewriter opaque_access_rewriter_; @@ -312,24 +313,24 @@ class PipelineBodyRewriter : public StmtExprMutator { class PipelineRewriter : public StmtExprMutator { public: static Stmt Rewrite( - Map buffer_data_to_buffer, + ffi::Map buffer_data_to_buffer, const std::unordered_set& double_buffers, - const Array pipeline_allocs, const For& pipeline_loop, + const ffi::Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) { + const ffi::Map preserved_annotations) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, pipeline_info, fragment_info, preserved_annotations); return rewriter.BuildPipeline(); } private: - PipelineRewriter(Map buffer_data_to_buffer, + PipelineRewriter(ffi::Map buffer_data_to_buffer, const std::unordered_set& double_buffers, - const Array& pipeline_allocs, const For& pipeline_loop, + const ffi::Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) + const ffi::Map preserved_annotations) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -365,7 +366,7 @@ class PipelineRewriter : public StmtExprMutator { // introduce extra lowerbound when the loop length is smaller than num stages // to ensure the epilogue interval do not overlap the prologue interval. PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; - Optional extra_epilogue_lower_bound = std::nullopt; + ffi::Optional extra_epilogue_lower_bound = std::nullopt; if (max_stage_ > 1 && !analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { if (is_const_int(epigogue_start)) { epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); @@ -382,7 +383,7 @@ class PipelineRewriter : public StmtExprMutator { SeqStmt stmt = SeqStmt({prologue, body, epilogue}); // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. - Array alloc_buffers; + ffi::Array alloc_buffers; for (const auto& alloc : pipeline_allocs_) { alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); @@ -527,7 +528,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The resized buffer. */ Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { - ObjectPtr new_buffer = make_object(*(buffer.get())); + ObjectPtr new_buffer = ffi::make_object(*(buffer.get())); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (new_buffer->strides.size()) { ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); @@ -546,7 +547,7 @@ class PipelineRewriter : public StmtExprMutator { // async invocations exactly. When it is valid, it is the "sum of extents of loops that have // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This // is only needed to compute wait count for epilogue without async producers. - Optional producer_head{PrimExpr(-1)}; + ffi::Optional producer_head{PrimExpr(-1)}; bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } }; @@ -578,9 +579,9 @@ class PipelineRewriter : public StmtExprMutator { // A symbolic expression representing the index the latest async operation associated with this // stage has written into, at the "current" iteration. - Optional producer_head; + ffi::Optional producer_head; // The predicate of BlockRealize containing the async operation of this stage. - Optional predicate; + ffi::Optional predicate; // Indices into a list of blocks, where async_commit_queue scope should be attached. // If multiple async producers are interleaved with their consumer in between, we need separate // async_commit_queue for each producer. Thus, we need multiple sets of indices. @@ -670,7 +671,7 @@ class PipelineRewriter : public StmtExprMutator { auto& dep_local_state = (*async_states_local)[producer_stage_idx]; const auto num_commit_group = dep_local_state.commit_groups.size(); - std::vector> producer_head_per_commit; + std::vector> producer_head_per_commit; if (num_commit_group == 0) { // Epilogue, no async producer. Since "local" producer_head is not available, use @@ -728,7 +729,7 @@ class PipelineRewriter : public StmtExprMutator { // Given pipelined blocks and async-related information, generate final loop statements with async // scopes (if any). - Array CompletePipelineLoopStatements( + ffi::Array CompletePipelineLoopStatements( const std::vector& blocks, const std::map& async_states_local, arith::Analyzer* ana_normalized) const { @@ -768,7 +769,7 @@ class PipelineRewriter : public StmtExprMutator { } } - Array stmts; + ffi::Array stmts; for (size_t i = 0; i < new_blocks.size();) { if (commit_group_indices[i] == -1) { @@ -776,7 +777,7 @@ class PipelineRewriter : public StmtExprMutator { stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); ++i; } else { - Array group_bodies; + ffi::Array group_bodies; auto stage_id = commit_group_indices[i]; auto predicate = new_blocks[i].predicate; for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { @@ -812,7 +813,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, - Optional extra_loop_lower_bound = std::nullopt) { + ffi::Optional extra_loop_lower_bound = std::nullopt) { PrimExpr new_loop_var; PrimExpr extent = end - start; @@ -966,17 +967,17 @@ class PipelineRewriter : public StmtExprMutator { } arith::Analyzer analyzer_; - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; const std::unordered_set& double_buffers_; - Array pipeline_allocs_; + ffi::Array pipeline_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; const std::unordered_map& fragment_info_; int max_stage_ = -1; - Map buffer_remap_; - Array ordered_stmts_; + ffi::Map buffer_remap_; + ffi::Array ordered_stmts_; std::map async_states; - Map preserved_annotations_; + ffi::Map preserved_annotations_; }; /*! @@ -988,10 +989,10 @@ class PipelineRewriter : public StmtExprMutator { * destination to the source. */ void BuildDependencyGraph( - const Array& blocks, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { - std::unordered_map> buffer_writers; + const ffi::Array& blocks, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { + std::unordered_map> buffer_writers; for (const Block& block : blocks) { for (const BufferRegion& read : block->reads) { @@ -1016,7 +1017,7 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: static Stmt Inject(const PrimFunc& func) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); PipelineInjector injector(global_symbol); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -1027,7 +1028,8 @@ class PipelineInjector : private StmtExprMutator { } private: - explicit PipelineInjector(Optional global_symbol) : global_symbol_(global_symbol) {} + explicit PipelineInjector(ffi::Optional global_symbol) + : global_symbol_(global_symbol) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1037,7 +1039,8 @@ class PipelineInjector : private StmtExprMutator { * case 1: stage(A) < stage(B) * case 2: stage(A) == stage(B) and order(A) < order(B) */ - void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { + void ValidatePipelineBody(const PipelineInfo& pipeline_info, + const ffi::Array& original_order) { std::unordered_set used_orders; std::unordered_map stage_max_order; std::unordered_map order_to_block; @@ -1050,13 +1053,13 @@ class PipelineInjector : private StmtExprMutator { used_orders.insert(order); } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; BuildDependencyGraph(original_order, &dep_src2dst, nullptr); for (const auto& pair : dep_src2dst) { const Block& src = pair.first; const auto& src_info = pipeline_info.at(src); - const Array& dsts = pair.second; + const ffi::Array& dsts = pair.second; for (const Block& dst : dsts) { const auto& dst_info = pipeline_info.at(dst); CHECK_LE(src_info.stage, dst_info.stage) @@ -1081,7 +1084,7 @@ class PipelineInjector : private StmtExprMutator { // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the // child of the block. Stmt pipeline_body{nullptr}; - Array pipeline_allocs; + ffi::Array pipeline_allocs; if (const auto* realize = for_node->body.as()) { const auto& block = realize->block; for (const auto& buffer : block->alloc_buffers) { @@ -1102,7 +1105,7 @@ class PipelineInjector : private StmtExprMutator { // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be // converted into a block. PipelineInfo pipeline_info; - Array original_order; // pipeline body blocks in the original order + ffi::Array original_order; // pipeline body blocks in the original order auto f_add_child = [&](const Stmt& child) { original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); @@ -1128,9 +1131,9 @@ class PipelineInjector : private StmtExprMutator { } auto pipeline_stages = - Downcast>(op->annotations.at(attr::software_pipeline_stage)); + Downcast>(op->annotations.at(attr::software_pipeline_stage)); auto pipeline_orders = - Downcast>(op->annotations.at(attr::software_pipeline_order)); + Downcast>(op->annotations.at(attr::software_pipeline_order)); CHECK_EQ(pipeline_stages.size(), original_order.size()) << "PrimFunc " << global_symbol_ << " has original order " << original_order.Map([](const auto& block) { return block->name_hint; }) @@ -1142,14 +1145,14 @@ class PipelineInjector : private StmtExprMutator { std::unordered_set pipeline_async_stages; if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { - for (auto s : Downcast>(annot.value())) { + for (auto s : Downcast>(annot.value())) { pipeline_async_stages.insert(s->value); } } - Map preserved_annotations; + ffi::Map preserved_annotations; for (const auto& kv : op->annotations) { - const String& key = kv.first; + const ffi::String& key = kv.first; if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order && kv.first != attr::software_pipeline_async_stages) { preserved_annotations.Set(key, kv.second); @@ -1169,7 +1172,7 @@ class PipelineInjector : private StmtExprMutator { // Step 4: Rewrite the pipeline body. Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, - pipeline_allocs, GetRef(op), pipeline_info, + pipeline_allocs, ffi::GetRef(op), pipeline_info, fragment_info_, preserved_annotations); if (const auto* realize = op->body.as()) { @@ -1186,7 +1189,7 @@ class PipelineInjector : private StmtExprMutator { * \param n The block pointer to which the buffer allocations are added. * \param alloc_buffers The buffer allocations to be added. */ - void AddAllocBuffers(BlockNode* n, const Array alloc_buffers) { + void AddAllocBuffers(BlockNode* n, const ffi::Array alloc_buffers) { for (const Buffer& alloc_buffer : alloc_buffers) { n->alloc_buffers.push_back(alloc_buffer); Region region; @@ -1236,10 +1239,10 @@ class PipelineInjector : private StmtExprMutator { return false; } - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; - Optional global_symbol_; + ffi::Optional global_symbol_; }; } // namespace software_pipeline diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index d0f84842a4fe..9016ffdbf9fe 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -208,7 +208,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (touched_var_.count(op)) { visit_touched_var_ = true; } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { return analyzer_->Simplify(index + var_ * alloc_extent); @@ -229,7 +229,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::tvm_context_id())) { - return allow_share_ ? GetRef(op) : var_; + return allow_share_ ? ffi::GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); } @@ -287,14 +287,14 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const AttrStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return AttrStmt(op->node, op->attr_key, value, body); } @@ -304,12 +304,12 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const LetStmtNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -319,7 +319,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { ICHECK(is_zero(op->min)); PrimExpr extent = this->VisitExpr(op->extent); if (visit_touched_var_ && !vt_loop_injected_) { - Stmt stmt = InjectVTLoop(GetRef(op), true); + Stmt stmt = InjectVTLoop(ffi::GetRef(op), true); ++max_loop_depth_; return stmt; } @@ -327,7 +327,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt body = this->VisitStmt(op->body); ++max_loop_depth_; if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->extent = std::move(extent); @@ -339,12 +339,12 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const IfThenElseNode* op) final { PrimExpr condition = this->VisitExpr(op->condition); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; ICHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { int temp = max_loop_depth_; max_loop_depth_ = 0; @@ -353,7 +353,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -379,15 +379,15 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { - Allocate node = GetRef(op); + Allocate node = ffi::GetRef(op); PrimExpr condition = this->VisitExpr(op->condition); - Array extents = + ffi::Array extents = op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); if (visit_touched_var_ && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); + return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; @@ -417,7 +417,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Allocate(op->buffer_var, op->dtype, extents, condition, body); } @@ -439,7 +439,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // only unroll if number of vthreads are small if (max_loop_depth_ == 0 && num_threads_ < 16) { // do unrolling if it is inside innermost content. - Array seq; + ffi::Array seq; for (int i = 0; i < num_threads_; ++i) { seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); } diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index 8521607f893e..03d814333ca4 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -103,7 +103,7 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, // Only inline private functions. Externally-exposed functions // must be preserved so to avoid breaking callsites outside of // the IRModule. - bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return false; // We do not currently implement any analysis for termination of @@ -128,10 +128,10 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, return true; } -Map CollectInlinablePrimFuncs(const IRModule& mod) { +ffi::Map CollectInlinablePrimFuncs(const IRModule& mod) { auto recursive_functions = CollectRecursiveFunctions(mod); - Map output; + ffi::Map output; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto prim_func = opt.value(); @@ -146,7 +146,7 @@ Map CollectInlinablePrimFuncs(const IRModule& mod) { class PrimFuncInliner : StmtExprMutator { public: - explicit PrimFuncInliner(Map inlinable_funcs) + explicit PrimFuncInliner(ffi::Map inlinable_funcs) : inlinable_funcs_(inlinable_funcs) { for (const auto& [gvar, callee] : inlinable_funcs_) { removable_funcs_.insert(gvar); @@ -176,7 +176,7 @@ class PrimFuncInliner : StmtExprMutator { } } - Optional GetInlinedFunction(const EvaluateNode* eval) { + ffi::Optional GetInlinedFunction(const EvaluateNode* eval) { auto call = eval->value.as(); if (!call) return std::nullopt; @@ -222,7 +222,8 @@ class PrimFuncInliner : StmtExprMutator { return StmtExprMutator::VisitExpr_(call); } - Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array& args) const { + Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, + const ffi::Array& args) const { CHECK_EQ(callee->params.size(), args.size()) << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" << callee->params << "), but is called with " << args.size() << " arguments (" << args @@ -232,7 +233,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map> param_map; + ffi::Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } @@ -243,7 +244,7 @@ class PrimFuncInliner : StmtExprMutator { } // Map from GlobalVar to PrimFuncs which may be inlined. - Map inlinable_funcs_; + ffi::Map inlinable_funcs_; /* \brief Set of callees that may be removed * @@ -253,7 +254,7 @@ class PrimFuncInliner : StmtExprMutator { */ PSet removable_funcs_; - Optional current_target_ = std::nullopt; + ffi::Optional current_target_ = std::nullopt; }; } // namespace diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 3f94fb0cfc6e..cdebfcfcfa7a 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -41,48 +41,48 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; if (const auto* for_ = s.as()) { - auto n = make_object(*for_); + auto n = ffi::make_object(*for_); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* let = s.as()) { - auto n = make_object(*let); + auto n = ffi::make_object(*let); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* attr = s.as()) { - auto n = make_object(*attr); + auto n = ffi::make_object(*attr); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* ite = s.as()) { - auto n = make_object(*ite); + auto n = ffi::make_object(*ite); ICHECK(is_no_op(n->then_case)); ICHECK(!n->else_case); n->then_case = body; body = Stmt(n); } else if (const auto* seq = s.as()) { - auto n = make_object(*seq); + auto n = ffi::make_object(*seq); ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); n->seq.Set(n->size() - 1, body); body = Stmt(n); } else if (const auto* assert_ = s.as()) { - auto n = make_object(*assert_); + auto n = ffi::make_object(*assert_); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); + auto n = ffi::make_object(*alloc); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); + auto n = ffi::make_object(*alloc); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* decl_buffer = s.as()) { - auto n = make_object(*decl_buffer); + auto n = ffi::make_object(*decl_buffer); ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); @@ -130,7 +130,7 @@ class IRConvertSSA final : public StmtExprMutator { if (defined_params.count(var_ptr)) return; if (defined_.count(var_ptr)) { - auto var = GetRef(var_ptr); + auto var = ffi::GetRef(var_ptr); redefines.emplace_back(this, var); } else { defined_.insert(var_ptr); @@ -148,7 +148,7 @@ class IRConvertSSA final : public StmtExprMutator { // Update the buffer map, based on the redefined parameters auto buffer_map = [&]() { - Map buffer_map; + ffi::Map buffer_map; bool made_change = false; for (const auto& [var, buffer] : func->buffer_map) { auto new_var = GetRemappedVar(var); @@ -174,15 +174,15 @@ class IRConvertSSA final : public StmtExprMutator { return DictAttrs(); } - Map dict; + ffi::Map dict; bool made_change = false; for (const auto& [key, old_value] : func->attrs->dict) { auto value = old_value; if (auto* expr = value.as()) { - value = VisitExpr(GetRef(expr)); + value = VisitExpr(ffi::GetRef(expr)); } else if (auto* stmt = value.as()) { - value = VisitStmt(GetRef(stmt)); + value = VisitStmt(ffi::GetRef(stmt)); } made_change = made_change || !value.same_as(old_value); @@ -212,7 +212,7 @@ class IRConvertSSA final : public StmtExprMutator { return func; } - PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(GetRef(op)); } + PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(ffi::GetRef(op)); } PrimExpr VisitExpr_(const LetNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { @@ -248,13 +248,13 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode* op) final { - Block block = GetRef(op); + Block block = ffi::GetRef(op); // The BlockNode is the point of definition for the IterVar // instances. These re-defines must be present before visiting // the body of the BlockNode. std::vector redefines; - Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { + ffi::Array iter_vars = op->iter_vars.Map([&](IterVar iter_var) { if (defined_.count(iter_var->var.get())) { redefines.emplace_back(this, iter_var->var); iter_var.CopyOnWrite()->var = redefines.back().new_var; @@ -263,9 +263,9 @@ class IRConvertSSA final : public StmtExprMutator { } return iter_var; }); - Array reads = + ffi::Array reads = block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); }); - Array writes = + ffi::Array writes = block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); }); if (!reads.same_as(block->reads) || !writes.same_as(block->writes) || @@ -312,8 +312,8 @@ class IRConvertSSA final : public StmtExprMutator { Var new_buffer_var = GetRemappedVar(buf->data); PrimExpr elem_offset = VisitExpr(buf->elem_offset); auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); }; - Array shape = buf->shape.Map(visit_expr); - Array strides = buf->strides.Map(visit_expr); + ffi::Array shape = buf->shape.Map(visit_expr); + ffi::Array strides = buf->strides.Map(visit_expr); // If no mapping is required, return the original buffer. if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) && @@ -432,7 +432,7 @@ class IRConvertSSA final : public StmtExprMutator { IterVar new_iter_var; if (dom.same_as(iter_var->dom) && var.same_as(iter_var->var)) { - new_iter_var = GetRef(iter_var); + new_iter_var = ffi::GetRef(iter_var); } else { new_iter_var = IterVar(dom, var, iter_var->iter_type, iter_var->thread_tag, iter_var->span); } @@ -442,7 +442,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt output; if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) { - output = GetRef(op); + output = ffi::GetRef(op); } else { output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span); } @@ -530,14 +530,14 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } -String GetPtrStorageScope(Var buffer_var) { +ffi::String GetPtrStorageScope(Var buffer_var) { const auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return ptr_type->storage_scope; } -Array GetBufferAllocationShape(const Buffer& buffer) { - Array alloc_shape = buffer->shape; +ffi::Array GetBufferAllocationShape(const Buffer& buffer) { + ffi::Array alloc_shape = buffer->shape; if (buffer->strides.size()) { ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { @@ -549,14 +549,14 @@ Array GetBufferAllocationShape(const Buffer& buffer) { return alloc_shape; } -Array ConvertIndices(const MatchBufferRegion& match_buffer, - const Array& indices) { +ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, + const ffi::Array& indices) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; ICHECK_EQ(indices.size(), target->shape.size()); arith::Analyzer analyzer; - Array result; + ffi::Array result; result.reserve(source->region.size()); size_t offset = source->region.size() - indices.size(); for (size_t i = 0; i < offset; ++i) { @@ -595,7 +595,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region return result; } -Optional ConditionalBoundsContext::TrySolveCondition() { +ffi::Optional ConditionalBoundsContext::TrySolveCondition() { // extract equations and related vars from condition expression. // currently only extract simple integral equations which could be solvable. arith::Analyzer analyzer; @@ -603,8 +603,8 @@ Optional ConditionalBoundsContext::TrySolveCondition() { if (is_const_int(condition)) { return std::nullopt; } - Array equations; - Array vars; + ffi::Array equations; + ffi::Array vars; std::function fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { @@ -615,7 +615,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { return; } else if (const VarNode* var = obj.as()) { if (var->dtype.is_int() || var->dtype.is_uint()) { - cand_vars.push_back(GetRef(var)); + cand_vars.push_back(ffi::GetRef(var)); } } else { is_simple &= obj->IsInstance() || obj->IsInstance() || @@ -648,7 +648,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { return std::nullopt; } // build dom ranges for related vars - Map ranges; + ffi::Map ranges; for (const Var& v : vars) { arith::IntSet dom; auto relax_it = relax_map_->find(v.get()); @@ -684,7 +684,7 @@ ConditionalBoundsContext::ConditionalBoundsContext( origin_pending_conditions_num_(pending_conditions->size()) {} void ConditionalBoundsContext::EnterWithScope() { - Optional constraints = TrySolveCondition(); + ffi::Optional constraints = TrySolveCondition(); if (!constraints.defined()) { // fail to process the condition, add to unresolved pending_conditions_->push_back(condition_); @@ -831,11 +831,11 @@ namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { tir::IRConvertSSA converter; - Map functions; + ffi::Map functions; bool made_change = false; for (auto [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto updated = converter.VisitPrimFunc(GetRef(ptr)); + auto updated = converter.VisitPrimFunc(ffi::GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; base_func = updated; diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index b77213bdf10a..fdf4def699ec 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -69,7 +69,7 @@ Stmt MergeNest(const std::vector>& nest, Stmt body); * original array */ template -inline Array UpdateArray(Array arr, F fupdate) { +inline ffi::Array UpdateArray(ffi::Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; for (size_t i = 0; i < arr.size(); ++i) { @@ -81,7 +81,7 @@ inline Array UpdateArray(Array arr, F fupdate) { if (!changed) { return arr; } else { - return Array(new_arr); + return ffi::Array(new_arr); } } @@ -95,8 +95,8 @@ inline Array UpdateArray(Array arr, F fupdate) { */ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, builtin::TVMStructFieldKind kind) { - Array args = {handle, make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind))}; + ffi::Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; return Call(dtype, builtin::tvm_struct_get(), args); } @@ -142,8 +142,8 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \return the set stmt. */ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { - Array args = {handle, make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind)), value}; + ffi::Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args)); } @@ -195,7 +195,7 @@ inline PrimExpr ConstInt32(size_t index) { * \return PrimExpr representing the TVMValue */ inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImm(type), ConstInt32(num)}; + ffi::Array args = {StringImm(type), ConstInt32(num)}; return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args); } @@ -211,15 +211,15 @@ Stmt ConvertSSA(Stmt stmt); * \param buffer_var The input buffer variable. * \return A string representing the storage scope of this buffer variable. */ -String GetPtrStorageScope(Var buffer_var); +ffi::String GetPtrStorageScope(Var buffer_var); /*! * \brief Convert match buffer target buffer access indices to original one. * \param indices The indices of the target buffer * \return The indices of source buffer. */ -Array ConvertIndices(const MatchBufferRegion& match_buffer, - const Array& indices); +ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, + const ffi::Array& indices); /*! * \brief Convert match buffer target buffer region to original one. @@ -233,7 +233,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region * \param buffer The buffer object. * \return shape The shape considering buffer strides. */ -Array GetBufferAllocationShape(const Buffer& buffer); +ffi::Array GetBufferAllocationShape(const Buffer& buffer); /*! * \brief Context helper to update domain map within conditional scope. @@ -261,7 +261,7 @@ class ConditionalBoundsContext { void ExitWithScope(); /*! \brief Helper to solve related variable's bound within conditional scope.*/ - Optional TrySolveCondition(); + ffi::Optional TrySolveCondition(); /*! \brief the condition holds on true branch. */ const PrimExpr& condition_; @@ -322,12 +322,12 @@ std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); * function body. * \return The updated function. */ -PrimFunc BindParams(PrimFunc f, const Array& constants); +PrimFunc BindParams(PrimFunc f, const ffi::Array& constants); /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = ffi::Tuple; /*! \brief A list of StorageAlignTuple, used by StorageAlign */ -using StorageAlignAnnotation = Array; +using StorageAlignAnnotation = ffi::Array; /*! * \brief Collect storage alignment annotations for all buffer vars within body. * \param body The stmt to collect. diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 8995beb2ce9e..0f643e5e18cb 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -32,14 +32,14 @@ namespace tvm { namespace tir { -std::pair>>, +std::pair>>, ObjectPtrHash, ObjectPtrEqual>, - Map> + ffi::Map> FindLoopLCA(const Stmt& root) { class LCAFinder : public StmtVisitor { public: void VisitStmt_(const ForNode* op) final { - stack.push_back(GetRef(op)); + stack.push_back(ffi::GetRef(op)); StmtVisitor::VisitStmt_(op); if (op->kind == ForKind::kThreadBinding) { UpdateLCA(op); @@ -50,7 +50,7 @@ FindLoopLCA(const Stmt& root) { void UpdateLCA(const ForNode* loop) { std::string thread_tag = loop->thread_binding.value()->thread_tag; { - Map* tgt = &annotations[thread_tag]; + ffi::Map* tgt = &annotations[thread_tag]; for (const auto& kv : loop->annotations) { tgt->Set(kv.first, kv.second); } @@ -78,14 +78,14 @@ FindLoopLCA(const Stmt& root) { std::unordered_map> lca; std::unordered_map iters; - std::unordered_map> annotations; - Map var_subst; + std::unordered_map> annotations; + ffi::Map var_subst; std::vector stack; }; LCAFinder finder; finder(root); - std::unordered_map>>, ObjectPtrHash, - ObjectPtrEqual> + std::unordered_map>>, + ObjectPtrHash, ObjectPtrEqual> result; std::vector sorted_thread_tags; for (const auto& kv : finder.lca) { @@ -104,7 +104,7 @@ FindLoopLCA(const Stmt& root) { for (const auto& thread_tag : sorted_thread_tags) { Stmt lca = finder.lca[thread_tag].back(); const IterVar& iter = finder.iters[thread_tag]; - const Map& annotations = finder.annotations[thread_tag]; + const ffi::Map& annotations = finder.annotations[thread_tag]; result[lca].emplace_back(iter, annotations); } return {result, finder.var_subst}; @@ -117,7 +117,7 @@ FindLoopLCA(const Stmt& root) { class ThreadBindingLifter : public StmtExprMutator { public: Stmt VisitStmt_(const ForNode* _op) final { - For op = GetRef(_op); + For op = ffi::GetRef(_op); bool is_kernel_root = false; if (op->kind == ForKind::kThreadBinding) { if (iter_lca.empty()) { @@ -149,24 +149,24 @@ class ThreadBindingLifter : public StmtExprMutator { } void SetKernelRoot(const ForNode* op) { - auto result = FindLoopLCA(GetRef(op)); + auto result = FindLoopLCA(ffi::GetRef(op)); this->iter_lca = std::move(result.first); this->var_subst = std::move(result.second); } PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_subst.find(GetRef(op)); + auto it = var_subst.find(ffi::GetRef(op)); if (it != var_subst.end()) { return (*it).second; } else { - return GetRef(op); + return ffi::GetRef(op); } } - std::unordered_map>>, ObjectPtrHash, - ObjectPtrEqual> + std::unordered_map>>, + ObjectPtrHash, ObjectPtrEqual> iter_lca; - Map var_subst; + ffi::Map var_subst; }; PrimFunc LiftThreadBinding(PrimFunc f) { diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f083a9d6d4df..1a78536dbaf4 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -113,7 +113,7 @@ class CandidateSelector final : public StmtExprVisitor { // always treat var with hint to be partitioned const VarNode* var = op->loop_var.get(); if (partition_hint_vars.count(var)) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); return; } @@ -122,7 +122,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); } record_.erase(var); } else { @@ -137,7 +137,7 @@ class CandidateSelector final : public StmtExprVisitor { Var var = iv->var; // always treat var with hint to be partitioned if (partition_hint_vars.count(var.get())) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); StmtExprVisitor::VisitStmt_(op); return; } @@ -146,7 +146,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { - candidates.insert(GetRef(op)); + candidates.insert(ffi::GetRef(op)); } record_.erase(var.get()); return; @@ -213,7 +213,7 @@ class CandidateSelector final : public StmtExprVisitor { #define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ void VisitExpr_(const OpNodeT* op) final { \ if (has_partition_hint_) { \ - DeduceCondition(GetRef(op)); \ + DeduceCondition(ffi::GetRef(op)); \ return; \ } \ StmtExprVisitor::VisitExpr_(op); \ @@ -421,7 +421,7 @@ class LoopPartitioner : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true); - auto fs = GetRef(op); + auto fs = ffi::GetRef(op); if (selector.candidates.count(fs)) { Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; @@ -443,7 +443,7 @@ class LoopPartitioner : public StmtMutator { const IterVarNode* iv = op->node.as(); ICHECK(iv); Var var = iv->var; - auto as = GetRef(op); + auto as = ffi::GetRef(op); if (selector.candidates.count(as)) { Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; @@ -489,7 +489,7 @@ class LoopPartitioner : public StmtMutator { std::pair LoopPartitioner::GetIntervalAndCondset( const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value, bool has_partition_hint) { - Array sets; + ffi::Array sets; ExpressionSet cond_set; for (const auto& kv : partitions) { diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 71c6c945e8f3..e5510664bea8 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -52,7 +52,8 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { } // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior - std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); + std::optional mem_copy = + IdentifyMemCpy(ffi::GetRef(loop), analyzer_); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); @@ -159,7 +160,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { std::set queue_ids_; std::optional async_queue_id_ = std::nullopt; bool dma_bypass_cache_; - Map input_iters = Map(); + ffi::Map input_iters = ffi::Map(); }; namespace transform { diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 00cc2f226a60..ae81a9e6c5bc 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -105,7 +105,7 @@ bool IsDominantBlock(const Block& scope_block, const Block& block) { * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ -bool IsReductionBlock(const BlockRealize& realize, const Map& loop_range_map, +bool IsReductionBlock(const BlockRealize& realize, const ffi::Map& loop_range_map, const Block& scope_block, arith::Analyzer* analyzer) { const auto* block = realize->block.as(); // Cond 1. The block has the `init` statement. @@ -123,11 +123,11 @@ bool IsReductionBlock(const BlockRealize& realize, const Map& loop_r } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. - if (!IsDominantBlock(scope_block, GetRef(block))) { + if (!IsDominantBlock(scope_block, ffi::GetRef(block))) { return false; } // Cond 5. The reduction block vars are not used to index the output buffers. - return ReductionIterNotIndexOutputBuffer(GetRef(block)); + return ReductionIterNotIndexOutputBuffer(ffi::GetRef(block)); } /*! @@ -137,11 +137,12 @@ bool IsReductionBlock(const BlockRealize& realize, const Map& loop_r * computation results or not, which is used for determine the buffer name prefix * \return The created buffers */ -Array MakeScratchpads(const Array& reduction_buffers, bool is_cross_thread_buffer) { - Array new_buffers; +ffi::Array MakeScratchpads(const ffi::Array& reduction_buffers, + bool is_cross_thread_buffer) { + ffi::Array new_buffers; new_buffers.reserve(reduction_buffers.size()); for (const Buffer& buffer : reduction_buffers) { - String name = is_cross_thread_buffer ? "cross" : "in"; + ffi::String name = is_cross_thread_buffer ? "cross" : "in"; name = name + "_thread_" + buffer->name; new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local")), /*dtype=*/buffer->dtype, @@ -162,8 +163,8 @@ Array MakeScratchpads(const Array& reduction_buffers, bool is_cr */ class BufferReplacer : private StmtExprMutator { public: - static Stmt Run(Array src_buffers, Array tgt_buffers, Stmt stmt) { - Map buffer_map; + static Stmt Run(ffi::Array src_buffers, ffi::Array tgt_buffers, Stmt stmt) { + ffi::Map buffer_map; ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); int n_buffers = src_buffers.size(); for (int i = 0; i < n_buffers; ++i) { @@ -173,11 +174,12 @@ class BufferReplacer : private StmtExprMutator { } private: - explicit BufferReplacer(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} + explicit BufferReplacer(ffi::Map buffer_map) + : buffer_map_(std::move(buffer_map)) {} PrimExpr VisitExpr_(const BufferLoadNode* load) final { auto it = buffer_map_.find(load->buffer); - return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : GetRef(load); + return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : ffi::GetRef(load); } Stmt VisitStmt_(const BufferStoreNode* store) final { @@ -190,7 +192,7 @@ class BufferReplacer : private StmtExprMutator { } } - Map buffer_map_; + ffi::Map buffer_map_; }; /*! @@ -217,7 +219,7 @@ class InThreadReducerMaker : private StmtMutator { private: void VisitStmt_(const BlockNode* block) final { - Array iter_vars = block->iter_vars; + ffi::Array iter_vars = block->iter_vars; for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type == kCommReduce) { reduction_block_vars_.push_back(iter_var); @@ -227,17 +229,17 @@ class InThreadReducerMaker : private StmtMutator { } /*! \brief the map from thread tag to its extent */ - Array reduction_block_vars_; + ffi::Array reduction_block_vars_; }; - static Optional Make(const BlockRealizeNode* src_realize, - Optional tgt_realize, Stmt stmt) { + static ffi::Optional Make(const BlockRealizeNode* src_realize, + ffi::Optional tgt_realize, Stmt stmt) { return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); } private: explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, - Optional tgt_realize) + ffi::Optional tgt_realize) : src_realize_(src_realize), tgt_realize_(tgt_realize) {} Stmt VisitStmt_(const BlockRealizeNode* realize) final { if (realize == src_realize_) { @@ -245,11 +247,11 @@ class InThreadReducerMaker : private StmtMutator { ? tgt_realize_.value() : Stmt{nullptr}; } - return GetRef(realize); + return ffi::GetRef(realize); } Stmt VisitStmt_(const ForNode* loop) final { - if (Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { + if (ffi::Optional opt_res = Downcast>(StmtMutator::VisitStmt_(loop))) { For res = opt_res.value(); if (res->thread_binding.defined()) { UnderLoopReductionBlockVarCollector collector; @@ -267,10 +269,10 @@ class InThreadReducerMaker : private StmtMutator { } Stmt VisitStmt_(const SeqStmtNode* seq) final { - Array stmts; + ffi::Array stmts; stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { - if (Optional opt_res = VisitStmt(stmt)) { + if (ffi::Optional opt_res = VisitStmt(stmt)) { stmts.push_back(opt_res.value()); } } @@ -278,7 +280,7 @@ class InThreadReducerMaker : private StmtMutator { } const BlockRealizeNode* src_realize_; - Optional tgt_realize_; + ffi::Optional tgt_realize_; }; /*! @@ -293,19 +295,19 @@ class InThreadReducerMaker : private StmtMutator { * \param combiner_rhs The RHS values of the combiner * \param reduction_loops The reduction loops */ -Stmt TransformReductionBlock(const BlockRealizeNode* realize, // - const Optional>& it_buffers, // - const Array& ct_buffers, // - const Array& wb_buffers, // - const Array& old_wb_indices, // - const CommReducer& reducer, // - const Array& combiner_rhs, // +Stmt TransformReductionBlock(const BlockRealizeNode* realize, // + const ffi::Optional>& it_buffers, // + const ffi::Array& ct_buffers, // + const ffi::Array& wb_buffers, // + const ffi::Array& old_wb_indices, // + const CommReducer& reducer, // + const ffi::Array& combiner_rhs, // const std::vector& reduction_loops) { int n_buffers = wb_buffers.size(); const BlockNode* block = realize->block.get(); - auto f_create_buffer_regions = [](Array buffers) { - Array regions; + auto f_create_buffer_regions = [](ffi::Array buffers) { + ffi::Array regions; regions.reserve(buffers.size()); for (const Buffer& buffer : buffers) { regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)})); @@ -313,8 +315,8 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // return regions; }; - Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); - Optional> it_buffer_regions = std::nullopt; + ffi::Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); + ffi::Optional> it_buffer_regions = std::nullopt; if (it_buffers.defined()) { it_buffer_regions = f_create_buffer_regions(it_buffers.value()); } @@ -323,11 +325,11 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // // - Stmt 2: do in-thread reduction // - Stmt 3: do cross-thread reduction // - Stmt 4: write cross-thread reduction result to the original buffer - Array stmts; + ffi::Array stmts; stmts.reserve(4); // Stmt 1: initialize the buffer for in-thread reduction if (it_buffers.defined()) { - Array inits; + ffi::Array inits; inits.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { inits.push_back( @@ -344,31 +346,32 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } // Stmt 2: do in-thread reduction { - Optional new_realize = std::nullopt; + ffi::Optional new_realize = std::nullopt; // If need to generate in-thread reduction, // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize // otherwise, directly remove given BlockRealize if (it_buffers.defined()) { - ObjectPtr new_block = make_object(*block); + ObjectPtr new_block = ffi::make_object(*block); new_block->reads = std::move(new_block->reads); new_block->writes = it_buffer_regions.value(); new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); new_block->init = std::nullopt; - ObjectPtr n = make_object(*realize); + ObjectPtr n = ffi::make_object(*realize); n->block = Block(new_block); new_realize = BlockRealize(n); } - For loop = GetRef(reduction_loops[0]); - if (Optional stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { + For loop = ffi::GetRef(reduction_loops[0]); + if (ffi::Optional stmt = + InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { stmts.push_back(stmt.value()); } } // Stmt 3: do cross-thread reduction { // Step 3.1. Create the parameters to the intrinsic - Array parameters; + ffi::Array parameters; parameters.reserve(reduction_loops.size() + 4); // 1-st argument: number of buffers parameters.push_back(make_const(DataType::UInt(32), n_buffers)); @@ -393,12 +396,12 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } } // Step 3.2. Create the block and the block-realize. - Array iter_vars{nullptr}; - Array bindings{nullptr}; - Array reads{nullptr}; + ffi::Array iter_vars{nullptr}; + ffi::Array bindings{nullptr}; + ffi::Array reads{nullptr}; if (it_buffers.defined()) { - iter_vars = Array{}; - bindings = Array{}; + iter_vars = ffi::Array{}; + bindings = ffi::Array{}; reads = it_buffer_regions.value(); } else { iter_vars = block->iter_vars; @@ -426,9 +429,9 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // { ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); int n_iter = static_cast(block->iter_vars.size()); - Array iter_vars; - Array bindings; - Map var_map; + ffi::Array iter_vars; + ffi::Array bindings; + ffi::Map var_map; iter_vars.reserve(n_iter); bindings.reserve(n_iter); for (int i = 0; i < n_iter; ++i) { @@ -437,8 +440,8 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // if (iter_var->iter_type != kCommReduce) { IterVar new_iter_var{nullptr}; { - ObjectPtr n = make_object(*iter_var.get()); - ObjectPtr v = make_object(*iter_var->var.get()); + ObjectPtr n = ffi::make_object(*iter_var.get()); + ObjectPtr v = ffi::make_object(*iter_var->var.get()); n->var = Var(v); new_iter_var = IterVar(n); } @@ -447,13 +450,13 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // var_map.Set(iter_var->var, new_iter_var->var); } } - Array wb_updates; - Array wb_regions; + ffi::Array wb_updates; + ffi::Array wb_regions; wb_updates.reserve(n_buffers); wb_regions.reserve(n_buffers); int n_dim = static_cast(old_wb_indices.size()); - Array region = Substitute(block->writes[0]->region, var_map); - Array wb_indices; + ffi::Array region = Substitute(block->writes[0]->region, var_map); + ffi::Array wb_indices; wb_indices.reserve(n_dim); for (int d = 0; d < n_dim; ++d) { wb_indices.push_back(Substitute(old_wb_indices[d], var_map)); @@ -475,13 +478,13 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) { if (const auto* and_node = obj.as()) { - Array sub_exprs = {and_node->a, and_node->b}; + ffi::Array sub_exprs = {and_node->a, and_node->b}; for (PrimExpr sub_expr : sub_exprs) { if (sub_expr->IsInstance()) { continue; } bool is_reduction = [sub_expr, &reduction_loop_vars]() { - Array vars = UndefinedVars(sub_expr); + ffi::Array vars = UndefinedVars(sub_expr); for (Var var : vars) { if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) { return true; @@ -520,7 +523,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) { const ForNode* loop = *rit; if (loop->thread_binding.defined()) { - ObjectPtr n = make_object(*loop); + ObjectPtr n = ffi::make_object(*loop); n->body = std::move(new_stmt); new_stmt = For(n); } @@ -541,14 +544,14 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. - if (!IsReductionBlock(GetRef(realize), loop_range_map_, - GetRef(block_stack_.back()), &analyzer_)) { + if (!IsReductionBlock(ffi::GetRef(realize), loop_range_map_, + ffi::GetRef(block_stack_.back()), &analyzer_)) { return {}; } // Step 2. Collect all the vars that appear in the bindings of reduction block iters. std::unordered_set reduction_vars; - GetVarsTouchedByBlockIters(GetRef(realize), nullptr, &reduction_vars); + GetVarsTouchedByBlockIters(ffi::GetRef(realize), nullptr, &reduction_vars); // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. // We call these loops "reduction-related". @@ -628,7 +631,7 @@ class CrossThreadReductionTransformer : public StmtMutator { * - the RHS values of the reduction updates, * - the indices which is used to access the reduction buffers when storing the reduction results */ - std::tuple, Array, Array> + std::tuple, ffi::Array, ffi::Array> CheckCanApplyCrossThreadReduction(const BlockNode* block, const std::vector& reduction_loops) const { // Condition 1. All the reduction-related loops should be the deepest among all statements @@ -669,19 +672,19 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 3. Get the identity values of the block init and the BufferStore block combiner // updates of the reduction. Extract the commutative reducer, combiner lhs and combiner rhs from // the reduction identities and the reduction combiner. - Array init_values{nullptr}; - Array updates{nullptr}; + ffi::Array init_values{nullptr}; + ffi::Array updates{nullptr}; CommReducer reducer{nullptr}; - Array combiner_lhs{nullptr}; - Array combiner_rhs{nullptr}; + ffi::Array combiner_lhs{nullptr}; + ffi::Array combiner_rhs{nullptr}; std::tie(init_values, updates) = - GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, GetRef(block)); + GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, ffi::GetRef(block)); std::tie(reducer, combiner_lhs, combiner_rhs) = GetReducerAndCombinerLhsRhs(std::nullopt, init_values, updates); // Condition 4. All reduction buffers should be all local or all non-local. int is_local_buf = -1; - Array reduction_buffers; + ffi::Array reduction_buffers; reduction_buffers.reserve(updates.size()); for (const BufferStore& buf_store : updates) { reduction_buffers.push_back(buf_store->buffer); @@ -702,7 +705,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; - PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { + PreOrderVisit(ffi::GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " "block isn't the last block under its first reduction-related loop"; @@ -772,7 +775,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } Stmt VisitStmt_(const BlockNode* block) final { - Map old_loop_range_map; + ffi::Map old_loop_range_map; block_stack_.push_back(block); std::swap(old_loop_range_map, loop_range_map_); @@ -801,9 +804,9 @@ class CrossThreadReductionTransformer : public StmtMutator { // which condition the block violates. int n_bound_reduction_loops = 0; CommReducer reducer{nullptr}; - Array reduction_buffers{nullptr}; - Array combiner_rhs{nullptr}; - Array wb_indices{nullptr}; + ffi::Array reduction_buffers{nullptr}; + ffi::Array combiner_rhs{nullptr}; + ffi::Array wb_indices{nullptr}; std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = CheckCanApplyCrossThreadReduction(block, reduction_loops); // Step 2. Before doing the cross-thread reduction, in-thread reduction is needed when @@ -814,10 +817,11 @@ class CrossThreadReductionTransformer : public StmtMutator { !is_one(realize->predicate); // Step 3. Create intermediate buffers, storing them in `ct_buffers` and // `it_buffers`. Let the scope block allocate these new buffers. - Array& new_buffers = block2new_buffers_[block_stack_.back()]; - Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); + ffi::Array& new_buffers = block2new_buffers_[block_stack_.back()]; + ffi::Array ct_buffers = + MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end()); - Optional> it_buffers = std::nullopt; + ffi::Optional> it_buffers = std::nullopt; if (need_in_thread_reduction) { it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); @@ -849,7 +853,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Step 1. Generate loop var for each unbound thread. // Update the block predicate with clauses of `thread_var == min`. PrimExpr predicate = realize->predicate; - Array loop_vars; + ffi::Array loop_vars; loop_vars.reserve(unbound_thread2range.size()); for (auto [scope, range] : unbound_thread2range) { std::string dim_index(1, static_cast(scope.dim_index + 'x')); @@ -859,7 +863,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } // Step 2. Update the BlockRealize with the new predicate. - ObjectPtr p_realize = make_object(*realize); + ObjectPtr p_realize = ffi::make_object(*realize); p_realize->predicate = std::move(predicate); // Step 3. Wrap the updated BlockRealize with the new loops. @@ -910,9 +914,9 @@ class CrossThreadReductionTransformer : public StmtMutator { std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; - std::unordered_map> block2new_buffers_; + std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; - Map loop_range_map_; + ffi::Map loop_range_map_; arith::Analyzer analyzer_; int block_idx_depth = 0; diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index f77276e1553c..1f15643ad89f 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -64,7 +64,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { PrimExpr VisitExpr_(const FloatImmNode* imm) final { auto type_code = imm->dtype.code(); - auto e = GetRef(imm); + auto e = ffi::GetRef(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " @@ -75,7 +75,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 529956d372f3..496c4374e203 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -43,21 +43,21 @@ struct KernelInfo { // The externally visible symbol which may refer to the PrimFunc // when launching a device kernel. - String global_symbol; + ffi::String global_symbol; // The parameters accepted by the PrimFunc. Used to rewrite // `launch_args` to be in terms of the calling scope. - Array params; + ffi::Array params; // The launch parameters that should annotate the PrimFunc, if the // kernel is ever called from the host. - Array launch_params; + ffi::Array launch_params; // Additional arguments which must be provided to the host-side // ffi::Function. These may be in terms of the function's parameters // (e.g. a function that computes the average of `N` elements, and // which must be launched with `N` CUDA threads). - Array launch_args; + ffi::Array launch_args; }; /*! @@ -80,7 +80,7 @@ class DeviceInfoCollector : public StmtVisitor { } collector.info_.global_symbol = - func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); collector.info_.launch_args = collector.info_.launch_params.Map( [&](const auto& param) { return collector.GetArgument(param); }); @@ -89,7 +89,7 @@ class DeviceInfoCollector : public StmtVisitor { } private: - PrimExpr GetArgument(const String& launch_param) const { + PrimExpr GetArgument(const ffi::String& launch_param) const { if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { CHECK(dyn_shmem_size.defined()) << "Compute kernel requires launch parameter \"" << launch_param @@ -142,9 +142,9 @@ class DeviceInfoCollector : public StmtVisitor { // recording what thread axis have been visited. std::unordered_set defined_thread; // The extent of each thread - Map thread_extent; + ffi::Map thread_extent; // The amount of dynamic shared memory used - Optional dyn_shmem_size{std::nullopt}; + ffi::Optional dyn_shmem_size{std::nullopt}; }; class ReturnRemover : public StmtExprMutator { @@ -229,7 +229,7 @@ class DeviceKernelMutator : public StmtExprMutator { {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); - } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { + } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } @@ -266,7 +266,7 @@ class DeviceKernelMutator : public StmtExprMutator { // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. extern_function_call_.insert(gvar); - Array args; + ffi::Array args; args.push_back(StringImm(gvar->name_hint)); for (const auto& arg : node->args) { args.push_back(arg); @@ -285,8 +285,8 @@ class DeviceKernelMutator : public StmtExprMutator { // caller's parameters. The param_map allows substitution of // parameter values into the thread extents, to generate // expressions that are valid within the caller. - Map param_map = [&]() { - Map param_map; + ffi::Map param_map = [&]() { + ffi::Map param_map; CHECK_EQ(node->args.size(), dev_info.params.size()) << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() << " arguments as input, but is called using " << node->args.size() << " arguments"; @@ -298,7 +298,7 @@ class DeviceKernelMutator : public StmtExprMutator { device_kernel_launch_.insert(gvar); - Array call_args; + ffi::Array call_args; call_args.push_back(StringImm(dev_info.global_symbol)); for (PrimExpr arg : node->args) { call_args.push_back(arg); @@ -312,7 +312,7 @@ class DeviceKernelMutator : public StmtExprMutator { return Call(dtype, builtin::tvm_call_packed(), call_args); } - Optional current_target_; + ffi::Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; std::unordered_set extern_function_call_; @@ -336,7 +336,7 @@ Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto prim_func = mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); + auto prim_func = mutator.RewriteKernelLaunchSite(gvar, ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } @@ -352,7 +352,7 @@ Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto* ptr = base_func.as()) { - auto prim_func = mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); + auto prim_func = mutator.UpdateKernelAttributes(gvar, ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index d3994b066dbc..304855da60ca 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -45,7 +45,7 @@ class InitBlockLower : public StmtMutator { return Block(n); } - static Stmt DoLowering(const Stmt& init, const Array& iter_vars) { + static Stmt DoLowering(const Stmt& init, const ffi::Array& iter_vars) { std::vector conditions; for (const IterVar& var : iter_vars) { if (var->iter_type == IterVarType::kCommReduce) { diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2915a741e80e..0ad827333941 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -68,9 +68,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode* op) final { if (auto* ptr_op = op->op.as()) { for (const auto& f_attr_map : attr_maps_) { - FLowerGeneral f = f_attr_map.get(GetRef(ptr_op), nullptr); + FLowerGeneral f = f_attr_map.get(ffi::GetRef(ptr_op), nullptr); if (f != nullptr) { - PrimExpr e = GetRef(op); + PrimExpr e = ffi::GetRef(op); PrimExpr r = f(e); ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { @@ -97,7 +97,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions PrimExpr VisitExpr_(const FloorDivNode* op) final { - auto e = GetRef(op); + auto e = ffi::GetRef(op); PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; @@ -290,7 +290,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using namespace arith; PVar x, y; PVar c; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); @@ -301,7 +301,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const EQNode* op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if ((floormod(x, y) == 0).Match(e)) { return VisitExpr((truncmod(x, y) == 0).Eval()); } @@ -311,7 +311,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const NENode* op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = ffi::GetRef(op); if ((floormod(x, y) != 0).Match(e)) { return VisitExpr((truncmod(x, y) != 0).Eval()); } @@ -387,7 +387,7 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - auto mtriple = target.value()->GetAttr("mtriple", ""); + auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); return f; diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index d301e910f922..e7c3b6485fc9 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -52,9 +52,9 @@ class MatchBufferLower : public StmtExprMutator { Stmt stmt = StmtExprMutator ::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - Array reads = + ffi::Array reads = op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = op->writes.Map( + ffi::Array writes = op->writes.Map( std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { @@ -74,7 +74,7 @@ class MatchBufferLower : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); + Var v = ffi::GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { return (*it).second; @@ -115,7 +115,7 @@ class MatchBufferLower : public StmtExprMutator { } else { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; - Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ffi::Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); @@ -170,13 +170,13 @@ class MatchBufferLower : public StmtExprMutator { // Step.2.2. Update element offset // We use the ElemOffset method to avoid duplicating the index calculation. { - Array indices; + ffi::Array indices; indices.reserve(source->region.size()); for (const Range& range : source->region) { indices.push_back(range->min); } - Array buffer_start_indices = source_buffer->ElemOffset(indices); + ffi::Array buffer_start_indices = source_buffer->ElemOffset(indices); if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) @@ -184,7 +184,7 @@ class MatchBufferLower : public StmtExprMutator { << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { // Non-zero elem_offset is ill-defined for non-flat memory. - // If needed in the future, will require `Array + // If needed in the future, will require `ffi::Array // elem_offsets`, with one offset for each flattened index. Bind(buffer->elem_offset, make_const(buffer->elem_offset.dtype(), 0)); } @@ -246,9 +246,9 @@ class MatchBufferLower : public StmtExprMutator { private: /*! \brief Buffer region mapping. */ - Map match_buffers_; + ffi::Map match_buffers_; /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ - Map var_map_; + ffi::Map var_map_; /*! \brief The analyzer */ arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 75bfece625d8..9154c5c3c6e8 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -57,9 +57,9 @@ class OpaqueBlockLower : public StmtExprMutator { // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { const Buffer& buffer = new_block->alloc_buffers[i - 1]; - Array allocation_shape = GetBufferAllocationShape(buffer); + ffi::Array allocation_shape = GetBufferAllocationShape(buffer); body = DeclBuffer(buffer, std::move(body)); - Map allocate_annotations; + ffi::Map allocate_annotations; auto it = storage_align_.find(buffer->data); if (it != storage_align_.end()) { StorageAlignAnnotation allocate_aligns; @@ -94,13 +94,13 @@ class OpaqueBlockLower : public StmtExprMutator { Stmt body = this->VisitStmt(op->body); // Step 3. Handle annotations std::vector> pragma_attrs; - Map new_annotations = + ffi::Map new_annotations = HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false); // Step 4. Create new For loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding ICHECK(op->thread_binding.defined()); - String thread_tag = op->thread_binding.value()->thread_tag; + ffi::String thread_tag = op->thread_binding.value()->thread_tag; body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); } else if (is_one(extent) && op->annotations.empty()) { // Case 2. Unit loop @@ -118,7 +118,7 @@ class OpaqueBlockLower : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto it = unit_loop_vars_.find(var); if (it == unit_loop_vars_.end()) { return var; @@ -132,16 +132,16 @@ class OpaqueBlockLower : public StmtExprMutator { } } - static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, + static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, ffi::String thread_tag, Stmt body) { IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), /*var=*/std::move(var), /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); - String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || - thread_tag == "vthread.y" || thread_tag == "vthread.z") - ? attr::virtual_thread - : attr::thread_extent; + ffi::String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), @@ -149,12 +149,12 @@ class OpaqueBlockLower : public StmtExprMutator { } /*! \brief Convert attr value from annotation map into PrimExpr. */ - PrimExpr ConvertAttrValue(const String& key, const Any& obj) { + PrimExpr ConvertAttrValue(const ffi::String& key, const Any& obj) { if (obj == nullptr) { return PrimExpr(); } else if (auto expr = obj.try_cast()) { return expr.value(); - } else if (auto str = obj.try_cast()) { + } else if (auto str = obj.try_cast()) { return std::move(StringImm(str.value())); } else { LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj.GetTypeKey() @@ -171,13 +171,13 @@ class OpaqueBlockLower : public StmtExprMutator { * (3) the non-pragma block annotations are dropped * \return New annotation dict with preserved keys. Also update pragma attr pairs ordered by key. */ - Map HandleAnnotations( - const Map& annotations, + ffi::Map HandleAnnotations( + const ffi::Map& annotations, std::vector>* pragma_attrs, bool is_block) { - Map preserved_annotations; + ffi::Map preserved_annotations; pragma_attrs->clear(); for (const auto& kv : annotations) { - const String& key = kv.first; + const ffi::String& key = kv.first; if (attr::IsPragmaKey(key)) { pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); } else if (!is_block) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 3b972482b728..37c652f0b356 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -92,7 +92,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return node; } - Optional GetRemappedBuffer(const Buffer& buf) { + ffi::Optional GetRemappedBuffer(const Buffer& buf) { if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) { return it->second; } @@ -162,7 +162,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const IntImmNode* size_of_args = call->args[0].as(); ICHECK(size_of_args) << call->args[0]->GetTypeKey(); ICHECK_EQ(size, size_of_args->value); - Array inits = combiner->identity_element; + ffi::Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); PrimExpr cond = call->args[size + 1]; @@ -433,12 +433,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::pair, std::vector> MakeWarpAllreduce( - std::vector src_values, // - std::vector dtypes, // - const CommReducerNode* combiner, // - PrimExpr reduce_index, int reduce_extent, // - PrimExpr group_index, // - PrimExpr mask, Optional predicate, // + std::vector src_values, // + std::vector dtypes, // + const CommReducerNode* combiner, // + PrimExpr reduce_index, int reduce_extent, // + PrimExpr group_index, // + PrimExpr mask, ffi::Optional predicate, // std::vector* seq) { int n_buffers = src_values.size(); @@ -449,8 +449,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // This is the index to the reduction variable, one reduction // variable per warp. Local scope seems easier to reason without // relying on a pattern match pass to fix it later. - Array zero_indices = {0}; - Array shape = {1}; + ffi::Array zero_indices = {0}; + ffi::Array shape = {1}; std::vector load_values; load_values.reserve(n_buffers); @@ -473,7 +473,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The mask for this reducer, as this reducer may sit inside // a divergent control flow. Here it uses a variable to cache the current // active channels. - Optional mask_buffer; + ffi::Optional mask_buffer; if (need_warp_shuffle_mask_) { mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices)); @@ -489,7 +489,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } for (int offset = start_offset; offset > 0; offset /= 2) { // Load reduction values, no synchronization needed. - Array a, b; + ffi::Array a, b; for (int i = 0; i < n_buffers; ++i) { Buffer shared_buf = shared_bufs[i]; BufferLoad val(shared_buf, zero_indices); @@ -519,7 +519,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Do reductions. - Array ret = (*combiner)(a, b); + ffi::Array ret = (*combiner)(a, b); // Store the reduction result to itself. std::vector stores; @@ -554,7 +554,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, - const Array& shared_bufs, PrimExpr reduce_index, + const ffi::Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, int reduce_extent, int group_extent, int contiguous_reduce_extent) { // Get next power of two @@ -569,7 +569,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction auto fload = [&](int offset) { - Array a, b; + ffi::Array a, b; for (size_t i = 0; i < size; ++i) { BufferLoad b_load(shared_bufs[i], {BufIndex(reduce_index + offset, group_index, reduce_extent)}); @@ -580,10 +580,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { ICHECK_EQ(a_load->dtype, types[i]); a.push_back(a_load); } - Array ret = (*combiner)(a, b); + ffi::Array ret = (*combiner)(a, b); return ret; }; - auto fstore = [&](const Array& ret) { + auto fstore = [&](const ffi::Array& ret) { std::vector stores(size); for (size_t i = 0; i < size; ++i) { stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index}); @@ -633,7 +633,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // here to reduce thread divergence. auto loads = fload(reduce_align); - Array in_warp_local_vars; + ffi::Array in_warp_local_vars; for (auto expr : loads) { Var var( "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()), @@ -696,9 +696,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Optional mask_buffer, PrimExpr val, + PrimExpr WarpShuffle(const Op& op, ffi::Optional mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { - Array indices = {0}; + ffi::Array indices = {0}; PrimExpr mask; if (mask_buffer.defined()) { mask = BufferLoad(mask_buffer.value(), indices); @@ -706,7 +706,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { mask = IntImm(DataType::Int(32), 0); } PrimExpr width = IntImm(DataType::Int(32), warp_size_); - Array args{mask, val, delta_or_lane, width, width}; + ffi::Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e74f5c7c9046..028fa4eb0368 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -40,7 +40,7 @@ namespace tir { class BuiltinLower : public StmtExprMutator { public: static PrimFunc Build(PrimFunc func) { - Optional device_type = std::nullopt; + ffi::Optional device_type = std::nullopt; if (auto target = func->GetAttr(tvm::attr::kTarget)) { device_type = Integer(target.value()->kind->default_device_type); } @@ -50,7 +50,7 @@ class BuiltinLower : public StmtExprMutator { return func; } - explicit BuiltinLower(Optional device_type = std::nullopt) + explicit BuiltinLower(ffi::Optional device_type = std::nullopt) : device_type_(device_type) {} // NOTE: Right now, we make the following scoping requirement @@ -317,7 +317,7 @@ class BuiltinLower : public StmtExprMutator { } if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = CopyOnWrite(op); n->min = std::move(min); @@ -370,7 +370,7 @@ class BuiltinLower : public StmtExprMutator { << "but was instead the expression " << device_type_ << " with type " << device_type_.value()->GetTypeKey(); - String device_name = runtime::DLDeviceType2Str(as_int->value); + ffi::String device_name = runtime::DLDeviceType2Str(as_int->value); return StringImm("device_api." + device_name + "." + method_name); } @@ -594,9 +594,9 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array packed_args = {op->args[name_offset], scope.stack_ffi_any, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + num_args)}; + ffi::Array packed_args = {op->args[name_offset], scope.stack_ffi_any, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + num_args)}; if (pass_last_arg_as_traced_value) { // pass in last element as traced value // used by call_packed_traced @@ -626,7 +626,7 @@ class BuiltinLower : public StmtExprMutator { std::string fdevapi_prefix = "device_api."; fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as()->value); - Array args = { + ffi::Array args = { GetDeviceMethodName("alloc_nd"), device_type_.value(), device_id_.value(), @@ -657,8 +657,8 @@ class BuiltinLower : public StmtExprMutator { // The prepration sequence to be emitted before the current statement. std::vector> prep_seq_stack_; - Optional device_type_{std::nullopt}; - Optional device_id_{std::nullopt}; + ffi::Optional device_type_{std::nullopt}; + ffi::Optional device_id_{std::nullopt}; bool is_precheck_{false}; diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc index 7cddfb678514..ac9a2940a942 100644 --- a/src/tir/transforms/lower_vtcm_alloc.cc +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -40,7 +40,7 @@ class VtcmAllocator : public StmtExprMutator { std::string storage_scope = GetStorageScope(op->buffer_var); if (IsVtcmStorage(storage_scope)) { Stmt body = this->VisitStmt(op->body); - Array args; + ffi::Array args; args.push_back(StringImm(storage_scope)); args.push_back(IntImm(DataType::Int(64), op->extents.size())); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents)); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 5708ab0746f2..1c8968aee915 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -150,7 +150,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } void UpdatePattern(const PrimExpr& index) { - Array m = arith::DetectLinearEquation(index, {warp_index_}); + ffi::Array m = arith::DetectLinearEquation(index, {warp_index_}); ICHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed. Could not simplify the store index `" << index << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " @@ -254,7 +254,7 @@ class WarpAccessRewriter : protected StmtExprMutator { protected: PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector& indices) { - Array new_args = op->args; + ffi::Array new_args = op->args; for (int i : indices) { if (op->args[i].get() == buffer_) { PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first; @@ -426,7 +426,7 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } - std::unordered_map new_storage_scopes_; + std::unordered_map new_storage_scopes_; private: Stmt VisitStmt_(const AllocateNode* op) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 198b8cfc2e32..cad095b5009a 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -125,7 +125,8 @@ class ReturnRewriter : public StmtMutator { class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const Map& packed_func_methods, Stmt stmt) { + static ffi::Optional Apply(const ffi::Map& packed_func_methods, + Stmt stmt) { SubroutineCallRewriter rewriter(packed_func_methods); stmt = rewriter.VisitStmt(std::move(stmt)); if (rewriter.made_change_) { @@ -136,16 +137,16 @@ class SubroutineCallRewriter : public StmtExprMutator { } private: - explicit SubroutineCallRewriter(const Map& packed_func_methods) + explicit SubroutineCallRewriter(const ffi::Map& packed_func_methods) : packed_func_methods(packed_func_methods) {} PrimExpr VisitExpr_(const CallNode* op) override { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); if (auto* gvar_ptr = node->op.as()) { - auto gvar = GetRef(gvar_ptr); + auto gvar = ffi::GetRef(gvar_ptr); if (auto symbol = packed_func_methods.Get(gvar)) { - Array cpacked_args; + ffi::Array cpacked_args; cpacked_args.push_back(tir::StringImm(symbol.value())); for (auto arg : node->args) { cpacked_args.push_back(arg); @@ -160,7 +161,7 @@ class SubroutineCallRewriter : public StmtExprMutator { return node; } - const Map& packed_func_methods; + const ffi::Map& packed_func_methods; bool made_change_{false}; }; @@ -182,7 +183,7 @@ inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { * \returns The global_symbol to be used for the function at call * sites, or std::nullopt if the function is to remain unchanged. */ -Optional RequiresPackedAPI(const PrimFunc& func) { +ffi::Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { @@ -192,7 +193,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { } // Internal function calls do not need the ffi::Function API - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.has_value()) { return std::nullopt; } @@ -248,8 +249,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { // local function definitions // load i-th argument as type t auto f_load_arg_value = [&](DataType arg_type, int i) { - Array call_args{v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; + ffi::Array call_args{v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; // load 64 bit version DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); @@ -347,7 +348,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result) - Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; + ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; // Arg definitions are defined before buffer binding to avoid the use before // def errors. @@ -396,11 +397,11 @@ PrimFunc MakePackedAPI(PrimFunc func) { func_ptr->body = body; func_ptr->params = args; - Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); + ffi::Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined << " are used, but are not passed in as API arguments"; - func_ptr->buffer_map = Map(); + func_ptr->buffer_map = ffi::Map(); func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. @@ -411,7 +412,7 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { - Map packed_func_methods; + ffi::Map packed_func_methods; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto prim_func = opt.value(); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 8276d26fcfa8..fcba187d5f90 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -45,8 +45,8 @@ namespace { class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const std::unordered_set& external_methods, - Stmt stmt) { + static ffi::Optional Apply(const std::unordered_set& external_methods, + Stmt stmt) { SubroutineCallRewriter rewriter(external_methods); stmt = rewriter.VisitStmt(std::move(stmt)); if (rewriter.made_change_) { @@ -65,7 +65,7 @@ class SubroutineCallRewriter : public StmtExprMutator { if (auto gvar = node->op.as()) { if (external_methods_.count(gvar)) { - Array args = node->args.Map([](const PrimExpr& arg) -> PrimExpr { + ffi::Array args = node->args.Map([](const PrimExpr& arg) -> PrimExpr { if (auto* as_call = arg.as()) { if (as_call->op.same_as(builtin::tvm_stack_make_array())) { PrimExpr data_ptr = as_call->args[0]; @@ -102,7 +102,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { } // Internal function calls do not need API updates - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.has_value()) { return func; } @@ -133,7 +133,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { std::vector device_init; // Collect variables and buffers to map between - Array args; + ffi::Array args; for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map @@ -156,7 +156,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { func_ptr->body = body; func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); - func_ptr->buffer_map = Map(); + func_ptr->buffer_map = ffi::Map(); // return the function. return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}}); @@ -169,7 +169,7 @@ Pass MakeUnpackedAPI() { std::unordered_set external_methods; for (const auto& [gvar, base_func] : mod->functions) { if (auto* prim_func = base_func.as()) { - if (prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { external_methods.insert(gvar.get()); } } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 73f5d7746da9..83965a29cbab 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -73,7 +73,7 @@ class IntermediateStageRewriter { BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; - Block new_block = GetRef(block); + Block new_block = ffi::GetRef(block); new_block.CopyOnWrite()->body = std::move(new_buffer_store); return {target_buffer, new_buffer, new_block, local_stage}; @@ -119,7 +119,7 @@ class IntermediateStageRewriter { /*! \brief Create the intermediate stage. */ Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, - Array local_stage_indices, + ffi::Array local_stage_indices, std::vector relaxed_loops, const BufferStoreNode* store) { // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices); @@ -135,9 +135,9 @@ class IntermediateStageRewriter { Downcast(local_stage)); // Step 2: Add outer loops - Map subst_map; + ffi::Map subst_map; for (const ForNode* relaxed_loop : relaxed_loops) { - ObjectPtr for_node = make_object(*relaxed_loop); + ObjectPtr for_node = ffi::make_object(*relaxed_loop); for_node->loop_var = for_node->loop_var.copy_with_suffix(""); for_node->body = std::move(local_stage); local_stage = For(for_node); @@ -148,10 +148,10 @@ class IntermediateStageRewriter { } /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */ - std::pair> CreateIntermediateBuffer( + std::pair> CreateIntermediateBuffer( const std::vector relaxed_loops, const Buffer& buffer) const { - Array buffer_indices; - Array new_buffer_shape; + ffi::Array buffer_indices; + ffi::Array new_buffer_shape; // Create the intermediate buffer for the local stage. The shape of the new buffer is the // extents of the relaxed outer loops. @@ -172,14 +172,14 @@ class IntermediateStageRewriter { class SharedMemoryLocalStageInserter : public StmtMutator { public: Stmt VisitStmt_(const ForNode* op) final { - ancestor_loop_or_blocks_.push_back(GetRef(op)); + ancestor_loop_or_blocks_.push_back(ffi::GetRef(op)); Stmt new_stmt = StmtMutator::VisitStmt_(op); ancestor_loop_or_blocks_.pop_back(); return new_stmt; } Stmt VisitStmt_(const BlockRealizeNode* op) final { - ancestor_loop_or_blocks_.push_back(GetRef(op)); + ancestor_loop_or_blocks_.push_back(ffi::GetRef(op)); Stmt new_stmt = StmtMutator::VisitStmt_(op); ancestor_loop_or_blocks_.pop_back(); return new_stmt; @@ -206,8 +206,8 @@ class SharedMemoryLocalStageInserter : public StmtMutator { op->alloc_buffers.begin(), op->alloc_buffers.end()); // Visit children and insert local stages (if any) to the proper location. - Array new_alloc_buffers; - Array new_seq; + ffi::Array new_alloc_buffers; + ffi::Array new_seq; // Helper function to check if the subtree (body of the block) contains any target buffers. // If so, the allocated intermediate buffer and the local stage should be lifted to the current @@ -236,7 +236,7 @@ class SharedMemoryLocalStageInserter : public StmtMutator { } } if (!changed) { - return GetRef(op); + return ffi::GetRef(op); } } else { int subtree_start = target_buffers_.size(); @@ -244,12 +244,12 @@ class SharedMemoryLocalStageInserter : public StmtMutator { int subtree_end = target_buffers_.size(); f_check_subtree(subtree_start, subtree_end); if (body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } new_seq.push_back(body); } - Block new_block = GetRef(op); + Block new_block = ffi::GetRef(op); BlockNode* new_block_node = new_block.CopyOnWrite(); // Add new buffer allocations if any. if (new_alloc_buffers.size() > 0) { @@ -260,9 +260,10 @@ class SharedMemoryLocalStageInserter : public StmtMutator { } std::vector ancestor_loop_or_blocks_; // ancestor loops or block realize - Map buffer_remap_; // mapping from the target buffer to the intermediate buffer - Map buffer_local_stage_; // mapping from the target buffer to the local stage - Array target_buffers_; // the target buffers for rewriting + ffi::Map + buffer_remap_; // mapping from the target buffer to the intermediate buffer + ffi::Map buffer_local_stage_; // mapping from the target buffer to the local stage + ffi::Array target_buffers_; // the target buffers for rewriting }; namespace transform { diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 43a976fa892f..094f48e321f6 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -40,13 +40,13 @@ Stmt FuseNestLoops(Stmt body) { } suffix += "_fused"; Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); - Map subst_map; + ffi::Map subst_map; PrimExpr tot = fused_var; for (int i = n - 1; i >= 0; i--) { subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); tot = floordiv(tot, loops[i]->extent); } - auto f_substitute = [&](const Var& v) -> Optional { + auto f_substitute = [&](const Var& v) -> ffi::Optional { return subst_map.Get(v).value_or(v); }; PrimExpr fused_extent = 1; @@ -74,19 +74,19 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { // generate thread binding loops std::vector factors{-1}; std::vector thread_axis; - if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); thread_axis.push_back("threadIdx.z"); } - if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); thread_axis.push_back("threadIdx.y"); } - if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + if (ffi::Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { int t = o_t.value()->value; tot_threads *= t; factors.push_back(t); @@ -114,7 +114,7 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { substitute_value += new_loop_vars[i]; } // Construct the new loop nest - Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + Stmt body = Substitute(loop->body, [&](const Var& v) -> ffi::Optional { if (v.same_as(loop->loop_var)) { return substitute_value; } else { @@ -152,17 +152,17 @@ Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints * the index mapping * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) */ -Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { +ffi::Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { Stmt body = stmt; while (const ForNode* loop = body.as()) { body = loop->body; } const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); BufferRegion write_region = constraints.write_region; - const Array& write_index = buf_store->indices; + const ffi::Array& write_index = buf_store->indices; ICHECK(write_region->region.size() == write_index.size() && write_region->buffer.same_as(buf_store->buffer)); - Array result; + ffi::Array result; arith::Analyzer analyzer; for (int i = 0; i < static_cast(write_region->region.size()); i++) { PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); @@ -176,10 +176,10 @@ Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body = stmt; - Map var_range; - Array loop_vars; + ffi::Map var_range; + ffi::Array loop_vars; // Step 1. Get index mapping - Array mapping_pattern = GetMapping(stmt, constraints); + ffi::Array mapping_pattern = GetMapping(stmt, constraints); while (const ForNode* loop = body.as()) { var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); loop_vars.push_back(loop->loop_var); @@ -191,14 +191,15 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, auto iter_map = arith::DetectIterMap(mapping_pattern, var_range, Bool(true), arith::Bijective, &analyzer); CHECK_EQ(iter_map->indices.size(), loop_vars.size()); - Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); + ffi::Map inverse_mapping = + arith::InverseAffineIterMap(iter_map->indices, loop_vars); // Step 3. Generate new body BufferRegion read_region = constraints.read_region; BufferRegion write_region = constraints.write_region; - Array write_index; - Array read_index; - Array new_loop_vars; - Map substitute_map; + ffi::Array write_index; + ffi::Array read_index; + ffi::Array new_loop_vars; + ffi::Map substitute_map; // Step 3.1 construct target buffer indices for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { if (is_one(write_region->region[i]->extent)) { diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc index 2ecb740ba327..5f7a1f494a7d 100644 --- a/src/tir/transforms/memhammer_intermediate_stage.cc +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -25,7 +25,7 @@ Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_bo Stmt* ith_loop = nullptr) { Stmt ret = inner_body; for (int i = static_cast(loops.size() - 1); i >= 0; i--) { - ObjectPtr new_loop = make_object(*loops[i]); + ObjectPtr new_loop = ffi::make_object(*loops[i]); new_loop->body = ret; ret = For(new_loop); if (ith == i) { @@ -71,7 +71,7 @@ std::pair LiftThreadBindingLoops(Stmt stmt) { */ class IndexPatternFinder : public ExprVisitor { public: - IndexPatternFinder(const Map& var_range, Array* resulting_index) + IndexPatternFinder(const ffi::Map& var_range, ffi::Array* resulting_index) : var_range_(var_range), resulting_index_(resulting_index) {} struct Operator { enum class OpKind { Mul, FloorDiv, FloorMod }; @@ -87,19 +87,19 @@ class IndexPatternFinder : public ExprVisitor { * \param rewrite_indices The access indices after rank promotion * \return The new buffer shape after rank promotion. */ - static Array getRankPromotedShape(Array indices, - const Map& var_range, - Array* rewrite_indices) { - Map var_dom = arith::AsIntSet(var_range); - Array new_shape; + static ffi::Array getRankPromotedShape(ffi::Array indices, + const ffi::Map& var_range, + ffi::Array* rewrite_indices) { + ffi::Map var_dom = arith::AsIntSet(var_range); + ffi::Array new_shape; for (const PrimExpr& expr : indices) { - Array indices_dim; + ffi::Array indices_dim; IndexPatternFinder extractor(var_range, &indices_dim); extractor(expr); if (!extractor.success_) { return {}; } - Array access_shape = extractor.access_shape_; + ffi::Array access_shape = extractor.access_shape_; PrimExpr product_shape = 1; for (PrimExpr e : access_shape) { product_shape *= e; @@ -119,8 +119,8 @@ class IndexPatternFinder : public ExprVisitor { if (!success_) { return; } - if (Optional range = var_range_.Get(GetRef(op))) { - PrimExpr index = GetRef(op); + if (ffi::Optional range = var_range_.Get(ffi::GetRef(op))) { + PrimExpr index = ffi::GetRef(op); int64_t max = range.value()->extent.as()->value; int64_t extent = max; for (int i = static_cast(operator_stack.size()) - 1; i >= 0; i--) { @@ -190,9 +190,9 @@ class IndexPatternFinder : public ExprVisitor { operator_stack.pop_back(); } - Map var_range_; - Array access_shape_; - Array* resulting_index_; + ffi::Map var_range_; + ffi::Array access_shape_; + ffi::Array* resulting_index_; std::vector operator_stack; bool success_ = true; }; @@ -225,15 +225,16 @@ class BufferLoadReplacer : public StmtExprMutator { * \return a pair. The first is the stmt after transformation. * The second is the SeqStmt that contains 2 stages (one original and another inserted). */ -std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - Optional compute_location, - const Array& outer_loops, Buffer* alloc_buffer) { +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::String storage_scope, + ffi::Optional compute_location, + const ffi::Array& outer_loops, + Buffer* alloc_buffer) { Stmt body = stmt; std::vector loops; std::vector loops_under_compute_location; std::vector relaxed_thread_loops; bool need_relax = !compute_location.defined(); - Map var_range; + ffi::Map var_range; PrimExpr vector_bytes = -1; // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into // several contiguous-changing dimensions @@ -253,7 +254,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } body = loop->body; } - Optional predicate; + ffi::Optional predicate; if (const auto* op = body.as()) { // the predicate is generated by coalescing predicate = op->condition; @@ -261,7 +262,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } for (const For& loop : outer_loops) { if (loop->kind == ForKind::kThreadBinding) { - const String& thread_tag = loop->thread_binding.value()->thread_tag; + const ffi::String& thread_tag = loop->thread_binding.value()->thread_tag; if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), runtime::ThreadScope::Create(thread_tag))) { var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -296,11 +297,11 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); - Array cache_indices; - Array new_shape; + ffi::Array cache_indices; + ffi::Array new_shape; bool use_rank_promotion = false; if (!is_write_cache && buf_store->value.as()) { - Array indices = + ffi::Array indices = is_write_cache ? buf_store->indices : buf_store->value.as()->indices; new_shape = IndexPatternFinder::getRankPromotedShape(indices, var_range, &cache_indices); // write cache disabled for now @@ -309,8 +310,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String use_rank_promotion = true; } } - Array new_loop_vars; - Map subst_map; + ffi::Array new_loop_vars; + ffi::Map subst_map; if (!use_rank_promotion) { cache_indices.clear(); for (const ForNode* loop : relaxed_thread_loops) { @@ -339,8 +340,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String cache_indices.push_back(loop->loop_var); } } - Array subst_indices; - Array subst_cache_indices; + ffi::Array subst_indices; + ffi::Array subst_cache_indices; if (is_write_cache) { for (PrimExpr e : buf_store->indices) { subst_indices.push_back(Substitute(e, subst_map)); @@ -366,8 +367,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String if (is_write_cache) { // copy from wmma to new cache buffer BufferLoad new_buffer_load{new_buffer, cache_indices}; - generate_body = - BufferLoadReplacer(target_buffer_load->buffer, new_buffer_load)(GetRef(buf_store)); + generate_body = BufferLoadReplacer(target_buffer_load->buffer, + new_buffer_load)(ffi::GetRef(buf_store)); generate_body = Substitute(generate_body, subst_map); } else { generate_body = @@ -384,14 +385,14 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = loops_under_compute_location[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->loop_var = new_loop_vars[i + relaxed_thread_loops.size()]; new_loop->body = generate_body; generate_body = For(new_loop); } for (int i = static_cast(relaxed_thread_loops.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = relaxed_thread_loops[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->loop_var = new_loop_vars[i]; new_loop->body = generate_body; new_loop->kind = ForKind::kSerial; @@ -402,7 +403,8 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String Stmt rewrite_body; if (is_write_cache) { BufferLoad new_buffer_load{new_buffer, cache_indices}; - rewrite_body = BufferStore(new_buffer, GetRef(target_buffer_load), cache_indices); + rewrite_body = + BufferStore(new_buffer, ffi::GetRef(target_buffer_load), cache_indices); } else { rewrite_body = BufferStore(buf_store->buffer, BufferLoad(new_buffer, cache_indices), buf_store->indices); @@ -412,7 +414,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { const ForNode* orig_loop = loops_under_compute_location[i]; - ObjectPtr new_loop = make_object(*orig_loop); + ObjectPtr new_loop = ffi::make_object(*orig_loop); new_loop->body = rewrite_body; rewrite_body = For(new_loop); } diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 15dd58d4ca75..2ecf1b804107 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -90,8 +90,8 @@ class AutoPadder { * \param buffers the given buffers * \return the list of new padded buffers */ - Array PadSharedMemory(const Array& buffers) { - Array result; + ffi::Array PadSharedMemory(const ffi::Array& buffers) { + ffi::Array result; for (const Buffer& buffer : buffers) { runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); @@ -113,7 +113,7 @@ class AutoPadder { low_dim_iter_space[i] = last_dim_iter_space; } PrimExpr stride = 1; - Array reverse_strides; + ffi::Array reverse_strides; int pad_min = padding_min_.Get(buffer).value_or(Integer(1)).IntValue(); // Step 2. For each dimension, select a padding that has minimal bank conflict for (int k = n - 2; k >= 0; k--) { // dims @@ -165,8 +165,8 @@ class AutoPadder { reverse_strides.push_back(stride); } // Step 3. create the new padded buffer - ObjectPtr b = make_object(*buffer.get()); - Array strides; + ObjectPtr b = ffi::make_object(*buffer.get()); + ffi::Array strides; for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { strides.push_back(reverse_strides[i]); } @@ -190,7 +190,7 @@ class AutoPadder { Stmt RewriteBufferAccess(const Stmt& stmt) { class Rewriter : public StmtExprMutator { public: - explicit Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + explicit Rewriter(const ffi::Map& buffer_map) : buffer_map_(buffer_map) {} private: PrimExpr VisitExpr_(const BufferLoadNode* _op) final { @@ -217,7 +217,7 @@ class AutoPadder { // after mutation. Otherwise we just return the original block. bool changed = false; // Step 1. Mutate the read region. - Array reads; + ffi::Array reads; for (const BufferRegion& read : op->reads) { if (buffer_map_.count(read->buffer)) { changed = true; @@ -227,7 +227,7 @@ class AutoPadder { } } // Step 2. Mutate the write region. - Array writes; + ffi::Array writes; for (const BufferRegion& write : op->writes) { if (buffer_map_.count(write->buffer)) { changed = true; @@ -238,7 +238,7 @@ class AutoPadder { } // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of // MatchBufferRegion, the storage scope of the target buffer also needs to be set. - Array match_buffers; + ffi::Array match_buffers; for (const MatchBufferRegion& match_buffer : op->match_buffers) { if (buffer_map_.count(match_buffer->source->buffer)) { changed = true; @@ -262,10 +262,10 @@ class AutoPadder { block->match_buffers = std::move(match_buffers); return Stmt(block); } else { - return GetRef(op); + return ffi::GetRef(op); } } - const Map& buffer_map_; + const ffi::Map& buffer_map_; }; Rewriter rewriter(padded_buffer_map_); return rewriter(stmt); @@ -287,7 +287,7 @@ class AutoPadder { if (!success_) { return; } - int extent = var_range_[GetRef(op)]->extent.as()->value; + int extent = var_range_[ffi::GetRef(op)]->extent.as()->value; if (extent > 1) { stack_.push({{extent, 1}}); } else { @@ -396,7 +396,7 @@ class AutoPadder { } public: - explicit PatternCollector(const Map& var_range) : var_range_(var_range) {} + explicit PatternCollector(const ffi::Map& var_range) : var_range_(var_range) {} /*! * \brief Collect the iteration space for given indices. The iteration space is the possible @@ -409,9 +409,8 @@ class AutoPadder { * \return The iteration space. The first array represents dimensions, and the second array * represents the iteration space of one dimension */ - static std::vector> CollectIterationSpace(const Array& indices, - const Map& var_range, - int data_bits) { + static std::vector> CollectIterationSpace( + const ffi::Array& indices, const ffi::Map& var_range, int data_bits) { PatternCollector collector(var_range); std::vector> ret; for (int i = 0; i < static_cast(indices.size()); i++) { @@ -444,30 +443,30 @@ class AutoPadder { } std::stack> stack_; - const Map& var_range_; + const ffi::Map& var_range_; bool success_ = true; }; /*! A utility class for calling CollectIterationSpace to each buffer access*/ class IterSpaceAnalyzer : public StmtExprVisitor { public: - IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, - const Map warp_thread_extent) + IterSpaceAnalyzer(const ffi::Map& substitute_map, AutoPadder* self, + int data_bits, const ffi::Map warp_thread_extent) : substitute_map_(substitute_map), self(self), data_bits_(data_bits), warp_thread_extent_(warp_thread_extent) {} private: - bool CheckVarContiguous(PrimExpr e, Var var, const Map& subst_map) { - PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + bool CheckVarContiguous(PrimExpr e, Var var, const ffi::Map& subst_map) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { return Integer(0); } else { return v; } }); - PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + PrimExpr e2 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { return Integer(1); } else { @@ -508,7 +507,7 @@ class AutoPadder { void VisitStmt_(const BufferStoreNode* op) final { runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -536,7 +535,7 @@ class AutoPadder { void VisitExpr_(const BufferLoadNode* op) final { runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -572,13 +571,13 @@ class AutoPadder { runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); if (scope.rank == runtime::StorageRank::kShared) { Region region = r->source->region; - Array indices; + ffi::Array indices; for (int i = 0; i < static_cast(region.size()); i++) { Var var("region" + std::to_string(i)); indices.push_back(region[i]->min + var); var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); } - Array substitued_indices; + ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : indices) { substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); @@ -595,11 +594,11 @@ class AutoPadder { } } - Map substitute_map_; + ffi::Map substitute_map_; AutoPadder* self; int data_bits_; - Map warp_thread_extent_; - Map var_range_; + ffi::Map warp_thread_extent_; + ffi::Map var_range_; int vector_length_ = -1; Var vector_var; }; @@ -611,11 +610,12 @@ class AutoPadder { * \param data_bits The length of dtype in bits * \param thread_extent The extents of all thread binding loops */ - void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, - const Map& thread_extent) { - Map warp_thread_extent; + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const ffi::Array& outer_loops, + int data_bits, + const ffi::Map& thread_extent) { + ffi::Map warp_thread_extent; Integer prod = 1; - Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + ffi::Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; arith::Analyzer analyzer; for (int i = 0; i < 3; i++) { Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); @@ -628,7 +628,7 @@ class AutoPadder { prod *= extent; } } - Map substitute_map; + ffi::Map substitute_map; for (const For& loop : outer_loops) { substitute_map.Set(loop->loop_var, loop->min); } @@ -638,11 +638,11 @@ class AutoPadder { private: /*! \brief A map from the old buffers to the new padded buffers */ - Map padded_buffer_map_; + ffi::Map padded_buffer_map_; /*! \brief A map from each buffer to the iteration spaces of the accesses*/ std::unordered_map>>> iter_spaces_; /*! \brief A map from each buffer to their minimal padding size */ - Map padding_min_; + ffi::Map padding_min_; /*! \brief max padding size in relative to the original shape*/ const double max_pad_factor_ = 0.25; @@ -651,7 +651,8 @@ class AutoPadder { class AutoCopyMutator : public StmtExprMutator { public: - explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + explicit AutoCopyMutator(ffi::Map thread_extent) + : thread_extent_(thread_extent) {} /** * \brief Replace old buffers with padded buffers in the stmt * \param stmt The stmt to rewrite @@ -708,16 +709,16 @@ class AutoCopyMutator : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) final { - outer_loops_.push_back(GetRef(op)); + outer_loops_.push_back(ffi::GetRef(op)); Stmt stmt = StmtMutator::VisitStmt_(op); outer_loops_.pop_back(); return stmt; } /*! \brief Thread extents collected. */ - Map thread_extent_; + ffi::Map thread_extent_; /*! \brief The outer loops during recursive visit */ - Array outer_loops_; + ffi::Array outer_loops_; /*! \brief Calculating optimal padding size */ AutoPadder padder; @@ -736,7 +737,7 @@ class AutoCopyMutator : public StmtExprMutator { */ class ThreadExtentCollector : public StmtVisitor { public: - static Map CollectThreadExtent(const Stmt& stmt) { + static ffi::Map CollectThreadExtent(const Stmt& stmt) { ThreadExtentCollector collector; collector(stmt); return collector.thread_extent_; @@ -744,7 +745,7 @@ class ThreadExtentCollector : public StmtVisitor { private: void VisitStmt_(const BlockNode* op) final { - if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (ffi::Optional warp_execution = GetAnn(op, "warp_execution")) { if (warp_execution.value()->value != 0) { thread_extent_.Set("threadIdx.x", Integer(32)); } @@ -754,14 +755,14 @@ class ThreadExtentCollector : public StmtVisitor { void VisitStmt_(const ForNode* op) final { if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { if (const auto* extent = op->extent.as()) { - thread_extent_.Set(op->thread_binding.value()->thread_tag, GetRef(extent)); + thread_extent_.Set(op->thread_binding.value()->thread_tag, ffi::GetRef(extent)); } } StmtVisitor::VisitStmt_(op); } /*! \brief the map from thread tag to its extent */ - Map thread_extent_; + ffi::Map thread_extent_; }; namespace transform { diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h index 46c9a97c527d..5751aa119e36 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -37,9 +37,9 @@ namespace tir { /*! \brief The set containing all possible constraints of a data copy */ struct ConstraintSet { /*! \brief The extents of the thread binding loops */ - Map thread_extent; + ffi::Map thread_extent; /*! \brief The outer loops surrounding the data copy */ - Array outer_loops; + ffi::Array outer_loops; /*! \brief The read region of the data copy */ BufferRegion read_region; /*! \brief The write region of the data copy */ @@ -51,12 +51,12 @@ struct ConstraintSet { /*! \brief The vectorization length in bytes */ int vector_bytes = 1; - explicit ConstraintSet(Map thread_extent, // - Array outer_loops, // - BufferRegion read_region, // - BufferRegion write_region, // - int data_bits, // - const Map& ann) + explicit ConstraintSet(ffi::Map thread_extent, // + ffi::Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const ffi::Map& ann) : thread_extent(thread_extent), outer_loops(outer_loops), read_region(read_region), @@ -74,9 +74,9 @@ struct ConstraintSet { /*! \brief The set containing all possible outputs of a rewrite rule */ struct OutputSet { /*! \brief New buffers allocated after rewrite */ - Array alloc_buffer; + ffi::Array alloc_buffer; /*! \brief The minimal padding size of a buffer in base 2 logarithm */ - Map padding_min; + ffi::Map padding_min; }; /*! @@ -248,9 +248,9 @@ class WmmaToShared : public RewriteRule { * \return a pair. The first is the stmt after transformation. * The second is the SeqStmt that contains 2 stages (one original and another inserted). */ -std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - Optional compute_location, - const Array& outer_loops, Buffer* alloc_buffer); +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::String storage_scope, + ffi::Optional compute_location, + const ffi::Array& outer_loops, Buffer* alloc_buffer); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index 5a0d0fa2105c..c1b303e0731b 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -28,7 +28,7 @@ namespace tir { * \return A pair. The first is the stmt after transformation. * The second is the compute location where we may add write cache. */ -std::pair> TileWmmaBlock(Stmt stmt) { +std::pair> TileWmmaBlock(Stmt stmt) { Stmt body = stmt; std::vector loops; while (const ForNode* loop = body.as()) { @@ -52,7 +52,7 @@ std::pair> TileWmmaBlock(Stmt stmt) { /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), }; body = Substitute(std::move(body), - Map{ + ffi::Map{ {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, }); @@ -76,15 +76,16 @@ std::pair> TileWmmaBlock(Stmt stmt) { return {body, compute_location}; } -Array RelaxIndices(const Array& indices, const Array& shape, - const Map& var_dom) { - Array int_set; +ffi::Array RelaxIndices(const ffi::Array& indices, + const ffi::Array& shape, + const ffi::Map& var_dom) { + ffi::Array int_set; int_set.reserve(indices.size()); for (auto& indice : indices) { int_set.push_back(arith::EvalSet(indice, var_dom)); } int ndim = int_set.size(); - Array region; + ffi::Array region; region.reserve(ndim); for (int i = 0; i < ndim; ++i) { region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); @@ -110,7 +111,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -141,8 +142,8 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*data_alignment=*/64, /*offset_factor=*/16, /*buffer_type=*/kDefault); - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = BlockRealize( /*iter_values=*/{}, /*predicate=*/Bool(true), @@ -209,7 +210,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -249,8 +250,8 @@ Stmt RewriteWmmaStore(Stmt stmt) { /*offset_factor=*/16, /*buffer_type=*/kDefault); - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = BlockRealize( /*iter_values=*/{}, // /*predicate=*/Bool(true), @@ -333,7 +334,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator { Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - Optional compute_location{nullptr}; + ffi::Optional compute_location{nullptr}; std::tie(body, compute_location) = TileWmmaBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; @@ -347,7 +348,7 @@ Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, return rewriter(body); } -std::pair> TileMmaToGlobalBlock(Stmt stmt) { +std::pair> TileMmaToGlobalBlock(Stmt stmt) { // i, j = sch.get_loops(block)[2:] // i_0, i_1 = sch.split(i, factors=[None, 8]) // j_0, j_1 = sch.split(j, factors=[None, 8]) @@ -376,7 +377,7 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), }; body = Substitute(std::move(body), - Map{ + ffi::Map{ {loops[n - 2]->loop_var, new_loop_vars[0] * 8 + new_loop_vars[2]}, {loops[n - 1]->loop_var, new_loop_vars[1] * 8 + new_loop_vars[3]}, }); @@ -418,7 +419,7 @@ Stmt RewriteMmaStore(Stmt stmt) { } int n = loops.size(); - Map var_dom{ + ffi::Map var_dom{ {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, }; @@ -468,8 +469,8 @@ Stmt RewriteMmaStore(Stmt stmt) { /*buffer_type=*/kDefault); // Step 3.2. Generate new r/w region - Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); - Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); // Step 3.3. Generate new inner loop body // for v in T.vectorized(2): @@ -542,7 +543,7 @@ class MmaToGlobalRewriter : public StmtExprMutator { Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body{nullptr}; - Optional compute_location{nullptr}; + ffi::Optional compute_location{nullptr}; std::tie(body, compute_location) = TileMmaToGlobalBlock(stmt); SeqStmt seq{nullptr}; Buffer cache_buffer; diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 63342bd2ec8d..e477df27ce80 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -125,7 +125,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -156,7 +156,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -178,7 +178,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -352,8 +352,8 @@ class SharedMemoryRewriter : public StmtExprMutator { << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " << "FlattenBuffer"; - Array indices = {node->indices[0] + - this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; + ffi::Array indices = { + node->indices[0] + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; auto writer = node.CopyOnWrite(); writer->buffer = GetUpdatedBuffer(node->buffer); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index b09a4dc17b26..0a95018f139c 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -212,7 +212,7 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { Stmt operator()(Stmt s) { visitor_(s); for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { - PrimExpr e = GetRef(i->first); + PrimExpr e = ffi::GetRef(i->first); if (e.dtype() == i->second) { i = visitor_.vmap.erase(i); } else { @@ -268,7 +268,7 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ if (a.dtype() != b.dtype()) { \ bool is_enabled = is_enabled_; \ diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index bcd5f53dd4f4..2a8c3d520c60 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -53,7 +53,7 @@ class CollectManagedAllocations : public StmtExprVisitor { /*! \brief Collect the allocate buffer order. */ class BufferAllocateOrderCollector : public StmtExprVisitor { public: - static Array Collect(const PrimFunc& func) { + static ffi::Array Collect(const PrimFunc& func) { BufferAllocateOrderCollector collector; for (const auto& kv : func->buffer_map) { collector.buffer_alloc_recorder_.push_back(kv.second); @@ -98,16 +98,16 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { } /*! \brief The buffer allocated order recorder. */ - Array buffer_alloc_recorder_; + ffi::Array buffer_alloc_recorder_; }; class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { - Map> buffer_lca = DetectBufferAccessLCA(func); + ffi::Map> buffer_lca = DetectBufferAccessLCA(func); // The buffer_alloc_recorder Array is used to keep the buffer allocation order // since the buffer_lca Map is unordered. - Array buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); + ffi::Array buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func); std::unordered_set arg_buffer_vars; CollectManagedAllocations collector; collector(func->body); @@ -145,7 +145,7 @@ class BufferAllocationLocator : public StmtExprMutator { } auto node = Downcast(StmtMutator::VisitStmt_(op)); - Array new_block_alloc_bufs; + ffi::Array new_block_alloc_bufs; for (const Buffer& buf : it->second) { if (managed_allocations_.count(buf->data.get())) { buffer_data_to_buffer_.erase(buf->data); @@ -162,7 +162,7 @@ class BufferAllocationLocator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { ICHECK(!op->init.defined()); - Array alloc_buffers; + ffi::Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { alloc_buffers = it->second; @@ -206,7 +206,7 @@ class BufferAllocationLocator : public StmtExprMutator { throw; } - Stmt InjectOpaqueBlock(Stmt body, const Array& alloc_buffers) { + Stmt InjectOpaqueBlock(Stmt body, const ffi::Array& alloc_buffers) { ICHECK(!alloc_buffers.empty()); Block opaque_block(/*iter_vars=*/{}, /*reads=*/{}, @@ -216,7 +216,7 @@ class BufferAllocationLocator : public StmtExprMutator { /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); - Array> access = + ffi::Array> access = GetBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); n->reads = access[0]; n->writes = access[1]; @@ -224,8 +224,9 @@ class BufferAllocationLocator : public StmtExprMutator { return realize; } - Array RemoveRedundantBufferRegion(const Array& region) const { - Array result; + ffi::Array RemoveRedundantBufferRegion( + const ffi::Array& region) const { + ffi::Array result; for (const BufferRegion& buffer_region : region) { if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) { result.push_back(buffer_region); @@ -235,9 +236,9 @@ class BufferAllocationLocator : public StmtExprMutator { } /*! \brief The map from stmt to the buffers to be allocated under it. */ - std::unordered_map> alloc_buffers_; + std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ - Map buffer_data_to_buffer_; + ffi::Map buffer_data_to_buffer_; /*! \brief Buffers that are allocated within a BlockNode, and may be moved. */ std::unordered_set managed_allocations_; }; diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index b1f3476eab73..f3c72c9e0808 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -36,7 +36,7 @@ transform::Pass AnnotateEntryFunc() { auto [gvar, base_func] = *mod->functions.begin(); if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { if (auto ptr = base_func.as()) { - mod->Update(gvar, WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, true)); + mod->Update(gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); } } return mod; @@ -47,11 +47,11 @@ transform::Pass AnnotateEntryFunc() { bool has_external_non_primfuncs = false; IRModule with_annotations; for (const auto& [gvar, base_func] : mod->functions) { - bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); + bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_external) { if (auto ptr = base_func.as()) { - with_annotations->Add(gvar, - WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, true)); + with_annotations->Add( + gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); } else { has_external_non_primfuncs = true; } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 14ad70122798..46fb38b48ba0 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -70,13 +70,13 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; } - if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; auto launch_params = opt.value(); // replace the thread axis attribute @@ -97,7 +97,7 @@ PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { namespace transform { -Pass RemapThreadAxis(Map thread_map) { +Pass RemapThreadAxis(ffi::Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 9db6f9f32808..c9c738128638 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -181,20 +181,20 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Evaluate(0); } } Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = GetRef(op); + BufferStore store = ffi::GetRef(op); // Helper function that returns a statement containing only the // side effects of evaluating this BufferStore, but not the store // itself. auto only_side_effects = [&]() { - Array statements; + ffi::Array statements; statements.push_back(MakeEvaluate(store->value)); for (const auto& index : store->indices) { statements.push_back(MakeEvaluate(index)); @@ -204,7 +204,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { if (touch_pattern_.has_value()) { // A write that is later overwritten is a no-op. - Stmt context = context_ ? GetRef(context_) : store; + Stmt context = context_ ? ffi::GetRef(context_) : store; if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) { touch_pattern_->RemoveStore(store); return only_side_effects(); @@ -217,7 +217,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { - Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); + Stmt context_arg = context_ ? ffi::GetRef(context_) : Stmt(store); stores_existing_value = touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_); } else { @@ -257,7 +257,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { } private: - bool ArrayValueEqual(const Array& a, const Array& b) { + bool ArrayValueEqual(const ffi::Array& a, const ffi::Array& b) { if (a.size() != b.size()) { return false; } @@ -280,8 +280,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { return Evaluate(0); } } - Stmt MakeEvaluate(const Array& values) { - Array stmts; + Stmt MakeEvaluate(const ffi::Array& values) { + ffi::Array stmts; for (PrimExpr e : values) { if (SideEffect(e) > CallEffectKind::kReadState) { stmts.push_back(Evaluate(e)); diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 13dac2789b43..561d46164b5a 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -35,8 +35,9 @@ namespace tir { class RemoveLayoutRewriteBlock : public StmtMutator { public: - static std::tuple, std::unordered_map, - std::unordered_map>> + static std::tuple, + std::unordered_map, + std::unordered_map>> Rewrite(PrimFunc f) { RemoveLayoutRewriteBlock rewriter; @@ -54,7 +55,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { if (it == block->annotations.end() || !is_one(Downcast((*it).second))) { // The block is not a weight layout block // Remove allocates if needed - Array alloc_buffers; + ffi::Array alloc_buffers; for (const Buffer& buffer : block->alloc_buffers) { if (!rewritten_buffers_.count(buffer)) { alloc_buffers.push_back(buffer); @@ -91,7 +92,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { n->reads = {}; n->writes = {}; - Array load_indices; + ffi::Array load_indices; for (auto ind : load->indices) { ICHECK(ind->IsInstance()); load_indices.push_back(Downcast(ind)); @@ -105,14 +106,14 @@ class RemoveLayoutRewriteBlock : public StmtMutator { private: /*! \brief The buffer map from original layout buffer to rewritten buffer */ - Map buf_map_; + ffi::Map buf_map_; /*! \brief The buffer map from original layout buffer to rewritten buffer */ std::unordered_set rewritten_buffers_; /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ - std::unordered_map> buffer_var_to_rewritten_shape_; + std::unordered_map> buffer_var_to_rewritten_shape_; }; // After RemoveLayoutRewriteBlock, the body of a compute update block references a @@ -149,7 +150,7 @@ class AllocateConstRewrite : public StmtExprMutator { AllocateConstRewrite( const BufferVarMap& buffer_var_map, const std::unordered_map& buffer_var_to_index_map, - const std::unordered_map>& buffer_var_to_rewritten_shape, + const std::unordered_map>& buffer_var_to_rewritten_shape, bool skip_tensor_rewrite) : buffer_var_map_(buffer_var_map), buffer_var_to_index_map_(buffer_var_to_index_map), @@ -160,7 +161,7 @@ class AllocateConstRewrite : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { Block block = Downcast(StmtMutator::VisitStmt_(op)); auto n = CopyOnWrite(block.get()); - Array new_reads; + ffi::Array new_reads; for (auto read_region : op->reads) { if (auto it = new_load_buf_.find(read_region->buffer->data.get()); it != new_load_buf_.end()) { @@ -180,7 +181,7 @@ class AllocateConstRewrite : public StmtExprMutator { auto new_body = StmtMutator::VisitStmt(alloc->body); auto rewritten_tensor = RewriteTensor( alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]); - Array rewritten_extents; + ffi::Array rewritten_extents; for (auto s : rewritten_tensor.Shape()) { rewritten_extents.push_back(PrimExpr(static_cast(s))); } @@ -193,9 +194,9 @@ class AllocateConstRewrite : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) { auto new_buffer = - Buffer(GetRef(it->second), op->buffer->dtype, op->buffer->shape, op->buffer->strides, - op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, - op->buffer->offset_factor, op->buffer->buffer_type); + Buffer(ffi::GetRef(it->second), op->buffer->dtype, op->buffer->shape, + op->buffer->strides, op->buffer->elem_offset, it->second->name_hint, + op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; return BufferLoad(new_buffer, op->indices, op->predicate); } @@ -203,7 +204,7 @@ class AllocateConstRewrite : public StmtExprMutator { } runtime::Tensor RewriteTensor(runtime::Tensor src, const IndexMap& index_map, - const Array& dst_shape) { + const ffi::Array& dst_shape) { if (skip_tensor_rewrite_) { // Only the shape of the destination array needs to be correct. std::vector dst_shape_int; @@ -223,7 +224,7 @@ class AllocateConstRewrite : public StmtExprMutator { in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ - std::unordered_map> buffer_var_to_rewritten_shape_; + std::unordered_map> buffer_var_to_rewritten_shape_; /*! \brief Maps load buffer variables to newly created buffers */ std::unordered_map new_load_buf_; /*! \brief Whether or not to skip rewriting of Tensor contents */ @@ -263,7 +264,7 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { buffer_var_to_rewritten_shape, skip_tensor_rewrite); n->body = rewriter(std::move(n->body)); - Map buffer_map; + ffi::Map buffer_map; for (const auto& [param, buffer] : f_->buffer_map) { auto it = buf_map.find(buffer); if (it != buf_map.end()) { diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 167453c04fe0..47bbc73dfed6 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -37,7 +37,7 @@ namespace tir { Stmt stmt = StmtExprMutator::VisitStmt_(op); \ op = stmt.as(); \ ICHECK(op != nullptr); \ - auto n = make_object(*op); \ + auto n = ffi::make_object(*op); \ n->FIELD = std::move(new_var); \ return Stmt(n); \ } @@ -47,7 +47,7 @@ class RenewDefMutator : public StmtExprMutator { static PrimFunc Transform(const PrimFunc& func) { RenewDefMutator generator; // Redefine params - Array params; + ffi::Array params; for (const auto& param : func->params) { params.push_back(generator.ReDefineVar(param)); } @@ -56,8 +56,8 @@ class RenewDefMutator : public StmtExprMutator { const Buffer& buffer = func->buffer_map.at(param); for (const PrimExpr& e : buffer->shape) { if (const auto* v = e.as()) { - if (generator.remap_.count(GetRef(v)) == 0) { - generator.ReDefineVar(GetRef(v)); + if (generator.remap_.count(ffi::GetRef(v)) == 0) { + generator.ReDefineVar(ffi::GetRef(v)); } } } @@ -65,7 +65,7 @@ class RenewDefMutator : public StmtExprMutator { } // Redefine buffers in order // TODO(Siyuan Feng): checking var is used after define - Map buffer_map; + ffi::Map buffer_map; for (const auto& param : func->params) { if (param->dtype.is_handle()) { const Buffer& buffer = func->buffer_map.at(param); @@ -105,32 +105,32 @@ class RenewDefMutator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { // Step 0. Re-define Itervars - Array iter_vars = + ffi::Array iter_vars = op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); // Step 1. Re-define buffers allocate under the block - Array alloc_buffers = op->alloc_buffers.Map( + ffi::Array alloc_buffers = op->alloc_buffers.Map( std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); // Step 2. Re-define match_buffers - Array match_buffers = op->match_buffers.Map( + ffi::Array match_buffers = op->match_buffers.Map( std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body - Optional init = std::nullopt; + ffi::Optional init = std::nullopt; if (op->init.defined()) { init = this->VisitStmt(op->init.value()); } Stmt body = this->VisitStmt(op->body); // Step 4. Revisit access region - Array reads = + ffi::Array reads = op->reads.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = + ffi::Array writes = op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); // Step 5. Regenerate block. Since the defs are changed, we need to create a new block - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->iter_vars = std::move(iter_vars); n->alloc_buffers = std::move(alloc_buffers); n->match_buffers = std::move(match_buffers); @@ -150,7 +150,7 @@ class RenewDefMutator : public StmtExprMutator { if (buffer.same_as(op->buffer)) { return stmt; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = std::move(buffer); return BufferStore(n); } @@ -164,7 +164,7 @@ class RenewDefMutator : public StmtExprMutator { if (buffer.same_as(op->buffer)) { return expr; } else { - auto n = make_object(*op); + auto n = ffi::make_object(*op); n->buffer = std::move(buffer); return BufferLoad(n); } @@ -172,7 +172,7 @@ class RenewDefMutator : public StmtExprMutator { private: Var ReDefineVar(const Var& var) { - Var new_var = Var(make_object(*var.get())); + Var new_var = Var(ffi::make_object(*var.get())); this->AddDefRemap(var, new_var); return new_var; } @@ -204,13 +204,13 @@ class RenewDefMutator : public StmtExprMutator { // update data Var data = Downcast(redefine_if_is_var(buffer->data)); // update shape - Array shape = buffer->shape.Map(redefine_if_is_var); + ffi::Array shape = buffer->shape.Map(redefine_if_is_var); // update strides - Array strides = buffer->strides.Map(redefine_if_is_var); + ffi::Array strides = buffer->strides.Map(redefine_if_is_var); // update elem_offset PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->shape = std::move(shape); n->strides = std::move(strides); @@ -243,13 +243,13 @@ class RenewDefMutator : public StmtExprMutator { return Downcast((*it).second); } Var data = Downcast(VisitExpr(buffer->data)); - Array shape = + ffi::Array shape = buffer->shape.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); - Array strides = + ffi::Array strides = buffer->strides.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); - auto n = make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->data = std::move(data); n->shape = std::move(shape); n->strides = std::move(strides); @@ -277,7 +277,7 @@ class RenewDefMutator : public StmtExprMutator { BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { Buffer buffer = VisitBuffer(buffer_region->buffer); - Array region = buffer_region->region.Map( + ffi::Array region = buffer_region->region.Map( std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { return buffer_region; @@ -286,7 +286,7 @@ class RenewDefMutator : public StmtExprMutator { } } - Map remap_; + ffi::Map remap_; }; PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } diff --git a/src/tir/transforms/replace_global_vars.cc b/src/tir/transforms/replace_global_vars.cc index 3e8437063775..b16926056b7d 100644 --- a/src/tir/transforms/replace_global_vars.cc +++ b/src/tir/transforms/replace_global_vars.cc @@ -35,8 +35,8 @@ namespace { using tvm::transform::GlobalVarReplacer; struct Mutator : StmtExprMutator { - Map replacements; - explicit Mutator(Map replacements) : replacements(replacements) {} + ffi::Map replacements; + explicit Mutator(ffi::Map replacements) : replacements(replacements) {} PrimExpr VisitExpr_(const CallNode* node) override { auto call = Downcast(StmtExprMutator::VisitExpr_(node)); @@ -53,7 +53,7 @@ struct Mutator : StmtExprMutator { TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) .set_dispatch([](const ObjectRef& obj, - Map replacements) -> BaseFunc { + ffi::Map replacements) -> BaseFunc { Mutator mutator(replacements); auto func = Downcast(obj); auto new_body = mutator(func->body); @@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) // If the function is externally exposed, and is being replaced // by a GlobalVar with a new name, then the function's // kGlobalSymbol must be updated to match. - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = opt.value(); for (const auto& [before, after] : replacements) { if (before->name_hint == name) { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 2b087c924f58..f1b79f8122c0 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -115,7 +115,7 @@ std::unordered_set CollectVarsUsedInBufferDefinition(const Stmt& void VisitBuffer(const Buffer& buf) { // Collect variables that should remain defined - VarUseDefAnalyzer usage(Array{}); + VarUseDefAnalyzer usage(ffi::Array{}); usage(buf->data); for (const auto& dim : buf->shape) { usage(dim); @@ -150,7 +150,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, - Optional config_opt = std::nullopt) { + ffi::Optional config_opt = std::nullopt) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); @@ -194,7 +194,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt(const Stmt& stmt) override { - Optional cache = this->current_stmt_; + ffi::Optional cache = this->current_stmt_; this->current_stmt_ = stmt; Stmt output = Parent::VisitStmt(stmt); this->current_stmt_ = std::move(cache); @@ -249,7 +249,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (can_inline && !used_in_buffer_def) { return body; } else if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); @@ -259,7 +259,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } Stmt VisitStmt_(const IfThenElseNode* op) override { - if (Optional cond = ProveCondition(op->condition)) { + if (ffi::Optional cond = ProveCondition(op->condition)) { if (cond.value()->value) { return this->VisitStmt(op->then_case); } else if (op->else_case) { @@ -274,7 +274,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode* op) override { if (op->op.same_as(builtin::if_then_else())) { - if (Optional cond = ProveCondition(op->args[0])) { + if (ffi::Optional cond = ProveCondition(op->args[0])) { if (cond.value()->value) { return this->VisitExpr(op->args[1]); } else { @@ -303,7 +303,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } private: - bool ArrayDeepEqual(const Array& lhs, const Array& rhs) { + bool ArrayDeepEqual(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -320,7 +320,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { * Uses more aggressive optimization, such as performing additional * inlining and tracking known buffer values. */ - Optional ProveCondition(PrimExpr condition) const { + ffi::Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); if (config_->propagate_knowns_to_prove_conditional) { ICHECK(touch_pattern_.has_value()); @@ -338,8 +338,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { SimplifyConfig config_; std::optional touch_pattern_; - Map non_inlined_bindings_; - Optional current_stmt_{std::nullopt}; + ffi::Map non_inlined_bindings_; + ffi::Optional current_stmt_{std::nullopt}; std::unordered_set used_in_buffer_def_; }; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 796514e02762..feeea7b3fcfe 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -53,7 +53,7 @@ class HostDeviceSplitter : public StmtMutator { private: Stmt SplitDeviceFunc(Stmt body, Target device_target) { - auto [params, buffers_to_declare] = [&]() -> std::tuple, Array> { + auto [params, buffers_to_declare] = [&]() -> std::tuple, ffi::Array> { VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); use_def(body); @@ -98,7 +98,7 @@ class HostDeviceSplitter : public StmtMutator { GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); - Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); + ffi::Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); if (can_propagate_errors) { Var kernel_error_code("kernel_error_code", success->dtype); @@ -137,14 +137,14 @@ Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { GlobalVarSupply global_var_supply(mod); - IRModule device_mod = IRModule(Map({})); - IRModule updates = IRModule(Map({})); + IRModule device_mod = IRModule(ffi::Map({})); + IRModule updates = IRModule(ffi::Map({})); for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { PrimFunc func = opt.value(); - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); auto kernel_name = name_prefix + "_kernel"; auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 8c7a7035defa..2a38e64cc7e2 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -171,7 +171,7 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { for (AccessEntry& e : s.access) { if (e.buffer.defined()) { ICHECK(e.touched.size()); - Array new_touched; + ffi::Array new_touched; for (const auto& touched : e.touched) { new_touched.push_back(arith::EvalSet(touched, relax_map)); } @@ -250,7 +250,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(GetRef(buffer)); + StorageScope scope = GetScope(ffi::GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index a0e03b35cdaa..10b26f7c2ab2 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -56,7 +56,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \brief An access entry */ struct AccessEntry { /*! \brief The thread index that access this entry */ - Array threads; + ffi::Array threads; /*! \brief The buffer variable, if any */ Var buffer = NullValue(); /*! \brief The access data type */ @@ -65,7 +65,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * * Has one IntSet for each index in the buffer being accessed. */ - Array touched; + ffi::Array touched; /*! \brief The type of access */ AccessType type; /*! \brief The storage scope */ @@ -98,7 +98,7 @@ class StorageAccessVisitor : public StmtExprVisitor { /*! \return whether we are in device environment. */ bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ - const Array& env_threads() const { return env_threads_; } + const ffi::Array& env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked @@ -138,7 +138,7 @@ class StorageAccessVisitor : public StmtExprVisitor { // the current free stmt entry. StmtEntry curr_stmt_; // The involving threads - Array env_threads_; + ffi::Array env_threads_; }; } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 7112f62a1088..9570a3f17f04 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -406,7 +406,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (it != alloc_map_.end()) { Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); - Array indices = node->indices; + ffi::Array indices = node->indices; indices.Set(indices.size() - 1, RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); @@ -453,7 +453,7 @@ class StoragePlanRewriter : public StmtExprMutator { } return it->second->alloc_var; } else { - return GetRef(op); + return ffi::GetRef(op); } } PrimExpr VisitExpr_(const CallNode* op) final { @@ -840,7 +840,7 @@ class StoragePlanRewriter : public StmtExprMutator { ICHECK(alloc_info.count(var)); const AllocEntry& entry = alloc_info.at(var); const AllocateNode* alloc = entry.alloc; - auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); + auto storage_scope = StorageScope::Create(GetPtrStorageScope(ffi::GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -1145,7 +1145,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * missing a type annotation, assume that it has the same underlying * type as it is later accessed, with scalar element types. */ - VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + VectorTypeAccessChecker(const ffi::Array& params, + const ffi::Map& buffer_map, bool allow_untyped_pointers = false, bool detect_scalar_read_patterns = true) : allow_untyped_pointers_(allow_untyped_pointers), @@ -1196,7 +1197,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateNode* op) final { - const Array& extents = op->extents; + const ffi::Array& extents = op->extents; PrimExpr extent = extents[extents.size() - 1]; OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); @@ -1204,7 +1205,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } void VisitStmt_(const AllocateConstNode* op) final { - const Array& extents = op->extents; + const ffi::Array& extents = op->extents; PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue(); OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); @@ -1271,8 +1272,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * * @param is_buffer_load Whether the access is BufferLoad */ - void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array& indices, - bool is_buffer_load) { + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, + const ffi::Array& indices, bool is_buffer_load) { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; @@ -1471,7 +1472,7 @@ class VectorTypeRewriter : public StmtExprMutator { } const auto& info = it->second; - Array indices = node->indices; + ffi::Array indices = node->indices; const PrimExpr& last_dim_index = indices[indices.size() - 1]; const RampNode* ramp_index = indices[indices.size() - 1].as(); @@ -1536,7 +1537,7 @@ class VectorTypeRewriter : public StmtExprMutator { Stmt body = this->VisitStmt(op->body); Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } return LetStmt(var, value, body); } @@ -1553,7 +1554,7 @@ class VectorTypeRewriter : public StmtExprMutator { if (info_it != rewrite_map_.end()) { auto& info = info_it->second; - Array shape = buf->shape; + ffi::Array shape = buf->shape; PrimExpr last_dim = shape[shape.size() - 1]; shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor())); @@ -1591,7 +1592,7 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.factor(); extent = extent / make_const(extent.dtype(), factor); index = index / make_const(index.dtype(), factor); - Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; + ffi::Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); } else { @@ -1612,7 +1613,7 @@ class VectorTypeRewriter : public StmtExprMutator { Var new_buffer_var = info.new_buffer_var; - Array extents = op->extents; + ffi::Array extents = op->extents; PrimExpr last_extent = extents[extents.size() - 1]; extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); @@ -1633,7 +1634,7 @@ class VectorTypeRewriter : public StmtExprMutator { int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - Array extents = op->extents; + ffi::Array extents = op->extents; extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body); @@ -1652,7 +1653,7 @@ class VectorTypeRewriter : public StmtExprMutator { auto* n = func.CopyOnWrite(); // Remap any remaining references to the old buffer variables - Map var_remap; + ffi::Map var_remap; for (const auto& pair : rewrite_map_) { const auto& info = pair.second; var_remap.Set(info.old_buffer_var, info.new_buffer_var); @@ -1660,7 +1661,7 @@ class VectorTypeRewriter : public StmtExprMutator { n->body = Substitute(n->body, var_remap); // Remap the argument list to use the new buffer variables. - Array new_params; + ffi::Array new_params; for (const auto& old_param : n->params) { auto it = rewrite_map_.find(old_param.get()); if (it == rewrite_map_.end()) { @@ -1674,7 +1675,7 @@ class VectorTypeRewriter : public StmtExprMutator { // Remap the Buffer objects in PrimFunc::buffer_map so that the // buffers use the new buffer variables - Map new_buffer_map; + ffi::Map new_buffer_map; for (const auto& pair : n->buffer_map) { Var key = pair.first; Buffer old_buffer = pair.second; @@ -1742,7 +1743,7 @@ Pass StorageRewrite() { enable_reuse = false; } - Optional target = f->GetAttr("target"); + ffi::Optional target = f->GetAttr("target"); if (target.defined() && (target.value()->kind->name == "vulkan" || target.value()->kind->name == "webgpu")) { // Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 8285ee96279c..082f19e782ef 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -59,7 +59,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -92,7 +92,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { FragmentInfo info = fragments[buffer_var]; ICHECK_EQ(m->value, info.m); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index f8b0a83d4d43..bb8d733d880e 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -401,7 +401,7 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); - Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; + ffi::Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); Stmt body = op->body; for (const auto& kv : rw_stats_) { @@ -463,7 +463,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { namespace transform { -Pass ThreadSync(String storage_scope) { +Pass ThreadSync(ffi::String storage_scope) { auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = ThreadSync(std::move(n->body), storage_scope); diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index bee45716b17d..626bc807dea0 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -44,7 +44,7 @@ namespace tir { class MmaBufferLayoutTransformer : public StmtExprMutator { public: Stmt VisitStmt_(const BlockNode* op) { - Block block = GetRef(op); + Block block = ffi::GetRef(op); auto* n = block.CopyOnWrite(); auto fmutate = [this](const Buffer& buffer) { // m16n8k8.matrix[A/B/C] buffers are composed ofseveral small blocks. Assume the block's @@ -164,10 +164,10 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) { - if (buffer_var_map_.count(GetRef(op))) { - return buffer_var_map_[GetRef(op)]; + if (buffer_var_map_.count(ffi::GetRef(op))) { + return buffer_var_map_[ffi::GetRef(op)]; } - return GetRef(op); + return ffi::GetRef(op); } private: diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index a9e47055e2a7..4da295980c50 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -60,14 +60,14 @@ class ThreadBindingUnifier : public StmtExprMutator { if (op->kind != ForKind::kThreadBinding) { return StmtExprMutator::VisitStmt_(op); } - Map annotations = op->annotations; + ffi::Map annotations = op->annotations; Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), Range::FromMinExtent(op->min, op->extent)); if (annotations.empty()) { return stmt; } if (const auto* loop = stmt.as()) { - For new_loop = GetRef(loop); + For new_loop = ffi::GetRef(loop); new_loop.CopyOnWrite()->annotations = std::move(annotations); return new_loop; @@ -88,7 +88,7 @@ class ThreadBindingUnifier : public StmtExprMutator { const Range& dom) { // Step 1. Fetch the thread tag. IterVar new_iter_var{nullptr}; - const String& thread_tag = old_iter_var->thread_tag; + const ffi::String& thread_tag = old_iter_var->thread_tag; // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the // thread block depth is 0 before the increment, it means we are entering a new kernel, and @@ -107,7 +107,7 @@ class ThreadBindingUnifier : public StmtExprMutator { // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the // IterVar in mapping `thread_tag2iter_var_map_`. - Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); + ffi::Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); @@ -164,22 +164,22 @@ class ThreadBindingUnifier : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* var) final { // If this variable appears as a key in `var_substitution_map_`, we substitute it with its // corresponding value in the mapping. - Map::iterator it = var_substitution_map_.find(GetRef(var)); - return it != var_substitution_map_.end() ? (*it).second : GetRef(var); + ffi::Map::iterator it = var_substitution_map_.find(ffi::GetRef(var)); + return it != var_substitution_map_.end() ? (*it).second : ffi::GetRef(var); } /*! * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all * occurrences of the thread tag */ - Map thread_tag2iter_var_map_; + ffi::Map thread_tag2iter_var_map_; /*! * \brief A list of IterVar corresponding to threads in current kernel. This will be used to * generate for-loops to launch threads. */ - Array launch_threads_; + ffi::Array launch_threads_; /*! \brief A mapping from old variables to new variables, which is used for substitution */ - Map var_substitution_map_; + ffi::Map var_substitution_map_; /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ int thread_block_depth_ = 0; /*! \brief An analyzer used for equality proof */ diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index fdddf2091141..27377309fa37 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -83,7 +83,7 @@ class VarLocalAccessMarker : public ExprVisitor { explicit VarLocalAccessMarker(std::unordered_set* var_touched_local) : var_touched_local_(var_touched_local) {} - void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } + void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(ffi::GetRef(op)); } private: std::unordered_set* var_touched_local_; @@ -176,7 +176,7 @@ class LoopUnroller : public StmtExprMutator { } } } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -222,8 +222,8 @@ class LoopUnroller : public StmtExprMutator { ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); Stmt body = op->body; - Map vmap; - Array unrolled; + ffi::Map vmap; + ffi::Array unrolled; for (int i = 0; i < value; ++i) { vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); Stmt step = Substitute(body, vmap); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 92f0a6de98e1..2b26633ac4e4 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -60,7 +60,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { var_remap_->erase(it); } } - Array drop_buffers; + ffi::Array drop_buffers; for (auto kv : *buffer_remap_) { if (opaque_var_access_.count(kv.first->data)) { drop_buffers.push_back(kv.first); @@ -79,7 +79,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { // remap all intermediate constant buffer to promote data types (fp16/fp32) if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) { DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes()); - String storage_scope = "global"; + ffi::String storage_scope = "global"; if (auto* ptr_type = op->buffer_var->type_annotation.as()) { storage_scope = ptr_type->storage_scope; } @@ -106,7 +106,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); - Var buffer_var = GetRef(op); + Var buffer_var = ffi::GetRef(op); if (buffer_var.dtype().is_handle()) { opaque_var_access_.insert(buffer_var); } @@ -153,7 +153,7 @@ class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner { PrimExpr origin_b = PromoteToTarget(this->VisitExpr(op->b)); \ \ if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \ - return GetRef(op); \ + return ffi::GetRef(op); \ } else { \ return FUNC(origin_a, origin_b); \ } \ @@ -189,7 +189,7 @@ class ComputeLegalizer : public StmtExprMutator { } if (op_val.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return cast(op->dtype, op_val); } @@ -201,7 +201,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr false_value = PromoteToTarget(this->VisitExpr(op->false_value)); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Select(condition, true_value, false_value); } @@ -210,7 +210,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr VisitExpr_(const BroadcastNode* op) final { PrimExpr value = PromoteToTarget(this->VisitExpr(op->value)); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(value, op->lanes); } @@ -220,7 +220,7 @@ class ComputeLegalizer : public StmtExprMutator { auto fexpr = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; auto vectors = op->vectors.Map(fexpr); if (vectors.same_as(op->vectors)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Shuffle(vectors, op->indices); } @@ -233,12 +233,12 @@ class ComputeLegalizer : public StmtExprMutator { } // update normal computations to return f32 instead. auto fmutate = [this](const PrimExpr& e) { return PromoteToTarget(this->VisitExpr(e)); }; - Array args = op->args.Map(fmutate); + ffi::Array args = op->args.Map(fmutate); if (MatchDType(op->dtype)) { return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args); } if (args.same_as(op->args)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Call(op->dtype, op->op, args); } @@ -248,11 +248,11 @@ class ComputeLegalizer : public StmtExprMutator { if (MatchDType(op->dtype)) { return FloatImm(promote_dtype_, op->value); } - return GetRef(op); + return ffi::GetRef(op); } PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { @@ -273,7 +273,7 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr body = VisitExpr(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, body); } @@ -302,7 +302,7 @@ class ComputeLegalizer : public StmtExprMutator { Stmt body = VisitStmt(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, body); } @@ -312,12 +312,12 @@ class ComputeLegalizer : public StmtExprMutator { PrimExpr value = this->VisitExpr(op->value); auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); Buffer new_buf = GetRemappedBuffer(op->buffer); if (value.same_as(op->value) && indices.same_as(op->indices) && new_buf.same_as(op->buffer)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (MatchDType(new_buf->dtype)) { int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; @@ -526,7 +526,7 @@ class StorageLegalizer : public StmtExprMutator { private: PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); auto itr = var_remap_.find(var); if (itr != var_remap_.end()) { return itr->second; @@ -538,7 +538,7 @@ class StorageLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { if (MatchDType(op->dtype)) { DataType dtype = GetStorageUIntDType(op->dtype); - String storage_scope = "global"; + ffi::String storage_scope = "global"; if (auto* ptr_type = op->buffer_var->type_annotation.as()) { storage_scope = ptr_type->storage_scope; } @@ -563,7 +563,7 @@ class StorageLegalizer : public StmtExprMutator { } Stmt body = VisitStmt(op->body); if (buf.same_as(op->buffer) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return DeclBuffer(buf, body, op->span); } @@ -575,7 +575,7 @@ class StorageLegalizer : public StmtExprMutator { PrimExpr body = VisitExpr(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(var, value, body); } @@ -587,7 +587,7 @@ class StorageLegalizer : public StmtExprMutator { Stmt body = VisitStmt(op->body); if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(var, value, body); } @@ -598,7 +598,7 @@ class StorageLegalizer : public StmtExprMutator { Buffer new_buf = GetRemappedBuffer(op->buffer); auto indices = op->indices.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); @@ -654,7 +654,7 @@ class StorageLegalizer : public StmtExprMutator { return reinterpret(GetStorageUIntDType(op->dtype), value); } if (op->args[0].same_as(value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return reinterpret(op->dtype, value); } @@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); }); -Pass FP8ComputeLegalize(String promote_dtype_str) { +Pass FP8ComputeLegalize(ffi::String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto target = f->GetAttr(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } - return FP8ComputeLegalizer(DataType(StringToDLDataType(promote_dtype_str))).Legalize(f); + return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 9af990d1e2bf..e12ab9696a99 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -37,7 +37,7 @@ namespace tvm { namespace tir { -Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { +Var WithStorageScope(const VarNode* buffer_var, ffi::String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), @@ -45,7 +45,7 @@ Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { } UpdatePointerStorageScope::UpdatePointerStorageScope( - const std::unordered_map& new_storage_scopes) { + const std::unordered_map& new_storage_scopes) { for (auto& kv : new_storage_scopes) { new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); } @@ -54,7 +54,7 @@ UpdatePointerStorageScope::UpdatePointerStorageScope( PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { auto it = new_var_remap_.find(op); if (it == new_var_remap_.end()) { - return GetRef(op); + return ffi::GetRef(op); } return it->second; } diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index 1f1399fba76b..a2f7027ce4f8 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -36,7 +36,7 @@ namespace tir { class UpdatePointerStorageScope : public StmtExprMutator { public: explicit UpdatePointerStorageScope( - const std::unordered_map& new_storage_scopes); + const std::unordered_map& new_storage_scopes); virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const BufferLoadNode*); diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 53509ce49710..f7edeb25dde7 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -119,13 +119,13 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { using Parent::VisitStmt_; // This struct stores all the relevant data related to asssume statement - struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) - PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) - PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding - // bufferload expression (A[i] == 0) - tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 - PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 - Array buffer_indices; // Storing the indices of the buffer Eg : i + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + ffi::Array buffer_indices; // Storing the indices of the buffer Eg : i }; // List of conditions in a scope std::vector conditions_; @@ -162,7 +162,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{std::nullopt}; + ffi::Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -209,7 +209,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { return buf_value; } } - return GetRef(op); + return ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -358,7 +358,7 @@ Pass UseAssumeToReduceBranches() { // the primfunc has op_pattern defined and is an elementwise op. // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. if (n->attrs.GetAttr("op_pattern").defined()) { - Optional opt_pattern = f->GetAttr("op_pattern"); + ffi::Optional opt_pattern = f->GetAttr("op_pattern"); if (opt_pattern.defined()) { relax::OpPatternKind pattern; pattern = static_cast(Downcast(opt_pattern)->value); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 8e350924501e..5bf60d3b675a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -75,7 +75,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { bool EnableBufferLevelPredication(Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - Optional enable_buffer_predication = + ffi::Optional enable_buffer_predication = pass_ctx->GetConfig("tir.enable_buffer_level_predication"); if (enable_buffer_predication.defined()) { return enable_buffer_predication.value(); @@ -160,7 +160,7 @@ class TryPredicateBufferAccesses : public StmtExprMutator { num_accesses_analyzed_ += 1; // Do not try to predicate non-vectorized accesses - Array indices = node->indices; + ffi::Array indices = node->indices; if (!indices.size() || !indices[0]->IsInstance()) { return node; } @@ -233,7 +233,7 @@ class VecAllocAccess : public StmtExprMutator { // Extend the least significant dimension by a factor of // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. - Array shape = node->buffer->shape; + ffi::Array shape = node->buffer->shape; shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to @@ -243,7 +243,7 @@ class VecAllocAccess : public StmtExprMutator { // are updated for consistency. // Update strides if defined. - Array strides; + ffi::Array strides; for (size_t i = 0; i < strides.size(); i++) { PrimExpr stride = strides[i]; if (i != strides.size() - 1) { @@ -262,7 +262,7 @@ class VecAllocAccess : public StmtExprMutator { // Extend the last index by the number of lanes in the vectorized // variable. - Array indices = node->indices; + ffi::Array indices = node->indices; indices.Set(indices.size() - 1, analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); @@ -322,7 +322,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); @@ -369,7 +369,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return ffi::GetRef(op); } else { return !(a); } @@ -396,7 +396,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor elems; + ffi::Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); @@ -408,10 +408,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Broadcast(op->value, op->lanes); } @@ -422,7 +422,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return GetRef(op); + return ffi::GetRef(op); } else { int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor(); @@ -438,7 +438,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return ffi::GetRef(op); } else { if (value.dtype().is_scalable_vector()) { return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value); @@ -448,15 +448,15 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } + PrimExpr VisitExpr_(const FloatImmNode* op) final { return ffi::GetRef(op); } - PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef(op); } + PrimExpr VisitExpr_(const IntImmNode* op) final { return ffi::GetRef(op); } - PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef(op); } + PrimExpr VisitExpr_(const StringImmNode* op) final { return ffi::GetRef(op); } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); + Var var = ffi::GetRef(op); if (var.same_as(var_)) { return ramp_; @@ -473,12 +473,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->args[0]); if (cond.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return GetRef(op); + return ffi::GetRef(op); } else { int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); @@ -498,7 +498,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { - return GetRef(op); + return ffi::GetRef(op); } else { int lanes = value.dtype().get_lanes_or_vscale_factor(); if (value.dtype().is_scalable_vector()) { @@ -518,7 +518,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::texture2d_load())) { int lane = 0; - Array fcd = MutateArray({op->args.back()}, &lane); + ffi::Array fcd = MutateArray({op->args.back()}, &lane); auto new_args = op->args; new_args.pop_back(); new_args.push_back(fcd[0]); @@ -526,9 +526,9 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::texture2d_store())) { int lane = 0; // Vectorize the value to store - Array value{op->args.back()}; - Array mutated_value = MutateArray(value, &lane); - Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; + ffi::Array value{op->args.back()}; + ffi::Array mutated_value = MutateArray(value, &lane); + ffi::Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; return Call(op->dtype.with_lanes(lane), op->op, new_args); } else if (op->op.same_as(builtin::reinterpret())) { return MutateReinterpretExpr_(op); @@ -539,32 +539,32 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args; + ffi::Array new_args; for (auto arg : op->args) { auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return ffi::GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Call(op->dtype, op->op, new_args); } } else { int lane = 0; - Array new_args; + ffi::Array new_args; if (op->op.same_as(builtin::call_llvm_pure_intrin())) { // op->args[1], will give us total number of arguments to intrinsic - Array op_expr_args; + ffi::Array op_expr_args; for (size_t i = 1; i < op->args.size(); ++i) { // Collect all intrinsic arguments op_expr_args.push_back(op->args[i]); } // Generate RAMP nodes for intrinsic arguments - Array updated_args = MutateArray(op_expr_args, &lane); + ffi::Array updated_args = MutateArray(op_expr_args, &lane); new_args.push_back(op->args[0]); // Collect updated intrinsic arguments for (size_t i = 0; i < updated_args.size(); ++i) { @@ -575,7 +575,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs.same_as(new_args)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Call(op->dtype.with_lanes(lane), op->op, new_args); } @@ -583,10 +583,10 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); + auto load = ffi::GetRef(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); if (!indices.same_as(op->indices)) { auto writer = load.CopyOnWrite(); @@ -619,7 +619,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -631,10 +631,10 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors.size() << " and the index size is " << op->indices.size(); int lane_vectors = 0; int lane_indices = 0; - Array vectors = MutateArray(op->vectors, &lane_vectors); - Array indices = MutateArray(op->indices, &lane_indices); + ffi::Array vectors = MutateArray(op->vectors, &lane_vectors); + ffi::Array indices = MutateArray(op->indices, &lane_indices); if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { - return GetRef(op); + return ffi::GetRef(op); } int new_vec_length = Downcast(var_lanes_)->value / op->vectors[0].dtype().lanes(); @@ -689,10 +689,10 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); + auto store = ffi::GetRef(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices.Map(fmutate); + ffi::Array indices = op->indices.Map(fmutate); PrimExpr value = this->VisitExpr(op->value); @@ -746,11 +746,11 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, op->annotations); @@ -766,7 +766,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitStmt(op->then_case); - Optional else_case = std::nullopt; + ffi::Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } @@ -782,11 +782,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); + return Scalarize(ffi::GetRef(op)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -802,7 +802,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); + Scalarize(ffi::GetRef(op)); } ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; @@ -816,7 +816,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar] = op->var; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -828,16 +828,16 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } // Mutate the extents - Array extents; + ffi::Array extents; for (const auto& extent : op->extents) { PrimExpr new_ext = this->VisitExpr(extent); if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(ffi::GetRef(op)); } extents.push_back(new_ext); } @@ -887,7 +887,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor MutateArray(Array arr, int* p_lanes) { + ffi::Array MutateArray(ffi::Array arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; bool changed = false; @@ -907,7 +907,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(new_arr); + return ffi::Array(new_arr); } template PrimExpr BinaryVec(const T* op) { @@ -915,7 +915,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -929,7 +929,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 1ca901c6fbf5..65cbe3680572 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -52,7 +52,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def_packed("topi.broadcast_to", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); + *rv = broadcast_to(args[0].cast(), + args[1].cast>()); }) .TOPI_DEF_BCAST_OP("topi.add", topi::add) .TOPI_DEF_BCAST_OP("topi.subtract", topi::subtract) diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 9586b9c5575e..32131e975b3d 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -136,26 +136,26 @@ class EinsumBuilder { * \param equation The Einsum equation * \param input_shapes The shapes of the input tensors */ - EinsumBuilder(EinsumEquation equation, Array> input_shapes) + EinsumBuilder(EinsumEquation equation, ffi::Array> input_shapes) : equation_(equation), input_shapes_(input_shapes) {} /*! * \brief Run the shape inference * \return The inferred shape of the output */ - Array InferShape() { + ffi::Array InferShape() { CHECK_EQ(equation_.inputs.size(), input_shapes_.size()) << "Number of operands does not match the " "equation"; - std::vector> + std::vector> ellipis_shapes; // the sub-shape covered by the ellipsis for each operand // Step 1: Collect the broadcasted extent for each label for (int operand_index = 0; operand_index < static_cast(input_shapes_.size()); ++operand_index) { const EinsumEquation::Subscript subscript = equation_.inputs[operand_index]; - const Array& input_shape = input_shapes_[operand_index]; + const ffi::Array& input_shape = input_shapes_[operand_index]; int current_dim = 0; for (auto label : subscript) { @@ -182,14 +182,16 @@ class EinsumBuilder { // Step 2: Infer the shape of the ellipsis if exists // The ellipsis may cover different number of dimensions for each operand, these sub-shapes // need to be broadcasted to the shape with the maximum number of dimensions - Array ellipsis_shape; + ffi::Array ellipsis_shape; if (ellipis_shapes.size()) { - ellipsis_shape = *std::max_element( - ellipis_shapes.begin(), ellipis_shapes.end(), - [](const Array& a, const Array& b) { return a.size() < b.size(); }); - for (const Array& shape : ellipis_shapes) { + ellipsis_shape = + *std::max_element(ellipis_shapes.begin(), ellipis_shapes.end(), + [](const ffi::Array& a, const ffi::Array& b) { + return a.size() < b.size(); + }); + for (const ffi::Array& shape : ellipis_shapes) { auto common_shape = detail::BroadcastShape(ellipsis_shape, shape).common_shape; - ellipsis_shape = Array(common_shape.begin(), common_shape.end()); + ellipsis_shape = ffi::Array(common_shape.begin(), common_shape.end()); } } @@ -205,10 +207,10 @@ class EinsumBuilder { return output_shape_; } - PrimExpr BuildOutputExpr(const Array inputs, const Array& indices) { + PrimExpr BuildOutputExpr(const ffi::Array inputs, const ffi::Array& indices) { std::unordered_map label_to_index; - Array ellipsis_indices; - Array reduce_axes; + ffi::Array ellipsis_indices; + ffi::Array reduce_axes; PrepareOutputIndicesMapping(indices, &label_to_index, &ellipsis_indices); PrepareReductionIndicesMapping(indices, &label_to_index, &ellipsis_indices, &reduce_axes); @@ -234,14 +236,15 @@ class EinsumBuilder { /*! * \brief Prepare mapping from label (including ellipsis) to the output indices */ - void PrepareOutputIndicesMapping(const Array& indices, + void PrepareOutputIndicesMapping(const ffi::Array& indices, std::unordered_map* label_to_index, - Array* ellipsis_indices) { + ffi::Array* ellipsis_indices) { int i = 0; for (auto label : equation_.output) { if (label == EinsumEquation::kEllipsis) { auto ellipsis_ndim = ellipsis_shape_.value().size(); - *ellipsis_indices = Array(indices.begin() + i, indices.begin() + i + ellipsis_ndim); + *ellipsis_indices = + ffi::Array(indices.begin() + i, indices.begin() + i + ellipsis_ndim); i += ellipsis_ndim; } else { label_to_index->emplace(label, indices[i++]); @@ -255,8 +258,9 @@ class EinsumBuilder { * necessary) to the reduction axes */ void PrepareReductionIndicesMapping( - const Array& indices, std::unordered_map* label_to_index, - Array* ellipsis_indices, Array* reduction_axes) { + const ffi::Array& indices, + std::unordered_map* label_to_index, + ffi::Array* ellipsis_indices, ffi::Array* reduction_axes) { // Collect labels that need to be reduced, which is the union(input_labels) - output_labels std::set reduction_labels; for (const EinsumEquation::Subscript& subscript : equation_.inputs) { @@ -288,18 +292,18 @@ class EinsumBuilder { } } - Array GetIndicesForOperand( + ffi::Array GetIndicesForOperand( int operand_index, const std::unordered_map& label_to_index, - const Array& ellipsis_indices) { + const ffi::Array& ellipsis_indices) { const EinsumEquation::Subscript& subscript = equation_.inputs[operand_index]; - Array indices; // the indices for the operand - const Array input_shape = input_shapes_[operand_index]; + ffi::Array indices; // the indices for the operand + const ffi::Array input_shape = input_shapes_[operand_index]; int i = 0; // index of the operand shape for (char label : subscript) { if (label == EinsumEquation::kEllipsis) { // Ellipsis - Array ellipsis_shape = ellipsis_shape_.value(); + ffi::Array ellipsis_shape = ellipsis_shape_.value(); int ellipsis_ndim = static_cast(input_shape.size()) - static_cast(subscript.size()) + 1; // use last 'ellipsis_ndim' axes @@ -320,24 +324,24 @@ class EinsumBuilder { } EinsumEquation equation_; - Array> input_shapes_; + ffi::Array> input_shapes_; // intermediate results of shape inference // The output shape - Array output_shape_; + ffi::Array output_shape_; // The extent of each label with broadcast rules applied std::unordered_map label_to_extent_; // The shape of the ellipsis if ellipsis is used. The shape covered by the // ellipsis in each operand might be different from this, this is the common // shape among them according to the broadcast rules. - Optional> ellipsis_shape_; + ffi::Optional> ellipsis_shape_; }; -Tensor einsum(const std::string& subscripts_str, const Array inputs, std::string name, +Tensor einsum(const std::string& subscripts_str, const ffi::Array inputs, std::string name, std::string tag) { EinsumEquation equation = EinsumEquation::FromString(subscripts_str); - Array> input_shapes; + ffi::Array> input_shapes; for (const Tensor& input : inputs) { input_shapes.push_back(input->shape); } @@ -345,12 +349,14 @@ Tensor einsum(const std::string& subscripts_str, const Array inputs, std auto output_shape = einsum_builder.InferShape(); return te::compute( output_shape, - [&](const Array& indices) { return einsum_builder.BuildOutputExpr(inputs, indices); }, + [&](const ffi::Array& indices) { + return einsum_builder.BuildOutputExpr(inputs, indices); + }, name, tag); } -Array InferEinsumShape(const std::string& subscripts, - const std::vector>& operands) { +ffi::Array InferEinsumShape(const std::string& subscripts, + const std::vector>& operands) { EinsumEquation equation = EinsumEquation::FromString(subscripts); EinsumBuilder einsum_builder = EinsumBuilder(equation, operands); return einsum_builder.InferShape(); @@ -359,7 +365,7 @@ Array InferEinsumShape(const std::string& subscripts, TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.einsum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = einsum(args[0].cast(), args[1].cast>()); + *rv = einsum(args[0].cast(), args[1].cast>()); }); }); diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index b60256cea5f5..718f078dbe9f 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -100,13 +100,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.elemwise_sum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = elemwise_sum(args[0].cast>()); + *rv = elemwise_sum(args[0].cast>()); }) .def_packed("topi.sign", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sign(args[0].cast()); }) .def_packed("topi.full", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = full(args[0].cast>(), args[1].cast(), + *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); }) .def_packed("topi.full_like", diff --git a/src/topi/nn.cc b/src/topi/nn.cc index d872bac2ce30..e77508a912d5 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -62,21 +62,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.nn.pad", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = pad(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast()); + *rv = pad(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast()); }) .def_packed("topi.nn.space_to_batch_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = space_to_batch_nd( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), args[4].cast()); }) .def_packed("topi.nn.batch_to_space_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = batch_to_space_nd( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), args[4].cast()); }) .def_packed("topi.nn.nll_loss", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -107,7 +107,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.dilate", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::dilate(args[0].cast(), args[1].cast>(), + *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); }); @@ -144,8 +144,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool_grad( args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) @@ -158,46 +158,46 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.nn.adaptive_pool1d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool1d(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.adaptive_pool", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.adaptive_pool3d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool3d(args[0].cast(), - args[1].cast>(), + args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }) .def_packed("topi.nn.pool1d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool1d( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) .def_packed("topi.nn.pool2d", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool2d( - args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }) .def_packed("topi.nn.pool3d", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::pool3d(args[0].cast(), args[1].cast>(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), + *rv = nn::pool3d(args[0].cast(), args[1].cast>(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); @@ -239,7 +239,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.layer_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), + args[2].cast(), args[3].cast>(), args[4].cast()); }); }); @@ -250,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def_packed("topi.nn.group_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::group_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); + args[5].cast>(), args[6].cast()); }); }); @@ -260,7 +260,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def_packed("topi.nn.instance_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), - args[4].cast>(), args[5].cast()); + args[4].cast>(), args[5].cast()); }); }); @@ -269,7 +269,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.rms_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast()); + args[2].cast>(), args[3].cast()); }); }); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 7b10c7771b32..503840df8aae 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -76,7 +76,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ args[2].cast()); }) .def_packed("topi.collapse_sum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); }); }); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 2324e845b934..911f9320b55a 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -48,7 +48,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_packed("topi.transpose", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), - args[1].cast>>()); + args[1].cast>>()); }) .def_packed("topi.flip", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -63,13 +63,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.reshape", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = reshape(args[0].cast(), args[1].cast>()); + *rv = reshape(args[0].cast(), args[1].cast>()); }) .def_packed("topi.sliding_window", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), - args[3].cast>()); + args[2].cast>(), + args[3].cast>()); }) .def_packed("topi.squeeze", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -77,11 +77,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.concatenate", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = concatenate(args[0].cast>(), args[1].cast()); + *rv = concatenate(args[0].cast>(), args[1].cast()); }) .def_packed("topi.stack", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = stack(args[0].cast>(), args[1].cast()); + *rv = stack(args[0].cast>(), args[1].cast()); }) .def_packed("topi.shape", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -97,9 +97,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); } else { - *rv = - split_indices_array(args[0].cast(), - args[1].cast>(), args[2].cast()); + *rv = split_indices_array(args[0].cast(), + args[1].cast>(), + args[2].cast()); } }) .def_packed("topi.layout_transform", @@ -144,7 +144,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.meshgrid", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = meshgrid(args[0].cast>(), args[1].cast()); + *rv = meshgrid(args[0].cast>(), + args[1].cast()); }) .def_packed("topi.repeat", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -153,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.tile", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tile(args[0].cast(), args[1].cast>()); + *rv = tile(args[0].cast(), args[1].cast>()); }) .def_packed("topi.gather", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -172,9 +173,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def_packed("topi.sparse_to_dense", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - sparse_to_dense(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); + *rv = sparse_to_dense(args[0].cast(), + args[1].cast>(), + args[2].cast(), args[3].cast()); }) .def_packed("topi.matmul", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -202,25 +203,25 @@ TVM_FFI_STATIC_INIT_BLOCK({ *rv = tensordot(args[0].cast(), args[1].cast(), args[2].cast()); } else { - Array axes = args[3].cast>(); + ffi::Array axes = args[3].cast>(); *rv = tensordot(args[0].cast(), args[1].cast(), - args[2].cast>(), axes); + args[2].cast>(), axes); } }) .def_packed( "topi.strided_slice", [](ffi::PackedArgs args, ffi::Any* rv) { te::Tensor x = args[0].cast(); - Array begin = args[1].cast>(); - Array end = args[2].cast>(); - Array strides = args[3].cast>(); - Array axes = args[4].cast>(); + ffi::Array begin = args[1].cast>(); + ffi::Array end = args[2].cast>(); + ffi::Array strides = args[3].cast>(); + ffi::Array axes = args[4].cast>(); bool assume_inbound = args[6].cast(); if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && IsConstIntArray(x->shape)) { - Array begin_static = args[1].cast>(); - Array end_static = args[2].cast>(); - Array strides_static = args[3].cast>(); + ffi::Array begin_static = args[1].cast>(); + ffi::Array end_static = args[2].cast>(); + ffi::Array strides_static = args[3].cast>(); auto slice_mode = args[5].cast(); if (axes.size()) { *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, @@ -245,7 +246,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ }) .def("topi.relax_dynamic_strided_slice", [](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, - Array output_shape) { + ffi::Array output_shape) { return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); }) .def_packed("topi.one_hot", @@ -266,7 +267,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ k1, k2, super_diag_right_align, sub_diag_right_align); }) .def("topi.adv_index", - [](te::Tensor x, Array indices) { return adv_index(x, indices); }); + [](te::Tensor x, ffi::Array indices) { return adv_index(x, indices); }); }); } // namespace topi diff --git a/src/topi/utils.cc b/src/topi/utils.cc index 6e5c997739d7..a518d28f0277 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -33,17 +33,17 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def_packed("topi.utils.is_empty_shape", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::detail::is_empty_shape(args[0].cast>()); + *rv = topi::detail::is_empty_shape(args[0].cast>()); }) .def_packed("topi.utils.bilinear_sample_nchw", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nchw( - args[0].cast(), args[1].cast>(), + args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }) .def_packed("topi.utils.bilinear_sample_nhwc", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nhwc(args[0].cast(), - args[1].cast>(), + args[1].cast>(), args[2].cast(), args[3].cast()); }); }); diff --git a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc index 6f9c9f0f6f7b..febf484f8161 100644 --- a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc @@ -22,30 +22,31 @@ #include "../src/runtime/hexagon/hexagon_buffer.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; TEST(HexagonBuffer, default_scope) { - Optional scope; + ffi::Optional scope; HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); } TEST(HexagonBuffer, ddr_scope) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); } TEST(HexagonBuffer, vtcm_scope) { - Optional scope(String("global.vtcm")); + ffi::Optional scope(ffi::String("global.vtcm")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); EXPECT_EQ(hb.GetStorageScope(), HexagonBuffer::StorageScope::kVTCM); } TEST(HexagonBuffer, invalid_scope) { - Optional scope(String("invalid")); + ffi::Optional scope(ffi::String("invalid")); EXPECT_THROW(HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope), InternalError); } @@ -268,7 +269,7 @@ TEST(HexagonBuffer, macro_copies_overlapping_regions_merged) { } TEST(HexagonBuffer, copy_from) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -281,7 +282,7 @@ TEST(HexagonBuffer, copy_from) { } TEST(HexagonBuffer, copy_from_invalid_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; // HexagonBuffer too small @@ -290,7 +291,7 @@ TEST(HexagonBuffer, copy_from_invalid_size) { } TEST(HexagonBuffer, copy_from_smaller_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; // HexagonBuffer is big @@ -299,25 +300,25 @@ TEST(HexagonBuffer, copy_from_smaller_size) { } TEST(HexagonBuffer, nd) { - Optional def; + ffi::Optional def; HexagonBuffer hb_default(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, def); EXPECT_EQ(hb_default.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer hb_global(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, global); EXPECT_EQ(hb_global.GetStorageScope(), HexagonBuffer::StorageScope::kDDR); - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb_vtcm(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, vtcm); EXPECT_EQ(hb_vtcm.GetStorageScope(), HexagonBuffer::StorageScope::kVTCM); - Optional invalid(String("invalid")); + ffi::Optional invalid(ffi::String("invalid")); EXPECT_THROW(HexagonBuffer hb_invalid(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, invalid), InternalError); } TEST(HexagonBuffer, nd_copy_from) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -335,10 +336,10 @@ TEST(HexagonBuffer, nd_copy_from) { } TEST(HexagonBuffer, 1d_copy_from_1d) { - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer from(8 /* nbytes */, 8 /* alignment */, global); - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer to(8 /* nbytes */, 8 /* alignment */, vtcm); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -352,10 +353,10 @@ TEST(HexagonBuffer, 1d_copy_from_1d) { } TEST(HexagonBuffer, 2d_copy_from_1d) { - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, vtcm); - Optional global(String("global")); + ffi::Optional global(ffi::String("global")); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, global); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -374,10 +375,10 @@ TEST(HexagonBuffer, 2d_copy_from_1d) { } TEST(HexagonBuffer, 1d_copy_from_2d) { - Optional vtcm(String("global.vtcm")); + ffi::Optional vtcm(ffi::String("global.vtcm")); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, vtcm); - Optional global(String("global.vtcm")); + ffi::Optional global(ffi::String("global.vtcm")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, global); std::vector data{0, 1, 2, 3, 4, 5, 6, 7}; @@ -391,7 +392,7 @@ TEST(HexagonBuffer, 1d_copy_from_2d) { } TEST(HexagonBuffer, nd_copy_from_nd_invalid_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); @@ -405,7 +406,7 @@ TEST(HexagonBuffer, nd_copy_from_nd_invalid_size) { } TEST(HexagonBuffer, nd_copy_from_nd_smaller_size) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb1d(8 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); @@ -419,7 +420,7 @@ TEST(HexagonBuffer, nd_copy_from_nd_smaller_size) { } TEST(HexagonBuffer, md_copy_from_nd) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb3d(3 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer hb4d(4 /* ndim */, 3 /* nbytes */, 8 /* alignment */, scope); @@ -436,7 +437,7 @@ TEST(HexagonBuffer, md_copy_from_nd) { } TEST(HexagonBuffer, copy_to) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope); std::vector data_in{0, 1, 2, 3, 4, 5, 6, 7}; @@ -451,7 +452,7 @@ TEST(HexagonBuffer, copy_to) { } TEST(HexagonBuffer, nd_copy_to) { - Optional scope(String("global")); + ffi::Optional scope(ffi::String("global")); HexagonBuffer hb(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); std::vector data_in{0, 1, 2, 3, 4, 5, 6, 7}; diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index 6211bd63dfbc..9c74521091aa 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -46,10 +47,10 @@ class HexagonDeviceAPITest : public ::testing::Test { int64_t shape1d[1]{256}; int64_t shape2d[2]{256, 256}; int64_t shape3d[3]{256, 256, 256}; - Optional default_scope; - Optional invalid_scope = String("invalid"); - Optional global_scope = String("global"); - Optional global_vtcm_scope = String("global.vtcm"); + ffi::Optional default_scope; + ffi::Optional invalid_scope = ffi::String("invalid"); + ffi::Optional global_scope = ffi::String("global"); + ffi::Optional global_vtcm_scope = ffi::String("global.vtcm"); }; TEST_F(HexagonDeviceAPITest, global) { CHECK(hexapi != nullptr); } diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index 2e47473f8a17..dd95a8fb37a7 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -56,8 +57,8 @@ class HexagonUserDMATest : public ::testing::Test { uint32_t length = 0x4000; // 16KB const bool ENABLE_BYPASS = true; const bool DISABLE_BYPASS = false; - Optional global_scope = String("global"); - Optional global_vtcm_scope = String("global.vtcm"); + ffi::Optional global_scope = ffi::String("global"); + ffi::Optional global_vtcm_scope = ffi::String("global.vtcm"); }; TEST_F(HexagonUserDMATest, wait) { diff --git a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc index 3cf008c874ab..baa4035e47fb 100644 --- a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc @@ -21,6 +21,7 @@ #include "../src/runtime/hexagon/hexagon_device_api.h" +using namespace tvm; using namespace tvm::runtime; using namespace tvm::runtime::hexagon; using namespace tvm::ffi; @@ -256,28 +257,28 @@ TEST_F(HexagonVtcmPoolTest, vtcm_alignment) { void* ptr; // Invalid alignments - EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 128 + 1, String("global")), + EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 128 + 1, ffi::String("global")), InternalError); - EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 2048 + 1, String("global")), + EXPECT_THROW(test_hexbuffs->AllocateHexagonBuffer(min_bytes, 2048 + 1, ffi::String("global")), InternalError); // Valid alignments, sizes need to be adjusted - ptr = test_hexbuffs->AllocateHexagonBuffer(1, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(1, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(127, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(127, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(129, 128, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(129, 128, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(1, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(1, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(2047, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(2047, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; - ptr = test_hexbuffs->AllocateHexagonBuffer(2049, 2048, String("global")); + ptr = test_hexbuffs->AllocateHexagonBuffer(2049, 2048, ffi::String("global")); CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; test_hexbuffs.reset(); diff --git a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc index 1097a21128e1..0ab2f5ff6855 100644 --- a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc +++ b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc @@ -194,7 +194,7 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { { OpenCLModuleNode module(m_dataSrc, "cl", m_fmap, std::string()); module.Init(); - module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::String(bytes)); + module.GetFunction("opencl.SetPreCompiledPrograms").value()(tvm::ffi::String(bytes)); Timestamp comp_start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < m_kernelNames.size(); ++i) { OpenCLModuleNode::KTRefEntry e = {i, 1}; diff --git a/tests/cpp-runtime/opencl/texture_copy_test.cc b/tests/cpp-runtime/opencl/texture_copy_test.cc index c9ee44515d1f..001e65b90126 100644 --- a/tests/cpp-runtime/opencl/texture_copy_test.cc +++ b/tests/cpp-runtime/opencl/texture_copy_test.cc @@ -63,7 +63,7 @@ TEST(TextureCopy, HostDeviceRT) { std::vector shape{16, 16, 4}; auto cpu_arr0 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto cpu_arr1 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - String mem_scope = "global.texture"; + ffi::String mem_scope = "global.texture"; auto opencl_txarr0 = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLOpenCL, 0}, mem_scope); size_t size = 1; @@ -97,7 +97,7 @@ TEST_F(TextureCopyTest, ViewBufferAsBuffer) { auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); - String mem_scope = "global"; + ffi::String mem_scope = "global"; DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); diff --git a/tests/cpp/data_type_rewriter_test.cc b/tests/cpp/data_type_rewriter_test.cc index c5e6d4f75843..1eec334344b3 100644 --- a/tests/cpp/data_type_rewriter_test.cc +++ b/tests/cpp/data_type_rewriter_test.cc @@ -37,7 +37,7 @@ TYPED_TEST_SUITE(DataTypeLegalizerBinaryOp, BinaryOpTypes); TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) { using RefType = TypeParam; using NodeType = typename RefType::ContainerType; - auto node = make_object(); + auto node = ffi::make_object(); node->a = Var("a", DataType::Int(32)); node->b = IntImm(DataType::Int(64), 2); DataTypeLegalizer legalizer; @@ -48,7 +48,7 @@ TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) { } TEST(DataTypeLegalizer, Select) { - auto node = make_object(); + auto node = ffi::make_object(); node->condition = Var("cond", DataType::Bool()); node->true_value = Var("a", DataType::Int(64)); node->false_value = IntImm(DataType::Int(32), 2); @@ -73,8 +73,8 @@ TEST(DataTypeLegalizer, IfThenElse) { } TEST(DataTypeLegalizer, Block) { - auto block_node = make_object(); - auto iter_var_node = make_object(); + auto block_node = ffi::make_object(); + auto iter_var_node = ffi::make_object(); iter_var_node->var = Var("i", DataType::Int(32)); iter_var_node->dom = Range::FromMinExtent(IntImm(DataType::Int(64), 0), IntImm(DataType::Int(64), 10)); @@ -84,12 +84,12 @@ TEST(DataTypeLegalizer, Block) { block_node->writes = {}; block_node->name_hint = "block"; block_node->body = Evaluate(Integer(0)); - auto block_realize_node = make_object(); + auto block_realize_node = ffi::make_object(); auto loop_var = Var("i", DataType::Int(32)); block_realize_node->iter_values = {loop_var}; block_realize_node->predicate = const_true(); block_realize_node->block = Block(block_node); - auto for_node = make_object(); + auto for_node = ffi::make_object(); for_node->loop_var = loop_var; for_node->min = IntImm(DataType::Int(64), 0); for_node->extent = IntImm(DataType::Int(64), 10); @@ -113,7 +113,7 @@ TEST(DataTypeLegalizer, Block) { } TEST(DataTypeLegalizer, For) { - auto node = make_object(); + auto node = ffi::make_object(); node->body = Evaluate(Integer(0)); node->loop_var = Var("i", DataType::Int(32)); node->min = IntImm(DataType::Int(64), 0); @@ -126,7 +126,7 @@ TEST(DataTypeLegalizer, For) { } TEST(DataTypeLegalizer, Ramp) { - auto node = make_object(); + auto node = ffi::make_object(); node->base = IntImm(DataType::Int(64), 0); node->stride = IntImm(DataType::Int(32), 1); int lanes = 4; diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 579479ccc0e5..05fbd5ce548c 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -51,5 +51,5 @@ TEST(ExprNodeRef, Basic) { Var x("x"); PrimExpr z = max(x + 1 + 2, 100); const tir::MaxNode* op = z.as(); - ICHECK(GetRef(op).same_as(z)); + ICHECK(ffi::GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 348792d6ff88..ec7b4111d240 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -215,7 +215,7 @@ TEST(IRF, StmtMutator) { Stmt body2 = Evaluate(1); Stmt bref = body.as()->body; auto* extentptr = body.as()->extents.get(); - Array arr{std::move(body), body2, body2}; + ffi::Array arr{std::move(body), body2, body2}; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() == arrptr); @@ -228,9 +228,9 @@ TEST(IRF, StmtMutator) { ICHECK(bref.as()->value.as()); } { - Array arr{fmakealloc()}; + ffi::Array arr{fmakealloc()}; // mutate array get reference by another one, triiger copy. - Array arr2 = arr; + ffi::Array arr2 = arr; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr.get() != arrptr); @@ -242,7 +242,7 @@ TEST(IRF, StmtMutator) { ICHECK(arr2.get() == arr.get()); } { - Array arr{fmakeif()}; + ffi::Array arr{fmakeif()}; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); ICHECK(arr[0].as()->else_case.as()->value.same_as(x)); // mutate but no content change. @@ -332,7 +332,7 @@ TEST(IRF, Substitute) { // test substitute buffer var Var y = x.copy_with_suffix("subst"); BufferLoad buffer_load = fmaketest(); - auto f_subst = [&](const Var& var) -> Optional { + auto f_subst = [&](const Var& var) -> ffi::Optional { if (var.same_as(x)) { return y; } @@ -345,7 +345,7 @@ TEST(IRF, Substitute) { { // test identity substitution PrimExpr expr = fmaketest(); - auto f_subst = [&](const Var& var) -> Optional { return var; }; + auto f_subst = [&](const Var& var) -> ffi::Optional { return var; }; PrimExpr new_expr = Substitute(expr, f_subst); // the expression is not changed ICHECK(new_expr.same_as(expr)); diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 644a80664fe1..c9628daf0d80 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -138,9 +138,9 @@ TEST(NestedMsg, Equal) { EXPECT_FALSE(Equal(M(std::nullopt), M(x), fequal)); - EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); + EXPECT_FALSE(Equal(M(x), M(ffi::Array({x})), fequal)); - EXPECT_FALSE(Equal(M(Array({x})), M(x), fequal)); + EXPECT_FALSE(Equal(M(ffi::Array({x})), M(x), fequal)); } TEST(NestedMsg, MapAndDecompose) { @@ -232,7 +232,7 @@ TEST(NestedMsg, NestedMsgToExpr) { relax::Var x("x", sf0), y("y", sf0), z("z", sf0); NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; - auto expr = NestedMsgToExpr(msg, [&](Optional leaf) { + auto expr = NestedMsgToExpr(msg, [&](ffi::Optional leaf) { ICHECK(leaf.defined()); int value = leaf.value().IntValue(); switch (value) { @@ -251,7 +251,7 @@ TEST(NestedMsg, NestedMsgToExpr) { // test simplified relax::Var t("t", sf1); NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; - auto expr1 = NestedMsgToExpr(msg1, [](Optional leaf) { return leaf.value(); }); + auto expr1 = NestedMsgToExpr(msg1, [](ffi::Optional leaf) { return leaf.value(); }); EXPECT_TRUE(StructuralEqual()(expr1, t)); } diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index be69d77ccc73..cbd8f7a94154 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -59,11 +59,12 @@ class ObjAA : public ObjA { } // namespace tvm TEST(ObjectHierachy, Basic) { + using namespace tvm; using namespace tvm::runtime; using namespace tvm::test; using namespace tvm::ffi; - ObjectRef refA(make_object()); + ObjectRef refA(ffi::make_object()); ICHECK_EQ(refA->type_index(), ObjA::RuntimeTypeIndex()); ICHECK(refA.as() != nullptr); ICHECK(refA.as() != nullptr); @@ -71,7 +72,7 @@ TEST(ObjectHierachy, Basic) { ICHECK(refA.as() == nullptr); ICHECK(refA.as() == nullptr); - ObjectRef refAA(make_object()); + ObjectRef refAA(ffi::make_object()); ICHECK_EQ(refAA->type_index(), ObjAA::RuntimeTypeIndex()); ICHECK(refAA.as() != nullptr); ICHECK(refAA.as() != nullptr); @@ -79,7 +80,7 @@ TEST(ObjectHierachy, Basic) { ICHECK(refAA.as() != nullptr); ICHECK(refAA.as() == nullptr); - ObjectRef refB(make_object()); + ObjectRef refB(ffi::make_object()); ICHECK_EQ(refB->type_index(), ObjB::RuntimeTypeIndex()); ICHECK(refB.as() != nullptr); ICHECK(refB.as() != nullptr); diff --git a/tests/cpp/target/parsers/aprofile_test.cc b/tests/cpp/target/parsers/aprofile_test.cc index 26f52f4938a8..1e74b3f71599 100644 --- a/tests/cpp/target/parsers/aprofile_test.cc +++ b/tests/cpp/target/parsers/aprofile_test.cc @@ -44,9 +44,9 @@ static bool CheckArchitectureAvailability() { #if TVM_LLVM_VERSION > 120 auto llvm_instance = std::make_unique(); codegen::LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); - Array targets = llvm_backend.GetAllLLVMTargets(); + ffi::Array targets = llvm_backend.GetAllLLVMTargets(); int expected_target_count = 0; - for (String target : targets) { + for (ffi::String target : targets) { if (target == "aarch64" || target == "arm") { expected_target_count += 1; } @@ -74,9 +74,10 @@ class AProfileParser : public ::testing::Test { class AProfileParserTestWithParam : public AProfileParser, public testing::WithParamInterface {}; -static TargetFeatures ParseTargetWithAttrs(String mcpu, String mtriple, Array mattr) { +static TargetFeatures ParseTargetWithAttrs(ffi::String mcpu, ffi::String mtriple, + ffi::Array mattr) { TargetJSON target_json = { - {"kind", String("llvm")}, + {"kind", ffi::String("llvm")}, {"mtriple", mtriple}, {"mattr", mattr}, }; @@ -93,8 +94,8 @@ std::string FloatToStringWithoutTrailingZeros(float value) { } TEST_F(AProfileParser, ParseTargetKeys) { - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); - Array keys = Downcast>(target.at("keys")); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "arm_cpu"); ASSERT_EQ(keys[1], "cpu"); @@ -102,11 +103,11 @@ TEST_F(AProfileParser, ParseTargetKeys) { TEST_F(AProfileParser, ParseTargetWithExistingKeys) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"keys", Array{"cpu"}}, + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); @@ -114,18 +115,18 @@ TEST_F(AProfileParser, ParseTargetWithExistingKeys) { TEST_F(AProfileParser, ParseTargetWithDuplicateKey) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"keys", Array{"cpu", "arm_cpu"}}, + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); } TEST_F(AProfileParser, ParseTargetDefaults) { - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); TargetFeatures features = Downcast(target.at("features")); ASSERT_EQ(Downcast(features.at("is_aarch64")), false); @@ -157,8 +158,8 @@ TEST_F(AProfileParser, IsAArch32Triple) { TEST_F(AProfileParser, IsAArch32BlankCPU) { TargetJSON target = ParseTarget({ - {"kind", String("llvm")}, - {"mtriple", String("arm-unknown-linux-gnu")}, + {"kind", ffi::String("llvm")}, + {"mtriple", ffi::String("arm-unknown-linux-gnu")}, }); TargetFeatures features = Downcast(target.at("features")); ASSERT_EQ(IsArch(target), true); @@ -396,7 +397,7 @@ TEST_F(AProfileParser, UnexpectedTargetKind) { EXPECT_THROW( { try { - ParseTarget({{"kind", String("c")}}); + ParseTarget({{"kind", ffi::String("c")}}); } catch (const tvm::InternalError& e) { EXPECT_THAT(e.what(), HasSubstr("Expected target kind 'llvm', but got 'c'")); throw; @@ -409,7 +410,7 @@ TEST(AProfileParserInvalid, LLVMUnsupportedArchitecture) { if (has_aarch64_and_arm_targets) { GTEST_SKIP() << "LLVM has been compiled for the correct targets."; } - TargetJSON target = ParseTarget({{"kind", String("llvm")}}); + TargetJSON target = ParseTarget({{"kind", ffi::String("llvm")}}); TargetFeatures features = Downcast(target.at("features")); for (auto feature : features) { ASSERT_EQ(Downcast(feature.second), false); diff --git a/tests/cpp/target/parsers/mprofile_test.cc b/tests/cpp/target/parsers/mprofile_test.cc index 97fb227e4190..19baf006d895 100644 --- a/tests/cpp/target/parsers/mprofile_test.cc +++ b/tests/cpp/target/parsers/mprofile_test.cc @@ -37,30 +37,30 @@ class MProfileParserMVECPUs : public testing::TestWithParam {}; class MProfileParserDSPCPUs : public testing::TestWithParam {}; class MProfileParserNoExtensions : public testing::TestWithParam {}; -static TargetFeatures ParseTargetWithAttrs(String mcpu, Array mattr) { +static TargetFeatures ParseTargetWithAttrs(ffi::String mcpu, ffi::Array mattr) { return ParseTarget({{"mcpu", mcpu}, {"mattr", mattr}}); } TEST(MProfileParser, CheckIsNotArch) { - String mcpu = "cake"; + ffi::String mcpu = "cake"; TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), false); } TEST_P(MProfileParserMVECPUs, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } TEST_P(MProfileParserDSPCPUs, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } TEST_P(MProfileParserNoExtensions, CheckIsArch) { - String mcpu = GetParam(); + ffi::String mcpu = GetParam(); TargetJSON fake_target = {{"mcpu", mcpu}}; ASSERT_EQ(IsArch(fake_target), true); } @@ -68,7 +68,7 @@ TEST_P(MProfileParserNoExtensions, CheckIsArch) { TEST(MProfileParser, ParseTarget) { TargetJSON target = ParseTarget({}); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "arm_cpu"); ASSERT_EQ(keys[1], "cpu"); @@ -79,10 +79,10 @@ TEST(MProfileParser, ParseTarget) { TEST(MProfileParser, ParseTargetWithExistingKeys) { TargetJSON target = ParseTarget({ - {"keys", Array{"cpu"}}, + {"keys", ffi::Array{"cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); @@ -90,10 +90,10 @@ TEST(MProfileParser, ParseTargetWithExistingKeys) { TEST(MProfileParser, ParseTargetWithDuplicateKey) { TargetJSON target = ParseTarget({ - {"keys", Array{"cpu", "arm_cpu"}}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, }); TargetFeatures features = Downcast(target.at("features")); - Array keys = Downcast>(target.at("keys")); + ffi::Array keys = Downcast>(target.at("keys")); ASSERT_EQ(keys.size(), 2); ASSERT_EQ(keys[0], "cpu"); ASSERT_EQ(keys[1], "arm_cpu"); diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc index d982a8ae2153..4f4b945cae8f 100644 --- a/tests/cpp/target/virtual_device_test.cc +++ b/tests/cpp/target/virtual_device_test.cc @@ -29,7 +29,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); VirtualDevice rhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -38,7 +38,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -47,7 +47,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(kDLCUDA); VirtualDevice rhs = VirtualDevice(kDLCUDA, 2, target_a); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a); EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -56,7 +56,7 @@ TEST(VirtualDevice, Join_Defined) { Target target_a = Target("cuda"); VirtualDevice lhs = VirtualDevice(); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_TRUE(actual.operator bool()); VirtualDevice expected = rhs; EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); @@ -67,25 +67,25 @@ TEST(VirtualDevice, Join_Undefined) { { VirtualDevice lhs = VirtualDevice(kDLCUDA); VirtualDevice rhs = VirtualDevice(kDLCPU); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); VirtualDevice rhs = VirtualDevice(kDLCUDA, 4); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } { VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "local"); VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "global"); - Optional actual = VirtualDevice::Join(lhs, rhs); + ffi::Optional actual = VirtualDevice::Join(lhs, rhs); EXPECT_FALSE(actual); } } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 17e3cae4ad18..6cea161f7482 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,36 +32,36 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") .add_attr_option("my_bool") - .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("your_names") + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { - String mcpu = Downcast(target.at("mcpu")); - target.Set("mcpu", String("super_") + mcpu); - target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", true}}); + ffi::String mcpu = Downcast(target.at("mcpu")); + target.Set("mcpu", ffi::String("super_") + mcpu); + target.Set("keys", ffi::Array({"super"})); + target.Set("features", ffi::Map{{"test", true}}); return target; } -Map TestAttrsPreProcessor(Map attrs) { - attrs.Set("mattr", String("woof")); +ffi::Map TestAttrsPreProcessor(ffi::Map attrs) { + attrs.Set("mattr", ffi::String("woof")); return attrs; } TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU) - .add_attr_option("mattr") - .add_attr_option("mcpu") + .add_attr_option("mattr") + .add_attr_option("mcpu") .set_default_keys({"cpu"}) .set_target_parser(TestTargetParser); TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU) - .add_attr_option("mattr") + .add_attr_option("mattr") .set_default_keys({"cpu"}) .set_attrs_preprocessor(TestAttrsPreProcessor); TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU) - .add_attr_option("mattr") - .add_attr_option("mcpu") + .add_attr_option("mattr") + .add_attr_option("mcpu") .set_default_keys({"cpu"}) .set_attrs_preprocessor(TestAttrsPreProcessor) .set_target_parser(TestTargetParser); @@ -74,13 +74,13 @@ TEST(TargetKind, GetAttrMap) { } TEST(TargetCreation, NestedConfig) { - Map config = { + ffi::Map config = { {"my_bool", true}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -92,25 +92,27 @@ TEST(TargetCreation, NestedConfig) { ICHECK(target->keys.empty()); bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool, true); - Array your_names = target->GetAttr>("your_names").value(); + ffi::Array your_names = + target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = target->GetAttr>("her_maps").value(); + ffi::Map her_maps = + target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); } TEST(TargetCreationFail, UnrecognizedConfigOption) { - Map config = { + ffi::Map config = { {"my_bool", true}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -126,13 +128,13 @@ TEST(TargetCreationFail, UnrecognizedConfigOption) { } TEST(TargetCreationFail, TypeMismatch) { - Map config = { - {"my_bool", String("true")}, - {"your_names", Array{"junru", "jian"}}, - {"kind", String("TestTargetKind")}, + ffi::Map config = { + {"my_bool", ffi::String("true")}, + {"your_names", ffi::Array{"junru", "jian"}}, + {"kind", ffi::String("TestTargetKind")}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -148,12 +150,12 @@ TEST(TargetCreationFail, TypeMismatch) { } TEST(TargetCreationFail, TargetKindNotFound) { - Map config = { + ffi::Map config = { {"my_bool", "true"}, - {"your_names", Array{"junru", "jian"}}, + {"your_names", ffi::Array{"junru", "jian"}}, { "her_maps", - Map{ + ffi::Map{ {"a", 1}, {"b", 2}, }, @@ -170,7 +172,7 @@ TEST(TargetCreationFail, TargetKindNotFound) { TEST(TargetCreation, TargetParser) { Target test_target("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); + ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); ASSERT_EQ(test_target->keys.size(), 1); ASSERT_EQ(test_target->keys[0], "super"); } @@ -185,10 +187,10 @@ TEST(TargetCreation, TargetFeatures) { } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", true}}; - Map config = { - {"kind", String("TestTargetParser")}, - {"mcpu", String("woof")}, + ffi::Map features = {{"test", true}}; + ffi::Map config = { + {"kind", ffi::String("TestTargetParser")}, + {"mcpu", ffi::String("woof")}, {"features", features}, }; EXPECT_THROW(Target test(config), ffi::Error); @@ -196,7 +198,7 @@ TEST(TargetCreation, TargetFeaturesBeforeParser) { TEST(TargetCreation, TargetAttrsPreProcessor) { Target test_target("TestAttrsPreprocessor -mattr=cake"); - ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); + ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); } TEST(TargetCreation, ClashingTargetProcessing) { @@ -204,45 +206,46 @@ TEST(TargetCreation, ClashingTargetProcessing) { } TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU) - .add_attr_option("single") - .add_attr_option>("array") - .add_attr_option>>("nested-array") - .add_attr_option>>>("nested2-array"); + .add_attr_option("single") + .add_attr_option>("array") + .add_attr_option>>("nested-array") + .add_attr_option>>>("nested2-array"); TEST(TargetCreation, ProcessStrings) { Target test_target1("TestStringKind -single='\\'string with single quote'"); - ASSERT_TRUE(test_target1->GetAttr("single")); - String string1 = test_target1->GetAttr("single").value(); + ASSERT_TRUE(test_target1->GetAttr("single")); + ffi::String string1 = test_target1->GetAttr("single").value(); ASSERT_EQ(string1, "'string with single quote"); Target test_target2("TestStringKind -single='\\\'\\\\\\'blah\\\\\\'\\\''"); - ASSERT_TRUE(test_target2->GetAttr("single")); - String string2 = test_target2->GetAttr("single").value(); + ASSERT_TRUE(test_target2->GetAttr("single")); + ffi::String string2 = test_target2->GetAttr("single").value(); ASSERT_EQ(string2, "'\\\'blah\\\''"); Target test_target3("TestStringKind -array=-danny,-sammy=1,-kirby='string with space'"); - ASSERT_TRUE(test_target3->GetAttr>("array")); - Array array3 = test_target3->GetAttr>("array").value(); + ASSERT_TRUE(test_target3->GetAttr>("array")); + ffi::Array array3 = test_target3->GetAttr>("array").value(); ASSERT_EQ(array3[0], "-danny"); ASSERT_EQ(array3[1], "-sammy=1"); ASSERT_EQ(array3[2], "-kirby='string with space'"); Target test_target4("TestStringKind -array='fred, foo, bar',baz"); - ASSERT_TRUE(test_target4->GetAttr>("array")); - Array array4 = test_target4->GetAttr>("array").value(); + ASSERT_TRUE(test_target4->GetAttr>("array")); + ffi::Array array4 = test_target4->GetAttr>("array").value(); ASSERT_EQ(array4[0], "fred, foo, bar"); ASSERT_EQ(array4[1], "baz"); Target test_target5("TestStringKind -array='fr\\'ed','f\\'oo',' bar,baz '"); - ASSERT_TRUE(test_target5->GetAttr>("array")); - Array array5 = test_target5->GetAttr>("array").value(); + ASSERT_TRUE(test_target5->GetAttr>("array")); + ffi::Array array5 = test_target5->GetAttr>("array").value(); ASSERT_EQ(array5[0], "fr'ed"); ASSERT_EQ(array5[1], "f'oo"); ASSERT_EQ(array5[2], "bar,baz"); Target test_target6("TestStringKind -nested-array='foo0,foo1,foo2','bar0,bar1,bar2','baz0,baz1'"); - ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); - Array> array6 = test_target6->GetAttr>>("nested-array").value(); + ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); + ffi::Array> array6 = + test_target6->GetAttr>>("nested-array").value(); ASSERT_EQ(array6[0][0], "foo0"); ASSERT_EQ(array6[0][1], "foo1"); ASSERT_EQ(array6[0][2], "foo2"); @@ -257,9 +260,11 @@ TEST(TargetCreation, ProcessStrings) { "'\\'foo0,foo1\\',\\'bar0,bar1\\',\\'baz0,baz1\\''," "'\\'zing0,zing1\\',\\'fred\\''"); - ASSERT_TRUE(test_target7->GetAttr>>>("nested2-array")); - Array>> array7 = - test_target7->GetAttr>>>("nested2-array").value(); + ASSERT_TRUE( + test_target7->GetAttr>>>("nested2-array")); + ffi::Array>> array7 = + test_target7->GetAttr>>>("nested2-array") + .value(); // { // {foo0, foo1}, // {bar0, bar1}, @@ -449,8 +454,8 @@ TEST(TargetCreation, LLVMCommandLineSaveRestore) { } TEST(TargetCreation, DetectSystemTriple) { - Map config = { - {"kind", String("llvm")}, + ffi::Map config = { + {"kind", ffi::String("llvm")}, }; Target target = Target(config); @@ -461,17 +466,17 @@ TEST(TargetCreation, DetectSystemTriple) { GTEST_SKIP() << "LLVM is not available, skipping test"; } - Optional mtriple = target->GetAttr("mtriple"); - ASSERT_TRUE(mtriple.value() == (*pf)().cast()); + ffi::Optional mtriple = target->GetAttr("mtriple"); + ASSERT_TRUE(mtriple.value() == (*pf)().cast()); } #endif TEST(TargetCreation, DeduplicateKeys) { - Map config = { - {"kind", String("llvm")}, - {"keys", Array{"cpu", "arm_cpu"}}, - {"device", String("arm_cpu")}, + ffi::Map config = { + {"kind", ffi::String("llvm")}, + {"keys", ffi::Array{"cpu", "arm_cpu"}}, + {"device", ffi::String("arm_cpu")}, }; Target target = Target(config); ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); @@ -480,17 +485,17 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->keys[0], "cpu"); ICHECK_EQ(target->keys[1], "arm_cpu"); ICHECK_EQ(target->attrs.size(), 2U); - ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); + ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); } TEST(TargetKindRegistry, ListTargetKinds) { - Array names = TargetKindRegEntry::ListTargetKinds(); + ffi::Array names = TargetKindRegEntry::ListTargetKinds(); ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } TEST(TargetKindRegistry, ListTargetOptions) { TargetKind llvm = TargetKind::Get("llvm").value(); - Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); + ffi::Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); ICHECK_EQ(attrs.empty(), false); } diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index b33724c722d7..d658c094796e 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -252,10 +252,10 @@ class AsyncLocalSession : public LocalSession { std::optional async_wait_; // time evaluator - ffi::Function GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, - int device_id, int number, int repeat, int min_repeat_ms, - int limit_zero_time_iterations, int cooldown_interval_ms, - int repeats_to_cooldown) { + ffi::Function GetTimeEvaluator(ffi::Optional opt_mod, std::string name, + int device_type, int device_id, int number, int repeat, + int min_repeat_ms, int limit_zero_time_iterations, + int cooldown_interval_ms, int repeats_to_cooldown) { Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 146a5ae1f7cd..c0228a20b320 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -112,8 +112,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](ffi::PackedArgs args, ffi::Any* ret) { (args[0].cast()).CallPacked(args.Slice(1), ret); }) - .def_packed("tvmjs.testing.log_info_str", - [](ffi::PackedArgs args, ffi::Any* ret) { LOG(INFO) << args[0].cast(); }) + .def_packed( + "tvmjs.testing.log_info_str", + [](ffi::PackedArgs args, ffi::Any* ret) { LOG(INFO) << args[0].cast(); }) .def("tvmjs.testing.add_one", [](int x) { return x + 1; }) .def_packed("tvmjs.testing.wrap_callback", [](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); @@ -162,7 +163,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ data.push_back(arr_i->at(j)); } } - *ret = Array(data); + *ret = ffi::Array(data); }); }); diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index eb14a7b7d7ee..6c9f437303af 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -165,7 +165,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { const char* kind() const final { return "webgpu"; } - Optional GetFunction(const String& name) final { + ffi::Optional GetFunction(const ffi::String& name) final { // special function if (name == "webgpu.get_fmap") { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { @@ -211,7 +211,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { ffi::Bytes SaveToBytes() const final { LOG(FATAL) << "Not implemented"; } - String InspectSource(const String& format) const final { + ffi::String InspectSource(const ffi::String& format) const final { // can only return source code. return source_; } @@ -237,7 +237,7 @@ ffi::Module WebGPUModuleLoadFromBytes(const ffi::Bytes& bytes) { stream->Read(&fmap); stream->Read(&smap); - return ffi::Module(make_object(smap, fmap)); + return ffi::Module(ffi::make_object(smap, fmap)); } // for now webgpu is hosted via a vulkan module.