diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 6d8620b65a34..1397bafc36ff 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -74,7 +74,7 @@ class MatchResult : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(MatchResult, ObjectRef, MatchResultNode); }; -using FCodegen = ffi::TypedFunction(Array match_results)>; +using FCodegen = ffi::TypedFunction(Array match_results)>; } // namespace relax } // namespace tvm #endif // TVM_RELAX_TIR_PATTERN_H_ diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index cf467870c60c..66c0b64f18e5 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -317,7 +317,7 @@ class MetricCollectorNode : public Object { * \returns A set of metric names and the associated values. Values must be * one of DurationNode, PercentNode, CountNode, or StringObj. */ - virtual Map Stop(ObjectRef obj) = 0; + virtual Map Stop(ffi::ObjectRef obj) = 0; virtual ~MetricCollectorNode() {} @@ -340,7 +340,7 @@ struct CallFrame { /*! Runtime of the function or op */ Timer timer; /*! Extra performance metrics */ - std::unordered_map extra_metrics; + std::unordered_map extra_metrics; /*! User defined metric collectors. Each pair is the MetricCollector and its * associated data (returned from MetricCollector.Start). */ @@ -404,12 +404,12 @@ class Profiler { * `StartCall` and `StopCall` must be nested properly. */ void StartCall(String name, Device dev, - std::unordered_map extra_metrics = {}); + std::unordered_map extra_metrics = {}); /*! \brief Stop the last `StartCall`. * \param extra_metrics Optional additional profiling information to add to * the frame (input sizes, allocations). */ - void StopCall(std::unordered_map extra_metrics = {}); + void StopCall(std::unordered_map extra_metrics = {}); /*! \brief A report of total runtime between `Start` and `Stop` as * well as individual statistics for each `StartCall`-`StopCall` pair. * \returns A `Report` that can either be formatted as CSV (with `.AsCSV`) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index dfb9c0beee72..b19bcab4c3ef 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -246,7 +246,7 @@ class LiteralDocNode : public ExprDocNode { * - String * - null */ - ObjectRef value; + ffi::Any value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -265,16 +265,14 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ObjectRef value, const Optional& object_path); + explicit LiteralDoc(ffi::Any value, const Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const Optional& p) { - return LiteralDoc(ObjectRef(nullptr), p); - } + static LiteralDoc None(const Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 909f13ecc051..9d189dda0915 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -145,7 +145,7 @@ class IRDocsifierNode : public Object { /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; /*! \brief Metadata printing */ - std::unordered_map> metadata; + std::unordered_map> metadata; /*! \brief GlobalInfo printing */ std::unordered_map> global_infos; /*! \brief The variable names used already */ @@ -206,7 +206,7 @@ class IRDocsifierNode : public Object { */ Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ - ExprDoc AddMetadata(const ObjectRef& obj); + ExprDoc AddMetadata(const ffi::Any& obj); /*! \brief Add a GlobalInfo to the global_infos map. * \param name The name of key of global_infos. * \param ginfo The GlobalInfo to be added. @@ -275,7 +275,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob const PrinterConfig& cfg) { if (cfg->obj_to_annotate.count(obj)) { if (const auto* stmt = d.as()) { - if (stmt->comment.defined()) { + if (stmt->comment.has_value()) { stmt->comment = stmt->comment.value() + "\n" + cfg->obj_to_annotate.at(obj); } else { stmt->comment = cfg->obj_to_annotate.at(obj); @@ -295,7 +295,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { - if (stmt->comment.defined()) { + if (stmt->comment.has_value()) { stmt->comment = stmt->comment.value() + "\n" + attn; } else { stmt->comment = attn; @@ -319,8 +319,16 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const ObjectPath& path) con return Downcast(LiteralDoc::Int(value.as().value(), path)); case ffi::TypeIndex::kTVMFFIFloat: return Downcast(LiteralDoc::Float(value.as().value(), path)); - case ffi::TypeIndex::kTVMFFIStr: - return Downcast(LiteralDoc::Str(value.as().value(), path)); + case ffi::TypeIndex::kTVMFFIStr: { + std::string string_value = value.cast(); + bool has_multiple_lines = string_value.find_first_of('\n') != std::string::npos; + if (has_multiple_lines) { + Doc d = const_cast(this)->AddMetadata(string_value); + // TODO(tqchen): cross check AddDocDecoration + return Downcast(d); + } + return Downcast(LiteralDoc::Str(string_value, path)); + } case ffi::TypeIndex::kTVMFFIDataType: return Downcast(LiteralDoc::DataType(value.as().value(), path)); case ffi::TypeIndex::kTVMFFIDevice: diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 250475c61d90..37410b1271cc 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -115,7 +115,7 @@ class LetStmt : public Stmt { class AttrStmtNode : public StmtNode { public: /*! \brief this is attribute about certain node */ - ObjectRef node; + ffi::Any node; /*! \brief the type key of the attribute */ String attr_key; /*! \brief The attribute value, value is well defined at current scope. */ @@ -142,7 +142,7 @@ class AttrStmtNode : public StmtNode { */ class AttrStmt : public Stmt { public: - TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); + TVM_DLL AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index f1f9e08ab9b8..4670abe52ec1 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -342,7 +342,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); const auto& name_opt = func->GetAttr(relax::attr::kComposite); - if (name_opt.defined()) { + if (name_opt.has_value()) { attrs = FuncAttrGetter().GetAttrs(func); } } else if (call_node->op->IsInstance()) { @@ -760,7 +760,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVa void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { const auto& name_opt = val->GetAttr(relax::attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected target func without composite"; + ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; target_funcs_.Set(binding->var, GetRef(val)); @@ -770,18 +770,18 @@ const std::tuple GraphBuilder::ParseFunc(const Function& String node_name, optype, layout; const auto& name_opt = func->GetAttr(msc_attr::kUnique); // get node_name - if (name_opt.defined()) { + if (name_opt.has_value()) { node_name = name_opt.value(); } // get optype const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); const auto& optype_opt = func->GetAttr(msc_attr::kOptype); const auto& composite_opt = func->GetAttr(relax::attr::kComposite); - if (codegen_opt.defined()) { + if (codegen_opt.has_value()) { optype = codegen_opt.value(); - } else if (optype_opt.defined()) { + } else if (optype_opt.has_value()) { optype = optype_opt.value(); - } else if (composite_opt.defined()) { + } else if (composite_opt.has_value()) { optype = composite_opt.value(); if (config_.target.size() > 0) { optype = StringUtils::Replace(composite_opt.value(), config_.target + ".", ""); @@ -789,7 +789,7 @@ const std::tuple GraphBuilder::ParseFunc(const Function& } // get layout const auto& layout_opt = func->GetAttr(msc_attr::kLayout); - if (layout_opt.defined()) { + if (layout_opt.has_value()) { layout = layout_opt.value(); } return std::make_tuple(node_name, optype, layout); diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc index f162f5db1e85..6ae71860b64e 100644 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -28,9 +28,9 @@ namespace contrib { namespace msc { void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ObjectRef& value = doc->value; + const ffi::Any& value = doc->value; bool defined = false; - if (!value.defined()) { + if (value == nullptr) { output_ << "nullptr"; defined = true; } else if (const auto* int_imm = value.as()) { @@ -217,7 +217,7 @@ void CppPrinter::PrintTypedDoc(const ClassDoc& doc) { } void CppPrinter::PrintTypedDoc(const CommentDoc& doc) { - if (doc->comment.defined()) { + if (doc->comment.has_value()) { output_ << "// " << doc->comment.value(); } } diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 0f0b24fd3a28..31869f29bbab 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -158,7 +158,7 @@ void MSCBasePrinter::PrintTypedDoc(const ExprStmtDoc& doc) { } void MSCBasePrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { - if (stmt->comment.defined()) { + if (stmt->comment.has_value()) { if (multi_lines) { for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) { PrintDoc(CommentDoc(l)); diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 44a915ae7b8d..82d15dc71842 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -30,7 +30,7 @@ namespace tvm { namespace contrib { namespace msc { -LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { +LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) { if (obj.as()) { return LiteralDoc::Str(Downcast(obj), std::nullopt); } else if (obj.as()) { @@ -51,7 +51,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { if (pair.second.as()) { values.push_back(Downcast(pair.second)); } else { - values.push_back(ToLiteralDoc(pair.second.cast())); + values.push_back(ToLiteralDoc(pair.second)); } } return DictDoc(keys, values); @@ -65,7 +65,7 @@ DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& di if (pair.second.as()) { values.push_back(Downcast(pair.second)); } else { - values.push_back(ToLiteralDoc(pair.second.cast())); + values.push_back(ToLiteralDoc(pair.second)); } } return DictDoc(keys, values); diff --git a/src/contrib/msc/core/printer/prototxt_printer.h b/src/contrib/msc/core/printer/prototxt_printer.h index 97c0f91818d2..e760a179d8dd 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.h +++ b/src/contrib/msc/core/printer/prototxt_printer.h @@ -50,7 +50,7 @@ class PrototxtPrinter : public MSCBasePrinter { explicit PrototxtPrinter(const std::string& options = "") : MSCBasePrinter(options) {} /*! \brief Change object to LiteralDoc*/ - static LiteralDoc ToLiteralDoc(const ObjectRef& obj); + static LiteralDoc ToLiteralDoc(const ffi::Any& obj); /*! \brief Change map to DictDoc*/ static DictDoc ToDictDoc(const Map& dict); diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index f1a13c7fd04f..184d7ce87059 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -30,9 +30,9 @@ namespace contrib { namespace msc { void PythonPrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ObjectRef& value = doc->value; + const ffi::Any& value = doc->value; bool defined = false; - if (!value.defined()) { + if (value == nullptr) { output_ << "None"; defined = true; } else if (const auto* int_imm = value.as()) { @@ -176,7 +176,7 @@ void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) { output_ << ":"; - if (doc->comment.defined()) { + if (doc->comment.has_value()) { IncreaseIndent(); MaybePrintComment(doc, true); DecreaseIndent(); @@ -197,7 +197,7 @@ void PythonPrinter::PrintTypedDoc(const ClassDoc& doc) { } void PythonPrinter::PrintTypedDoc(const CommentDoc& doc) { - if (doc->comment.defined()) { + if (doc->comment.has_value()) { output_ << "# " << doc->comment.value(); } } @@ -234,7 +234,7 @@ void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) { } void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { - if (stmt->comment.defined() && multi_lines) { + if (stmt->comment.has_value() && multi_lines) { NewLine(); output_ << "\"\"\""; for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) { diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 481a3092fec9..df534f4cfae6 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -49,7 +49,7 @@ std::tuple, Map> NormalizeNamedBindings( Map relax_var_remap; - auto normalize_key = [&](ObjectRef obj) -> relax::Var { + auto normalize_key = [&](ffi::Any obj) -> relax::Var { if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); @@ -77,18 +77,17 @@ std::tuple, Map> NormalizeNamedBindings( LOG(FATAL) << "Expected bound parameter to be a relax::Var, " << " or a string that uniquely identifies a relax::Var param within the function. " - << "However, received object " << obj << " of type " << obj->GetTypeKey(); + << "However, received object " << obj << " of type " << obj.GetTypeKey(); } }; - auto normalize_value = [&](Var key, ObjectRef obj) -> relax::Expr { + auto normalize_value = [&](Var key, ffi::Any obj) -> relax::Expr { if (auto opt = obj.as()) { return opt.value(); } else if (auto opt = obj.as()) { const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); return Constant(opt.value(), StructInfo(), span); } else { - LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() - << " into relax expression"; + LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; } }; @@ -130,7 +129,7 @@ IRModule BindNamedParam(IRModule m, String func_name, Map if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); - if (gsymbol.defined() && gsymbol.value() == func_name) { + if (gsymbol.has_value() && gsymbol.value() == func_name) { Function f_after_bind = FunctionBindNamedParams(GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 297f4a94fe1c..19b8f08f4780 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -55,7 +55,7 @@ class TupleFuser : public ExprMutator { main_var = gv; } else { const auto& name_opt = func->GetAttr(attr::kComposite); - if (name_opt.defined() && StringUtils::StartsWith(name_opt.value(), target_)) { + if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(), target_)) { target_funcs_.Set(gv, Downcast(func)); } } @@ -76,7 +76,7 @@ class TupleFuser : public ExprMutator { if (arg->IsInstance()) { String tuple_name; const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); - if (name_opt.defined()) { + if (name_opt.has_value()) { if (val->args.size() == 1) { tuple_name = name_opt.value() + "_input"; } else { diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index c68948cef5d2..086c475f6d1f 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -63,7 +63,7 @@ class ParamsInliner : public ExprMutator { } if (struct_info->IsInstance()) { const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - ICHECK(optype_opt.defined()) + ICHECK(optype_opt.has_value()) << "Can not find attr " << msc_attr::kOptype << " form extern func"; extern_types_.Set(p, optype_opt.value()); continue; diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 8687f7264790..85819ea58dc6 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -55,7 +55,7 @@ class ByocNameSetter : public ExprMutator { continue; } const auto& name_opt = func->GetAttr(attr::kCodegen); - if (name_opt.defined() && name_opt.value() == target_) { + if (name_opt.has_value() && name_opt.value() == target_) { const String& func_name = target_ + "_" + std::to_string(func_cnt); const auto& new_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name)); @@ -75,7 +75,7 @@ class ByocNameSetter : public ExprMutator { if (val->op->IsInstance()) { ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); - if (name_opt.defined()) { + if (name_opt.has_value()) { val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); } } diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index c9cf65e783dc..14ea3ccfec7b 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -160,7 +160,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { ExprVisitor::VisitBinding_(binding, val); const auto& name_opt = val->GetAttr(attr::kComposite); - if (name_opt.defined()) { + if (name_opt.has_value()) { local_funcs_.Set(binding->var, GetRef(val)); } } @@ -260,9 +260,9 @@ class RelaxExprNameSetter : public ExprVisitor { String optype; const auto& comp_opt = func->GetAttr(attr::kComposite); const auto& code_opt = func->GetAttr(attr::kCodegen); - if (comp_opt.defined()) { + if (comp_opt.has_value()) { optype = comp_opt.value(); - } else if (code_opt.defined()) { + } else if (code_opt.has_value()) { optype = code_opt.value(); } else { optype = "extern_func"; @@ -277,7 +277,7 @@ class RelaxExprNameSetter : public ExprVisitor { String name; // get from unique const auto& name_opt = func->GetAttr(msc_attr::kUnique); - if (name_opt.defined()) { + if (name_opt.has_value()) { return name_opt.value(); } // get from exprs in the func diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index caac67d6f511..f4a79602f506 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -108,6 +108,7 @@ const String CommonUtils::ToAttrKey(const String& key) { return msc_attr::kConsumerType; } LOG_FATAL << "Unexpected key " << key; + TVM_FFI_UNREACHABLE(); } bool StringUtils::Contains(const String& src_string, const String& sub_string) { @@ -261,12 +262,12 @@ const String StringUtils::Lower(const String& src_string) { return str; } -const String StringUtils::ToString(const runtime::ObjectRef& obj) { +const String StringUtils::ToString(const ffi::Any& obj) { String obj_string; - if (!obj.defined()) { + if (obj == nullptr) { obj_string = ""; - } else if (obj.as()) { - obj_string = Downcast(obj); + } else if (auto opt_str = obj.as()) { + obj_string = *opt_str; } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { @@ -370,7 +371,7 @@ const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& return Span(SourceName::Get(new_source), 0, 0, 0, 0); } -const String SpanUtils::GetAttr(const Span& span, const String& key) { +String SpanUtils::GetAttr(const Span& span, const String& key) { if (span.defined() && span->source_name.defined()) { Array tokens{"<" + key + ">", ""}; return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], tokens[1]); diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 84e3c667410c..aeb7f9eb88fd 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -173,7 +173,7 @@ class StringUtils { * \brief Change Object to String. * \return The String. */ - TVM_DLL static const String ToString(const runtime::ObjectRef& obj); + TVM_DLL static const String ToString(const ffi::Any& obj); }; /*! @@ -287,7 +287,7 @@ class SpanUtils { * \brief Get the value in value from the Span. * \return The value String. */ - TVM_DLL static const String GetAttr(const Span& span, const String& key); + TVM_DLL static String GetAttr(const Span& span, const String& key); /*! * \brief Get all the key:value in format value from the Span. diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 55a46789b892..684abbe38c17 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -606,7 +606,7 @@ Array MSCTensorRTCompiler(Array functions, for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; const auto& name_opt = func->GetAttr(msc_attr::kUnique); - ICHECK(name_opt.defined()) << "Can not find " << msc_attr::kUnique << " from attrs"; + ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); ICHECK(target_option.count(name)) << "Can not find target option for " << name; diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 5c2965de7648..3d43c74958ec 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -285,7 +285,7 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call // causal_mask Expr s_value; - if (!src_attrs->causal_mask.defined()) { + if (!src_attrs->causal_mask.has_value()) { auto softmax_attrs = make_object(); softmax_attrs->axis = 2; s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index c77bbf89d0aa..3436d49b02ee 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -73,7 +73,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, std::string name = gvar->name_hint; if (tvm::runtime::regex_match(name, func_name_regex)) { at_least_one_function_matched_regex = true; - if (!func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { // Function may be mutated, but is an internal function. Mark // it as externally-exposed, so that any call-tracing internal // transforms do not remove this function, in case it its diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 40412ad3075c..77c0480f85ca 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -68,7 +68,6 @@ String NameSupplyNode::add_prefix_to_name(const String& name) { } std::ostringstream ss; - ICHECK(name.defined()); ss << prefix_ << "_" << name; return ss.str(); } diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 7994ff68635e..78278103fe64 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -69,7 +69,7 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { // The JSON object is always an array whose first element is a tag. For example: // `['TENSOR', 'float32', [1, 224, 224, 3]] // Step 1. Extract the tag - String tag{ffi::ObjectPtr(nullptr)}; + Optional tag{std::nullopt}; try { const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() >= 1); @@ -124,7 +124,7 @@ ObjectRef TensorInfoNode::AsJSON() const { static String tag = "TENSOR"; String dtype = DLDataTypeToString(this->dtype); Array shape = support::AsArray(this->shape); - return Array{tag, dtype, shape}; + return Array{tag, dtype, shape}; } TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 11b3c6dc3eb9..a3a48c6a9f31 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -46,7 +46,7 @@ ObjectRef WorkloadNode::AsJSON() const { // Dump the JSON string to base64 std::string b64_mod = Base64Encode(json_mod); // Output - return Array{SHash2Str(this->shash), String(b64_mod)}; + return Array{SHash2Str(this->shash), String(b64_mod)}; } Workload Workload::FromJSON(const ObjectRef& json_obj) { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 1f396882720b..230e4d350924 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -75,7 +75,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { if (i != 0) { os << ","; } - os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"'; + os << '"' << support::StrEscape(kv.first.data(), kv.first.size()) << '"'; os << ":"; JSONDumps(kv.second, os); } diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 65089c5ab85b..97062d112275 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -44,8 +44,8 @@ class UpdateCostModelNode : public MeasureCallbackNode { pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { - if (!builder_results[i]->error_msg.defined() && // - (runner_results[i]->error_msg.defined() || // + if (!builder_results[i]->error_msg.has_value() && // + (runner_results[i]->error_msg.has_value() || // (runner_results[i]->run_secs.defined() && Sum(runner_results[i]->run_secs.value()) > 0))) { pruned_candidate.push_back(measure_candidates[i]); diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 571ed5675e16..c17c90fe2df8 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -147,7 +147,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); // Rewrite auto tensorization related annotations - if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) { + if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).has_value()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); Optional tensorize_init = @@ -157,7 +157,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is std::nullopt. sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize, tensorize_init.value_or("")); - if (tensorize_init.defined()) { + if (tensorize_init.has_value()) { sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init); sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init); } diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index bcf927803e18..746b0487addd 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -162,7 +162,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, if (producer_srefs.size() == 1 && tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && CanReverseComputeInline(state, block_sref) && - !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) { + !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).has_value()) { return InlineType::kInlineIntoProducer; } } diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 9f71989fa41e..780b4042999e 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -80,7 +80,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { - if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").defined()) { + if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").has_value()) { stack.emplace_back(sch, blocks); continue; } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 1d9d8e89ad7c..453239dd4ac4 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -71,7 +71,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { for (int i = 0; i < n; ++i) { const MeasureCandidate& candidate = candidates[i]; const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { + if (builder_result->error_msg.has_value()) { ++n_build_errors; continue; } @@ -88,7 +88,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { results.reserve(n); for (int i = 0, j = 0; i < n; ++i) { const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { + if (builder_result->error_msg.has_value()) { results.push_back(RunnerFuture( /*f_done=*/[]() -> bool { return true; }, /*f_result=*/ @@ -129,7 +129,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // << "[Task #" << task_id << ": " << name << "] Trial #" << trials << ": Error in " - << (builder_result->error_msg.defined() ? "building" : "running") + << (builder_result->error_msg.has_value() ? "building" : "running") << ":\n" << err << "\n" << sch->mod() << "\n" diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 4037f6757b3a..114afc0ad72e 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -57,7 +57,6 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) { for (const auto& inst : anchor_trace->insts) { if (inst->kind.same_as(kind_get_block)) { auto block_name = Downcast(inst->attrs[0]); - ICHECK(block_name.defined()); get_block_names.insert(block_name); } } diff --git a/src/node/object_path.cc b/src/node/object_path.cc index 6fd7a43a0492..3e68e0d0efa0 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -101,7 +101,7 @@ ObjectPath ObjectPathNode::Attr(const char* attr_key) const { } ObjectPath ObjectPathNode::Attr(Optional attr_key) const { - if (attr_key.defined()) { + if (attr_key.has_value()) { return ObjectPath(make_object(this, attr_key.value())); } else { return ObjectPath(make_object(this)); @@ -235,7 +235,7 @@ RootPathNode::RootPathNode(Optional name) : ObjectPathNode(nullptr), nam bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const { const auto* other = static_cast(other_path); - if (other->name.defined() != name.defined()) { + if (other->name.has_value() != name.has_value()) { return false; } else if (name && other->name) { return name.value() == other->name.value(); diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index d273ba6a734c..34b08994d04e 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -26,6 +26,8 @@ #include #include +#include "../support/str_escape.h" + namespace tvm { void ReprPrinter::Print(const ObjectRef& node) { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 2570a5e80004..c3060fc91f55 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -108,6 +108,9 @@ class NodeIndexer { MakeIndex(kv.second); } } + } else if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { + // skip content index for string and bytes } else if (auto opt_object = node.as()) { Object* n = const_cast(opt_object.value()); // if the node already have repr bytes, no need to visit Attrs. @@ -272,6 +275,10 @@ class JSONAttrGetter { node_->data.push_back(node_index_->at(kv.second)); } } + } else if (auto opt_str = node.as()) { + node_->repr_bytes = *opt_str; + } else if (auto opt_bytes = node.as()) { + node_->repr_bytes = *opt_bytes; } else if (auto opt_object = node.as()) { Object* n = const_cast(opt_object.value()); // do not need to print additional things once we have repr bytes. @@ -399,6 +406,11 @@ class FieldDependencyFinder { if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { return; } + if (node.type_index() == ffi::TypeIndex::kTVMFFIStr || + node.type_index() == ffi::TypeIndex::kTVMFFIBytes) { + // skip indexing content of string and bytes + return; + } // Skip the objects that have their own string repr if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(node.cast(), nullptr)) { @@ -552,6 +564,10 @@ class JSONAttrSetter { setter.ParseValue("v_device_type", &device_type); setter.ParseValue("v_device_id", &device_id); return Any(DLDevice{static_cast(device_type), device_id}); + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr) { + return Any(String(jnode->repr_bytes)); + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { + return Any(Bytes(jnode->repr_bytes)); } else { return ObjectRef(reflection->CreateInitObject(jnode->type_key, jnode->repr_bytes)); } @@ -581,6 +597,9 @@ class JSONAttrSetter { } } *node = result; + } else if (jnode->type_key == ffi::StaticTypeKey::kTVMFFIStr || + jnode->type_key == ffi::StaticTypeKey::kTVMFFIBytes) { + // skip set attrs for string and bytes } else if (auto opt_object = node->as()) { Object* n = const_cast(opt_object.value()); if (n == nullptr) return; diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 383f344facae..bf9d7b23d5a7 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -58,24 +58,12 @@ struct RefToObjectPtr : public ObjectRef { } }; -TVM_REGISTER_REFLECTION_VTABLE(ffi::StringObj) - .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) - .set_repr_bytes([](const Object* n) -> std::string { - return GetRef(static_cast(n)).operator std::string(); - }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; }); -TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj) - .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) - .set_repr_bytes([](const Object* n) -> std::string { - return GetRef(static_cast(n)).operator std::string(); - }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index db216aba967c..a1bc99ee75bf 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -140,7 +140,7 @@ class WellFormedChecker : public relax::ExprVisitor, // check name in global var and gsymbol Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (gsymbol.defined() && gsymbol != var->name_hint) { + if (gsymbol.has_value() && gsymbol != var->name_hint) { Malformed(Diagnostic::Error(func->span) << "Name in GlobalVar is not equal to name in gsymbol: " << var << " != " << gsymbol.value()); diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 8d2e97a0116a..b2c3e47c73a0 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -112,7 +112,7 @@ class OpAttrExtractor { } } - void Visit(const char* key, runtime::ObjectRef* value) { + void Visit(const char* key, ffi::Any* value) { if (const auto* an = (*value).as()) { std::vector attr; for (size_t i = 0; i < an->size(); ++i) { @@ -120,25 +120,23 @@ class OpAttrExtractor { attr.push_back(std::to_string(im->value)); } else if (const auto* fm = (*an)[i].as()) { attr.push_back(Fp2String(fm->value)); - } else if (const auto* str = (*an)[i].as()) { - String s = GetRef(str); - attr.push_back(s); + } else if (auto opt_str = (*an)[i].as()) { + attr.push_back(*opt_str); } else { LOG(FATAL) << "Not supported type: " << (*an)[i].GetTypeKey(); } } SetNodeAttr(key, attr); - } else if (!(*value).defined()) { // Skip NullValue + } else if (*value == nullptr) { // Skip NullValue SetNodeAttr(key, std::vector{""}); } else if (const auto* im = (*value).as()) { SetNodeAttr(key, std::vector{std::to_string(im->value)}); } else if (const auto* fm = (*value).as()) { SetNodeAttr(key, std::vector{Fp2String(fm->value)}); - } else if (const auto* str = (*value).as()) { - String s = GetRef(str); - SetNodeAttr(key, std::vector{s}); + } else if (const auto opt_str = (*value).as()) { + SetNodeAttr(key, std::vector{*opt_str}); } else { - LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value; + LOG(FATAL) << "Not yet supported type: " << (*value).GetTypeKey(); } } @@ -178,14 +176,12 @@ class OpAttrExtractor { break; } case ffi::TypeIndex::kTVMFFINDArray: { - runtime::NDArray value = field_value.cast(); - this->Visit(field_info->name.data, &value); + this->Visit(field_info->name.data, &field_value); break; } default: { if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef obj = field_value.cast(); - this->Visit(field_info->name.data, &obj); + this->Visit(field_info->name.data, &field_value); break; } LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); @@ -294,7 +290,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } else if (const auto* fn = cn->op.as()) { ICHECK(false); auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); - ICHECK(pattern.defined()); + ICHECK(pattern.has_value()); std::vector values; values.push_back(pattern.value()); std::vector attr; @@ -394,7 +390,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { name = op_node->name; } else if (const auto* fn = cn->op.as()) { auto comp = fn->GetAttr(attr::kComposite); - ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; + ICHECK(comp.has_value()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); @@ -422,7 +418,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->GetAttr(attr::kComposite).defined()) + ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 8665fe347ecb..53c1ca03976d 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -141,7 +141,7 @@ class TensorRTJSONSerializer : public JSONSerializer { const auto fn = Downcast(bindings_[GetRef(fn_var)]); auto opt_composite = fn->GetAttr(attr::kComposite); - ICHECK(opt_composite.defined()); + ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); // Collect the constants and attributes of all operator calls inside the composite body. diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 13fe82d4bc9c..27165db34350 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -83,8 +83,8 @@ class CodeGenVM : public ExprFunctor { void Codegen(const Function& func) { Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " - "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; Array param_names; for (Var param : func->params) { @@ -293,12 +293,12 @@ class CodeGenVM : public ExprFunctor { // At this point: all global var must corresponds to the right symbol. // TODO(relax-team): switch everything to extern before splitting TIR/relax // so we do not have idle global var here. - if (!symbol.defined()) { + if (!symbol.has_value()) { symbol = gvar->name_hint; kind = VMFuncInfo::FuncKind::kPackedFunc; } // declare the function to be safe. - ICHECK(symbol.defined()); + ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return builder_->GetFunction(symbol.value()); } diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 042bd5301a28..c7cf06ea9d7f 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -127,7 +127,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, int64_t dst_anylist_slot = -1) { Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.defined()) << "All functions must have global symbol at this phase"; + ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { @@ -148,8 +148,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { tir::PrimFunc Codegen(const Function& func) { Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " - "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; // initialize the state stmt_stack_ = {}; registers_num_ = 0; @@ -379,7 +379,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const GlobalVarNode* op) final { VMFuncInfo::FuncKind kind; auto symbol = LookupFunction(GetRef(op), &kind); - ICHECK(symbol.defined()); + ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return FuncListGet(builder_->GetFunction(symbol.value()).value()); } @@ -452,7 +452,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { VMFuncInfo::FuncKind kind; auto symbol = LookupFunction(call_node->op, &kind); - if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) { + if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) { // primfunc in the same module. // use cpacked to directly invoke without named based lookup if (Optional prim_func = LookupPrimFunc(symbol.value())) { diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 123b18d81cf5..5462154babc8 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -689,7 +689,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { Map new_subroutines; for (const auto& [gvar, func] : mod->functions) { if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { - bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); CHECK(!is_public) << "ValueError: " << "Expected module to have no publicly-exposed functions " << "other than 'pattern' and 'replacement'. " diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 9ae492262d75..4013d3aad17e 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -123,7 +123,7 @@ class AlterOpImplMutator : public ExprMutator { // If the callee does not have kOperatorName attribute or no replacement is requested for // it, nothing to do here. - if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value()) == 0) return call; + if (!maybe_op_kind.has_value() || op_impl_map_.count(maybe_op_kind.value()) == 0) return call; auto op_kind = maybe_op_kind.value(); const auto& replacement_func = op_impl_map_[op_kind]; diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 5c2fb9a79761..9ef135608dc4 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -55,7 +55,7 @@ Pass AttachGlobalSymbol() { new_func = WithAttr(GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); } - if (new_name.defined() && (!old_name.defined() || old_name.value() != new_name.value())) { + if (new_name.has_value() && (!old_name.has_value() || old_name.value() != new_name.value())) { updates->Add(gvar, new_func); if (new_name.value() != gvar->name_hint) { GlobalVar new_gvar(new_name.value()); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 6103dbbaec5b..49fe469e8927 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -97,7 +97,7 @@ std::tuple, Map> NormalizeBindings( Map relax_var_remap; - auto normalize_key = [&](ObjectRef obj) -> relax::Var { + auto normalize_key = [&](ffi::Any obj) -> relax::Var { if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); @@ -125,17 +125,16 @@ std::tuple, Map> NormalizeBindings( LOG(FATAL) << "Expected bound parameter to be a relax::Var, " << " or a string that uniquely identifies a relax::Var param within the function. " - << "However, received object " << obj << " of type " << obj->GetTypeKey(); + << "However, received object " << obj << " of type " << obj.GetTypeKey(); } }; - auto normalize_value = [&](ObjectRef obj) -> relax::Expr { + auto normalize_value = [&](ffi::Any obj) -> relax::Expr { if (auto opt = obj.as()) { return opt.value(); } else if (auto opt = obj.as()) { return Constant(opt.value()); } else { - LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() - << " into relax expression"; + LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; } }; @@ -181,7 +180,7 @@ IRModule BindParam(IRModule m, String func_name, Map bind_ if (relax_f->GetLinkageType() == LinkageType::kExternal) { // Use global_symbol if it's external linkage Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); - if (gsymbol.defined() && gsymbol.value() == func_name) { + if (gsymbol.has_value() && gsymbol.value() == func_name) { Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 7fd75ed7d3f7..22c557874cde 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relax { -Function FunctionBindSymbolicVars(Function func, Map obj_remap) { +Function FunctionBindSymbolicVars(Function func, Map obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; @@ -50,7 +50,7 @@ Function FunctionBindSymbolicVars(Function func, Map obj_re // Replacement map to be used when rewriting the function. Map var_remap; for (const auto& [key, replacement] : obj_remap) { - if (auto opt = key.as()) { + if (auto opt = key.as()) { String string_key = opt.value(); auto it = string_lookup.find(string_key); CHECK(it != string_lookup.end()) @@ -74,7 +74,7 @@ Function FunctionBindSymbolicVars(Function func, Map obj_re var_remap.Set(var, replacement); } else { LOG(FATAL) << "Expected symbolic variable to be a tir::Var or a string name, " - << "but " << key << " was of type " << key->GetTypeKey(); + << "but " << key << " was of type " << key.GetTypeKey(); } } @@ -90,15 +90,15 @@ Function FunctionBindSymbolicVars(Function func, Map obj_re } namespace { -IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_map) { - std::unordered_set used; +IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_map) { + std::unordered_set used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> Map { + auto func_binding_map = [&]() -> Map { std::unordered_set var_names; std::unordered_set vars; for (const auto& var : DefinedSymbolicVars(func)) { @@ -106,7 +106,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m vars.insert(var.get()); } - Map out; + Map out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; if (auto opt = key.as()) { @@ -115,10 +115,10 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m used_by_function = vars.count(ptr); } else { LOG(FATAL) << "Expected symbolic variable to be a tir::Var " - << "or a string name, but " << key << " was of type " << key->GetTypeKey(); + << "or a string name, but " << key << " was of type " << key.GetTypeKey(); } if (used_by_function) { - used.insert(key.get()); + used.insert(key); out.Set(key, replacement); } } @@ -132,9 +132,9 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m } } - Array unused; + Array unused; for (const auto& [key, replacement] : binding_map) { - if (!used.count(key.get())) { + if (!used.count(key)) { unused.push_back(key); } } diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 17ab181ec946..5b711b767562 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -33,7 +33,7 @@ template using PMap = std::unordered_map; Optional ExpandParams(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; bool has_tuple_param = std::any_of( diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index a9ebdfebf303..819de35e20f0 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -86,7 +86,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& int idx = 0; bool no_valid = true; // whether there is no valid schedule in this iteration for (const meta_schedule::BuilderResult& builder_result : builder_results) { - if (!builder_result->error_msg.defined()) { + if (!builder_result->error_msg.has_value()) { results.push_back(candidates.value()[idx]->sch->mod()); valid_count--; no_valid = false; @@ -98,7 +98,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& Array runner_inputs; int idx = 0; for (const meta_schedule::BuilderResult& builder_result : builder_results) { - if (!builder_result->error_msg.defined()) { + if (!builder_result->error_msg.has_value()) { runner_inputs.push_back(meta_schedule::RunnerInput( /*artifact_path=*/builder_result->artifact_path.value(), /*device_type=*/target->kind->name, @@ -109,7 +109,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& Array runner_futures = runner->Run(runner_inputs); for (const meta_schedule::RunnerFuture& runner_future : runner_futures) { meta_schedule::RunnerResult runner_result = runner_future->Result(); - if (runner_result->error_msg.defined()) { + if (runner_result->error_msg.has_value()) { costs.push_back(1e10); } else { double sum = 0; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index bfc278b9c779..c6f947016755 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -120,7 +120,7 @@ class GraphCreator : public ExprVisitor { // true. const auto* func = it.second.as(); if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) || - func->GetAttr(attr::kCodegen).defined()) { + func->GetAttr(attr::kCodegen).has_value()) { continue; } creator(GetRef(func)); @@ -733,7 +733,7 @@ class OperatorFusor : public ExprMutator { // Only visit Relax functions with neither attr::kPrimitive nor // attr::kCodegen. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive) && - !func->GetAttr(attr::kCodegen).defined()) { + !func->GetAttr(attr::kCodegen).has_value()) { auto updated_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, updated_func); } @@ -1263,8 +1263,8 @@ class CompositeFunctionAnnotator : public ExprMutator { } const auto& base_func = (*it).second; if (const auto* func = base_func.as()) { - if (func->GetAttr(attr::kComposite).defined() || - func->GetAttr(attr::kCodegen).defined()) { + if (func->GetAttr(attr::kComposite).has_value() || + func->GetAttr(attr::kCodegen).has_value()) { continue; } @@ -1363,8 +1363,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, } const FunctionNode* function = base_func.as(); if (function->GetAttr(attr::kPrimitive).value_or(false) || - function->GetAttr(attr::kComposite).defined() || - function->GetAttr(attr::kCodegen).defined()) { + function->GetAttr(attr::kComposite).has_value() || + function->GetAttr(attr::kCodegen).has_value()) { continue; } entry_functions.push_back(Downcast(base_func)); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 2c393a4a93d6..44363e19464f 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -178,7 +178,7 @@ Pass InlinePrivateFunctions() { for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto func = opt.value(); - bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_private = !func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_private) { replacements.Set(gvar, func); } diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 23c99eb92884..9b59b680eceb 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -249,7 +249,7 @@ namespace transform { Pass LazyGetInput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyInputs(func); @@ -267,7 +267,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { - if (!func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { return func; } return WithLazyOutputs(func); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 3c8511f8edb2..025e91c3c3ab 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -270,7 +270,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { std::vector GetGroupsToMerge(const CallNode* call) { Optional codegen_name = GetCodegenName(call->op); - if (!codegen_name.defined()) { + if (!codegen_name.has_value()) { return {}; } diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 85b021e2f552..acad7d154402 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -84,7 +84,7 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = if (Database::Current().defined()) { database = Database::Current().value(); } else { - ICHECK(work_dir.defined()); + ICHECK(work_dir.has_value()); String path_workload = work_dir.value() + "/database_workload.json"; String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 3a2dd9c2194d..26145cde1d48 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -44,7 +44,7 @@ class PartialTupleUsageCollector : ExprVisitor { PMap num_outputs; for (const auto& [gvar, base_func] : mod->functions) { - bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (!is_exposed) { if (auto relax_func = base_func.as()) { diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 778e551e9a65..2e88ebe417b3 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -55,7 +55,7 @@ struct CalleeAnalysis { }; std::optional AnalyzeCallee(Function func) { - bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return std::nullopt; auto free_relax_vars = [&]() -> PSet { diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index e669979d0949..41528c7d8690 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -571,7 +571,7 @@ std::pair> SplitFunctions(PrimFunc func, if (match_results.empty()) { return {func, std::nullopt}; } - Array codegen_result = f_codegen(match_results); + Array codegen_result = f_codegen(match_results); ICHECK(codegen_result.size() == 3); String library_code = Downcast(codegen_result[0]); int num_matched_ops = Downcast(codegen_result[1])->value; @@ -662,7 +662,7 @@ void StringReplace(std::string* subject, const std::string& search, const std::s tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { using namespace tvm::tir; Optional library_code = pf->attrs.GetAttr(kLibraryKernel); - if (!library_code.defined()) { + if (!library_code.has_value()) { return GetRef(pf); } std::string source = library_code.value(); diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 9e54a8fac8d2..f2e185ebd2d4 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -379,30 +379,16 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, // appear in the **function signature**. Map var_upper_bound_attr_raw = func->GetAttr>("tir_var_upper_bound").value_or(Map()); - Array non_negative_var_attr_raw = - func->GetAttr>("tir_non_negative_var").value_or(Array()); + Array non_negative_var_attr_raw = + func->GetAttr>("tir_non_negative_var").value_or(Array()); std::unordered_map var_upper_bound_attr; std::unordered_set non_negative_var_attr; // We manually check the value type to ensure the values are all positive IntImm. - for (auto it : var_upper_bound_attr_raw) { - const auto* key = it.first.as(); - const auto* value = it.second.as(); - CHECK(key != nullptr) - << "The entry key of attr `tir_var_upper_bound` should be string. However " - << it.first->GetTypeKey() << " is got."; - CHECK(value != nullptr) - << "The entry value of attr `tir_var_upper_bound` should be integer. However " - << it.second.GetTypeKey() << " is got."; - CHECK_GT(value->value, 0) - << "The entry value of attr `tir_var_upper_bound` should be a positive integer, while " - << value->value << " is got."; - var_upper_bound_attr[GetRef(key)] = GetRef(value); - } - for (ObjectRef var_name : non_negative_var_attr_raw) { - const auto* key = var_name.as(); - CHECK(key != nullptr) << "The element of attr `tir_non_negative_var` should be string. However " - << var_name->GetTypeKey() << " is got."; - non_negative_var_attr.insert(GetRef(key)); + for (auto [key, value] : var_upper_bound_attr_raw) { + var_upper_bound_attr[key] = value; + } + for (const String& var_name : non_negative_var_attr_raw) { + non_negative_var_attr.insert(var_name); } Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); for (const tir::Var& tir_var : var_in_signature) { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index edd953e3126e..009d00260781 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -125,7 +125,7 @@ TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, Array entry_fu */ inline std::string GetExtSymbol(const Function& func) { const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; + ICHECK(name_node.has_value()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 06eb0284f9d0..947d8884a59c 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -107,7 +107,7 @@ static size_t GetDataAlignment(const DLDataType dtype) { } size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { size_t size = 1; for (int i = 0; i < arr.ndim; ++i) { size *= static_cast(arr.shape[i]); @@ -121,7 +121,7 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") { + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { // by default, we can always redirect to the flat memory allocations DLTensor temp; temp.data = nullptr; diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 30a1e6ed6609..a6af311e6ef7 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -81,7 +81,7 @@ struct DiscoProtocol { } support::Arena arena_; - std::vector object_arena_; + std::vector object_arena_; friend struct RPCReference; }; @@ -175,7 +175,7 @@ inline void DiscoProtocol::WriteObject(Object* obj) { template inline void DiscoProtocol::ReadObject(TVMFFIAny* out) { SubClassType* self = static_cast(this); - ObjectRef result{nullptr}; + ffi::Any result{nullptr}; uint32_t type_index; self->template Read(&type_index); if (type_index == TypeIndex::kRuntimeDiscoDRef) { diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index fd8d6e53bfed..a26f113f1e9b 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -74,7 +74,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap // in Hexagon's "indirect tensor" format: // - shape[0] indicates the number of tensor-content memory allocations. // - shape[1] indicates the size of each tensor-content memory allocation. - if (!mem_scope.defined() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { return DeviceAPI::AllocDataSpace(dev, ndim, shape, dtype, mem_scope); } diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index 763410080927..cef445ee91c0 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -234,10 +234,10 @@ NDArray Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, size_t size = ffi::GetDataSize(shape.Product(), dtype); Buffer buffer; - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { buffer = this->Alloc(dev, size, alignment, dtype); } else { - buffer = this->Alloc(dev, shape, dtype, mem_scope.value()); + buffer = this->Alloc(dev, shape, dtype, *mem_scope); } return NDArray::FromNDAlloc(BufferAlloc(buffer), shape, dtype, dev); } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 8acefecaad8a..aa629aef50a7 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -40,8 +40,6 @@ namespace runtime { inline String get_name_mangled(const String& module_name, const String& name) { std::stringstream ss; - ICHECK(module_name.defined()); - ICHECK(name.defined()); ss << module_name << "_" << name; return ss.str(); } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 15616f126724..176884383d83 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -77,7 +77,7 @@ ImageInfo GetImageInfo(const cl::BufferDescriptor* desc, const DLTensor* tensor) cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( Optional mem_scope) { - if (!mem_scope.defined()) { + if (!mem_scope.has_value()) { return cl::BufferDescriptor::MemoryLayout::kBuffer1D; } else if (mem_scope.value() == "global.texture") { return cl::BufferDescriptor::MemoryLayout::kImage2DActivation; @@ -277,7 +277,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, D back_buffer->mbuf = buf; } - if (!mem_scope.defined()) { + if (!mem_scope.has_value()) { mem_scope = String("global.texture"); } return AllocCLImage(dev, back_buffer, width, height, row_pitch, type_hint, mem_scope); @@ -286,7 +286,7 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t width, size_t height, D void* OpenCLWorkspace::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, Optional mem_scope) { this->Init(); - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { size_t size = GetMemObjectSize(dev, ndim, shape, dtype); cl::BufferDescriptor* ret_buffer = nullptr; auto buf = MemoryManager::GetOrCreateAllocator(dev, AllocatorType::kPooled) @@ -349,7 +349,7 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, } size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { return DeviceAPI::GetDataSize(arr); } cl_uint row_align = GetImageAlignment(GetThreadEntry()->device.device_id); @@ -366,7 +366,7 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ffi::Shape sha // Fall back for devices w/o "cl_khr_image2d_from_buffer" if (!IsBufferToImageSupported(dev.device_id)) { cl::BufferDescriptor* ret_desc = desc; // buffer -> buffer - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { if (desc->layout != cl::BufferDescriptor::MemoryLayout::kBuffer1D) { // image -> buffer size_t nbytes = GetMemObjectSize(dev, shape.size(), shape.data(), dtype); @@ -389,7 +389,7 @@ void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ffi::Shape sha return ret_desc; } - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + if (!mem_scope.has_value() || (*mem_scope).empty() || (*mem_scope) == "global") { if (desc->layout == cl::BufferDescriptor::MemoryLayout::kBuffer1D) { // buffer -> buffer return desc; diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 390f383c006a..4e29dcc39232 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -151,7 +151,7 @@ void Profiler::Start() { } void Profiler::StartCall(String name, Device dev, - std::unordered_map extra_metrics) { + std::unordered_map extra_metrics) { std::vector> objs; for (auto& collector : collectors_) { ObjectRef obj = collector->Start(dev); @@ -162,7 +162,7 @@ void Profiler::StartCall(String name, Device dev, in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics, objs}); } -void Profiler::StopCall(std::unordered_map extra_metrics) { +void Profiler::StopCall(std::unordered_map extra_metrics) { CallFrame cf = in_flight_.top(); cf.timer->Stop(); for (auto& p : extra_metrics) { @@ -172,7 +172,7 @@ void Profiler::StopCall(std::unordered_map extra_metrics for (const auto& obj : cf.extra_collectors) { auto collector_metrics = obj.first->Stop(obj.second); for (auto& p : collector_metrics) { - cf.extra_metrics[p.first] = p.second.cast(); + cf.extra_metrics[p.first] = p.second; } } in_flight_.pop(); @@ -303,10 +303,10 @@ String ReportNode::AsCSV() const { } namespace { -void metric_as_json(std::ostream& os, ObjectRef o) { - if (o.as()) { +void metric_as_json(std::ostream& os, ffi::Any o) { + if (auto opt_str = o.as()) { os << "{\"string\":" - << "\"" << Downcast(o) << "\"" + << "\"" << *opt_str << "\"" << "}"; } else if (const CountNode* n = o.as()) { os << "{\"count\":" << n->value << "}"; @@ -320,7 +320,7 @@ void metric_as_json(std::ostream& os, ObjectRef o) { os << "{\"ratio\":" << std::setprecision(std::numeric_limits::max_digits10) << std::fixed << n->ratio << "}"; } else { - LOG(FATAL) << "Unprintable type " << o->GetTypeKey(); + LOG(FATAL) << "Unprintable type " << o.GetTypeKey(); } } } // namespace @@ -340,7 +340,7 @@ String ReportNode::AsJSON() const { s << "{"; for (const auto& kv : calls[i]) { s << "\"" << kv.first << "\":"; - metric_as_json(s, kv.second.cast()); + metric_as_json(s, kv.second); if (j < calls[i].size() - 1) { s << ","; } @@ -360,7 +360,7 @@ String ReportNode::AsJSON() const { s << "\"" << dev_kv.first << "\":{"; for (const auto& metric_kv : dev_kv.second) { s << "\"" << metric_kv.first << "\":"; - metric_as_json(s, metric_kv.second.cast()); + metric_as_json(s, metric_kv.second); if (j < dev_kv.second.size() - 1) { s << ","; } @@ -378,7 +378,7 @@ String ReportNode::AsJSON() const { size_t k = 0; for (const auto& kv : configuration) { s << "\"" << kv.first << "\":"; - metric_as_json(s, kv.second.cast()); + metric_as_json(s, kv.second); if (k < configuration.size() - 1) { s << ","; } @@ -392,7 +392,7 @@ String ReportNode::AsJSON() const { // Aggregate a set of values for a metric. Computes sum for Duration, Count, // and Percent; average for Ratio; and assumes all Strings are the same. All // ObjectRefs in metrics must have the same type. -ObjectRef AggregateMetric(const std::vector& metrics) { +Any AggregateMetric(const std::vector& metrics) { ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics"; if (metrics[0].as()) { double sum = 0; @@ -421,7 +421,7 @@ ObjectRef AggregateMetric(const std::vector& metrics) { } else if (metrics[0].as()) { for (auto& m : metrics) { if (Downcast(metrics[0]) != Downcast(m)) { - return ObjectRef(String("")); + return String(""); } } // Assume all strings in metrics are the same. @@ -429,8 +429,8 @@ ObjectRef AggregateMetric(const std::vector& metrics) { } else { LOG(FATAL) << "Can only aggregate metrics with types DurationNode, CountNode, " "PercentNode, RatioNode, and StringObj, but got " - << metrics[0]->GetTypeKey(); - return ObjectRef(); // To silence warnings + << metrics[0].GetTypeKey(); + return ffi::Any(); // To silence warnings } } @@ -446,7 +446,7 @@ static void set_locale_for_separators(std::stringstream& s) { } } -static String print_metric(ObjectRef metric) { +static String print_metric(ffi::Any metric) { std::string val; if (metric.as()) { std::stringstream s; @@ -470,7 +470,7 @@ static String print_metric(ObjectRef metric) { } else if (metric.as()) { val = Downcast(metric); } else { - LOG(FATAL) << "Cannot print metric of type " << metric->GetTypeKey(); + LOG(FATAL) << "Cannot print metric of type " << metric.GetTypeKey(); } return val; } @@ -509,7 +509,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } } for (const std::string& metric : metrics) { - std::vector per_call; + std::vector per_call; for (auto i : p.second) { auto& call = calls[i]; auto it = std::find_if(call.begin(), call.end(), @@ -517,7 +517,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con return std::string(call_metric.first) == metric; }); if (it != call.end()) { - per_call.push_back((*it).second.cast()); + per_call.push_back((*it).second); } } if (per_call.size() > 0) { @@ -719,7 +719,7 @@ Map parse_metrics(dmlc::JSONReader* reader) { std::string metric_name, metric_value_name; Map metrics; while (reader->NextObjectItem(&metric_name)) { - ObjectRef o; + ffi::Any o; reader->BeginObject(); reader->NextObjectItem(&metric_value_name); if (metric_value_name == "microseconds") { diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 7ee721405619..9b9816a4d993 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -1160,7 +1160,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { temp.shape = const_cast(shape); temp.strides = nullptr; temp.byte_offset = 0; - if (mem_scope.defined()) { + if (mem_scope.has_value()) { return endpoint_ ->SysCallRemote(RPCCode::kDevAllocDataWithScope, &temp, static_cast(mem_scope.value())) diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index f3aaf3f68835..04e5094d8e7f 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -25,12 +25,12 @@ namespace tvm { namespace runtime { namespace vm { -std::unique_ptr ConvertPagedPrefillFunc(Array args, +std::unique_ptr ConvertPagedPrefillFunc(Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = Downcast(args[1]); @@ -47,12 +47,12 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, throw; } -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); ffi::Function attn_func = Downcast(args[1]); @@ -69,7 +69,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array arg throw; } -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind) { +std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; } @@ -90,7 +90,7 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, A throw; } -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; @@ -105,7 +105,7 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< throw; } -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, AttnKind attn_kind) { if (args.empty()) { return nullptr; diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 21d8d81a8be8..449a1def0a38 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -499,8 +499,7 @@ class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillFunc pointer. */ -std::unique_ptr ConvertPagedPrefillFunc(Array args, - AttnKind attn_kind); +std::unique_ptr ConvertPagedPrefillFunc(Array args, AttnKind attn_kind); /*! * \brief Create a PagedDecodeFunc from the given arguments and the attention kind. @@ -508,7 +507,7 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedDecodeFunc pointer. */ -std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); +std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); /*! * \brief Create a RaggedPrefillFunc from the given arguments and the attention kind. @@ -516,7 +515,7 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, A * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillFunc(Array args, +std::unique_ptr ConvertRaggedPrefillFunc(Array args, AttnKind attn_kind); /*! @@ -525,7 +524,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array arg * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * PagedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, AttnKind attn_kind); /*! @@ -534,7 +533,7 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< * ffi::Functions. \param attn_kind The attention kind of the function. \return The created * RaggedPrefillTreeMaskFunc pointer. */ -std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, AttnKind attn_kind); } // namespace vm diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 4ce854df5985..bfd3fbd02505 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -2470,21 +2470,21 @@ TVM_FFI_STATIC_INIT_BLOCK({ Optional f_transpose_append_mha = std::nullopt; // args[13] Optional f_transpose_append_mla = std::nullopt; // args[14] std::unique_ptr f_attention_prefill_ragged = - ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill = - ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[16].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode = - ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[17].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_sliding_window = - ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); + ConvertPagedPrefillFunc(args[18].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_decode_sliding_window = - ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); + ConvertPagedDecodeFunc(args[19].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = - ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); + ConvertPagedPrefillTreeMaskFunc(args[20].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill_with_tree_mask = - ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); + ConvertRaggedPrefillTreeMaskFunc(args[21].cast>(), AttnKind::kMHA); std::unique_ptr f_mla_prefill = - ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); + ConvertPagedPrefillFunc(args[22].cast>(), AttnKind::kMLA); Array f_merge_inplace = args[23].cast>(); ffi::Function f_split_rotary = args[24].cast(); ffi::Function f_copy_single_page = args[25].cast(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 55a5a87d279a..4a026f9dadcc 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -1051,7 +1051,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl { } } - std::unordered_map metrics; + std::unordered_map metrics; metrics["Argument Shapes"] = profiling::ShapeString(arrs); // If a suitable device is found, enable profiling. diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 0cde34879e64..b0475e4fb055 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -72,7 +72,7 @@ void FunctionFrameNode::ExitWithScope() { Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); // if the function is not private, add a global symbol to its attributes - if (!is_private.value_or(Bool(false))->value && name.defined() && + if (!is_private.value_or(Bool(false))->value && name.has_value() && !attrs.count(tvm::attr::kGlobalSymbol)) { attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } @@ -89,8 +89,8 @@ void FunctionFrameNode::ExitWithScope() { builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { // Case 1. A global function of an IRModule - CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " - "function scope, if it's defined in a Module"; + CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); const String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 0bb73abf4f31..b845434e917b 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -78,7 +78,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf void FuncName(const String& name) { FunctionFrame frame = FindFunctionFrame("R.func_name"); - if (frame->name.defined()) { + if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() << "\""; } diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1eb46f70eb71..931e7e77d128 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -52,7 +52,7 @@ void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); // if the prim func is not private and there isn't already a global symbol, // add a global symbol - if (!is_private && name.defined() && !attrs.count(tvm::attr::kGlobalSymbol)) { + if (!is_private && name.has_value() && !attrs.count(tvm::attr::kGlobalSymbol)) { attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } @@ -68,8 +68,8 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " - "function scope, if it's defined in a Module"; + CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; const ir::IRModuleFrame& frame = opt_frame.value(); const String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e8c8d62c9b23..9d5d9dade5ea 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -87,7 +87,7 @@ Buffer Arg(String name, Buffer buffer) { void FuncName(String name) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); - if (frame->name.defined()) { + if (frame->name.has_value()) { LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); } frame->name = name; diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 5137266fa48b..23a2e94a7faa 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -79,7 +79,7 @@ StmtBlockDoc::StmtBlockDoc(Array stmts) { this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ObjectRef value, const Optional& object_path) { +LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) { ObjectPtr n = make_object(); n->value = value; if (object_path.defined()) { diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index a6b8a8db096c..8c352298c1a5 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -252,7 +252,7 @@ class PythonDocPrinter : public DocPrinter { } void MaybePrintCommentInline(const StmtDoc& stmt) { - if (stmt->comment.defined()) { + if (stmt->comment.has_value()) { const std::string& comment = stmt->comment.value(); bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() @@ -265,7 +265,7 @@ class PythonDocPrinter : public DocPrinter { } void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) { - if (stmt->comment.defined()) { + if (stmt->comment.has_value()) { std::vector comment_lines = support::Split(stmt->comment.value(), '\n'); bool first_line = true; size_t start_pos = output_.tellp(); @@ -313,8 +313,8 @@ class PythonDocPrinter : public DocPrinter { }; void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ObjectRef& value = doc->value; - if (!value.defined()) { + const ffi::Any& value = doc->value; + if (value == nullptr) { output_ << "None"; } else if (const auto* int_imm = value.as()) { if (int_imm->dtype.is_bool()) { @@ -354,7 +354,7 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { - LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey(); + LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); } } @@ -682,7 +682,7 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { output_ << ":"; - if (doc->comment.defined()) { + if (doc->comment.has_value()) { PrintBlockComment(doc->comment.value()); } PrintIndentedBlock(doc->body); @@ -696,20 +696,20 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { PrintDoc(doc->name); output_ << ":"; - if (doc->comment.defined()) { + if (doc->comment.has_value()) { PrintBlockComment(doc->comment.value()); } PrintIndentedBlock(doc->body); } void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { - if (doc->comment.defined()) { + if (doc->comment.has_value()) { MaybePrintCommenMultiLines(doc, false); } } void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { - if (doc->comment.defined() && !doc->comment.value().empty()) { + if (doc->comment.has_value() && !doc->comment.value().empty()) { PrintDocString(doc->comment.value()); } } diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index caa5cbe895bd..8288016d3ecd 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -22,14 +22,6 @@ namespace tvm { namespace script { namespace printer { -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](String s, ObjectPath p, IRDocsifier d) -> Doc { - if (HasMultipleLines(s)) { - return d->AddMetadata(s); - } - return LiteralDoc::Str(s, p); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 0eb5c951e567..33c0076c3044 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -81,11 +81,13 @@ Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { return it->second.creator(); } -ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { - ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata"; - String key = obj->GetTypeKey(); - Array& array = metadata[key]; - int index = std::find(array.begin(), array.end(), obj) - array.begin(); +ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { + ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata"; + String key = obj.GetTypeKey(); + Array& array = metadata[key]; + int index = std::find_if(array.begin(), array.end(), + [&](const ffi::Any& a) { return ffi::AnyEqual()(a, obj); }) - + array.begin(); if (index == static_cast(array.size())) { array.push_back(obj); } @@ -104,7 +106,7 @@ bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { auto it = obj2info.find(obj); ICHECK(it != obj2info.end()) << "No such object: " << obj; - if (it->second.name.defined()) { + if (it->second.name.has_value()) { defined_names.erase(it->second.name.value()); } obj2info.erase(it); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index d2f02f7908b9..d0b14753cc16 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -326,8 +326,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) bool IsNumber(const ExprDoc& e) { if (const auto* n = e.as()) { - if (n->value.defined()) { - return n->value->IsInstance() || n->value->IsInstance(); + if (n->value != nullptr) { + return n->value.as() || n->value.as(); } } return false; diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 239d9f67216f..50756bceb706 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -415,7 +415,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ObjectPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { if (const auto* realize = stmt->body.as()) { - if (realize->buffer.same_as(stmt->node)) { + // TODO(tqchen): add any.same_as(ObjectRef) + if (realize->buffer.same_as(stmt->node.cast())) { rhs = DocsifyBufferRealize( realize, /*value=*/d->AsDoc(stmt->value, stmt_p->Attr("value")), @@ -426,7 +427,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { - if (stmt->node->IsInstance()) { + if (stmt->node.as()) { rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); } } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 12b08b209b62..0becec1f3ff6 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -184,9 +184,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("testing.AcceptsVariant", [](Variant arg) -> String { if (auto opt_str = arg.as()) { - return opt_str.value()->GetTypeKey(); + return ffi::StringObj::_type_key; } else { - return arg.get()->GetTypeKey(); + return arg.get().GetTypeKey(); } }) .def("testing.AcceptsBool", [](bool arg) -> bool { return arg; }) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index b04d71da319c..b85b51e3d2bb 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -143,7 +143,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, t_void_p_, t_int_}, false); // initialize TVM runtime API - if (system_lib_prefix_.defined() && !target_c_runtime) { + if (system_lib_prefix_.has_value() && !target_c_runtime) { // We will need this in environment for backward registration. // Defined in include/tvm/runtime/c_backend_api.h: // int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); @@ -153,7 +153,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, } else { f_tvm_register_system_symbol_ = nullptr; } - if (dynamic_lookup || system_lib_prefix_.defined()) { + if (dynamic_lookup || system_lib_prefix_.has_value()) { f_tvm_ffi_func_call_ = llvm::Function::Create(ftype_tvm_ffi_func_call_, llvm::Function::ExternalLinkage, "TVMFFIFunctionCall", module_.get()); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index f0f1797a6c5f..6f90da3d8aea 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -495,7 +495,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); + ICHECK(global_symbol.has_value()); entry_func = global_symbol.value(); } } diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index ed636b82a2ed..bcea45cfa70e 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -313,7 +313,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { int GetCUDAComputeVersion(const Target& target) { Optional mcpu = target->GetAttr("mcpu"); - ICHECK(mcpu.defined()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target"; + ICHECK(mcpu.has_value()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target"; std::string sm_version = mcpu.value(); return std::stoi(sm_version.substr(3)); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index b3e5249c025c..924f520082fd 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -352,7 +352,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { // ICHECK(funcs.size() > 0); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. - cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.defined(), false); + cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.has_value(), false); cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); if (entry_func.length() != 0) { diff --git a/src/target/parsers/cpu.cc b/src/target/parsers/cpu.cc index c28d80133f97..ee9bf814d323 100644 --- a/src/target/parsers/cpu.cc +++ b/src/target/parsers/cpu.cc @@ -44,7 +44,7 @@ TargetJSON ParseTarget(TargetJSON target) { Optional mcpu = Downcast>(target.Get("mcpu")); // Try to fill in the blanks by detecting target information from the system - if (kind == "llvm" && !mtriple.defined() && !mcpu.defined()) { + if (kind == "llvm" && !mtriple.has_value() && !mcpu.has_value()) { String system_triple = DetectSystemTriple().value_or(""); target.Set("mtriple", system_triple); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 233353412286..2e808738ef4c 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -73,7 +73,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 962bd777f1d2..3cd4a6ed0d81 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -78,7 +78,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { // add to alloc buffer type. auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. @@ -443,7 +443,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); + ICHECK(global_symbol.has_value()); std::string func_name = global_symbol.value(); source_maker << "// Function: " << func_name << "\n"; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 31f989901686..f5bfd80fee25 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -137,7 +137,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; header_stream << "//----------------------------------------\n" @@ -767,7 +767,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); cg.Init(output_ssa); diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index 6afd087e5d85..f0226466f625 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -130,7 +130,7 @@ std::pair, std::string> Lo ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index f65566109f86..b0457a12398a 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -91,7 +91,7 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) { << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); ICHECK(iv != nullptr) << "Expected type to be IterVarNode" - << ", but get " << op->node->GetTypeKey(); + << ", but get " << op->node.GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 3b91c5e84b53..7b3c951587dd 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -329,7 +329,7 @@ IndexMap IndexMap::RenameVariables( } visited.emplace(obj.get()); Var var = Downcast(obj); - if (Optional opt_name = f_name_map(var); opt_name.defined()) { + if (Optional opt_name = f_name_map(var); opt_name.has_value()) { String name = opt_name.value(); ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); name_supply->ReserveName(name, /*add_prefix=*/false); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 17c763c6e4be..6803e01f50ba 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -84,7 +84,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TVM_REGISTER_NODE_TYPE(LetStmtNode); // AttrStmt -AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { +AttrStmt::AttrStmt(ffi::Any node, String attr_key, PrimExpr value, Stmt body, Span span) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 78cfd004dd4d..1dbbe75528d7 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -40,7 +40,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, ObjectPath path) { std::unordered_set externally_exposed; for (const auto& [gvar, func] : mod->functions) { gvars.push_back(gvar); - if (func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + if (func->GetAttr(tvm::attr::kGlobalSymbol).has_value()) { externally_exposed.insert(gvar); } } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 691ce8ebd162..3aacfa15832b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -949,7 +949,7 @@ StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { } bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { - return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).has_value(); } std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0b8aeec82c1f..c00c946852a5 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -311,9 +311,9 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional blocks_; }; GlobalVar gv = NullValue(); - if (func_name.defined()) { + if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); - } else if (func_working_on_.defined()) { + } else if (func_working_on_.has_value()) { gv = this->func_working_on_.value(); } else { LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index cbd5185ff8f1..5507c02bfe73 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -541,7 +541,7 @@ void PythonAPICall::OutputList(Array outputs) { String PythonAPICall::Str() const { std::ostringstream os; - if (output_.defined()) { + if (output_.has_value()) { os << output_.value() << " = "; } os << "sch." << method_name_ << '('; diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index f1e035d92ef7..6dd1eafcc076 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -164,13 +164,13 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. CheckParallelizability(self, GetRef(loop), for_kind, - thread_axis.defined() ? runtime::ThreadScope::Create(thread_axis.value()) - : runtime::ThreadScope{-1, -1}); + thread_axis.has_value() ? runtime::ThreadScope::Create(thread_axis.value()) + : runtime::ThreadScope{-1, -1}); // Step 3. Loop update and IR replacement ObjectPtr new_loop = make_object(*loop); new_loop->kind = for_kind; - if (thread_axis.defined()) { + if (thread_axis.has_value()) { new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), // /*var=*/Var(thread_axis.value(), loop->loop_var.dtype()), // /*iter_type=*/kThreadIndex, // diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 8160035e3d23..6efb17de25aa 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -240,7 +240,7 @@ Array TranslateAddOutputRVs( ICHECK(!rv_names->count(output.cast())) << "ValueError: The random variable has been produced once: " << rv_names->at(output.cast()); - String result{ffi::ObjectPtr{nullptr}}; + String result; if (output == nullptr) { result = "_"; } else if (output.as()) { @@ -320,8 +320,8 @@ void TraceNode::ApplyToSchedule( ObjectRef TraceNode::AsJSON(bool remove_postproc) const { std::unordered_map rv_names; - Array json_insts; - Array json_decisions; + Array json_insts; + Array json_decisions; json_insts.reserve(this->insts.size()); json_decisions.reserve(this->insts.size()); @@ -331,7 +331,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { if (remove_postproc && kind->IsPostproc()) { break; } - json_insts.push_back(Array{ + json_insts.push_back(Array{ /* 0: inst name */ kind->name, /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) @@ -346,7 +346,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { } ++i; } - return Array{ + return Array{ /* 0: trace */ std::move(json_insts), /* 1: decision */ std::move(json_decisions), }; diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d3e77e0e3b84..b9718c1a5f9c 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -118,7 +118,7 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& func_name) { GlobalVar gv = NullValue(); - if (func_name.defined()) { + if (func_name.has_value()) { gv = state_->mod->GetGlobalVar(func_name.value()); } else if (func_working_on_.defined()) { gv = this->func_working_on_.value(); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index deedfd6f68dc..0c35c5f043a2 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -290,7 +290,7 @@ inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) */ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) { Optional result = GetAnn(sref, ann_key); - return result.defined() && result.value() == ann_val; + return result.has_value() && result.value() == ann_val; } /*! diff --git a/src/tir/transforms/bind_target.cc b/src/tir/transforms/bind_target.cc index 281249f4add8..46a40228eaa1 100644 --- a/src/tir/transforms/bind_target.cc +++ b/src/tir/transforms/bind_target.cc @@ -71,7 +71,7 @@ class FunctionClassifierVisitor : public StmtExprVisitor { // Only analyze externally exposed functions as potential callers // since they represent the entry points where host/device calls originate for (const auto& [gvar, func] : mod->functions) { - bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); const auto* prim_func = func.as(); if (is_externally_exposed && prim_func != nullptr) { @@ -268,7 +268,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { } auto prim_func = GetRef(prim_func_node); - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { // Rule 1: If the function has a target, and the target has a host, and the function does not @@ -341,7 +341,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { continue; } - bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_externally_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_externally_exposed) { // Update calls in externally exposed functions to use host duplicates PrimFunc new_func = substitutor.Substitute(Downcast(func)); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 6f5a496d1f4a..a1e99313b663 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -415,7 +415,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (iter->iter_type != IterVarType::kThreadIndex) { return false; } - ICHECK(iter->thread_tag.defined()); // When there is warp memory // threadIdx.x must be set to be warp index. return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create((iter->thread_tag))); diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index 02bdfcbfedc3..f90752e26418 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -104,9 +104,9 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } static bool CheckAnnotation(const Any& annotation) { - if (auto* node = annotation.as()) { + if (auto opt_str = annotation.as()) { // Support string annotation for backward compatibility - return GetRef(node) != ""; + return *opt_str != ""; } else if (auto* node = annotation.as()) { return node->value != 0; } else if (auto opt_val = annotation.try_cast()) { diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index 9e87ffe5b2e3..8521607f893e 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -103,7 +103,7 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, // Only inline private functions. Externally-exposed functions // must be preserved so to avoid breaking callsites outside of // the IRModule. - bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_exposed = prim_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_exposed) return false; // We do not currently implement any analysis for termination of diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index b3825908b79c..d29c380b35e7 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -155,9 +155,9 @@ class CandidateSelector final : public StmtExprVisitor { } else if (op->attr_key == attr::pragma_loop_partition_hint) { if (analyzer_.CanProve(op->value)) { const VarNode* var = nullptr; - if (op->node->IsInstance()) { + if (op->node.as()) { var = op->node.as(); - } else if (op->node->IsInstance()) { + } else if (op->node.as()) { var = op->node.as()->var.get(); } ICHECK(var); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9f0228dc16d9..d95a02a0ba9c 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -187,7 +187,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { // Internal function calls do not need the ffi::Function API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (!global_symbol.defined()) { + if (!global_symbol.has_value()) { return std::nullopt; } @@ -196,7 +196,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { PrimFunc MakePackedAPI(PrimFunc func) { auto global_symbol = RequiresPackedAPI(func); - if (!global_symbol.defined()) { + if (!global_symbol.has_value()) { return func; } std::string name_hint = global_symbol.value(); @@ -365,7 +365,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - ObjectRef node = String("default"); + ffi::Any node = ffi::String("default"); seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop)); seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop)); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 898e7819062f..8276d26fcfa8 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -103,7 +103,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { // Internal function calls do not need API updates auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (!global_symbol.defined()) { + if (!global_symbol.has_value()) { return func; } @@ -128,7 +128,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) { // Setup device context Integer device_type(target_device_type); Integer device_id(0); - ObjectRef node = String("default"); + ffi::Any node = ffi::String("default"); const Stmt nop = Evaluate(0); std::vector device_init; diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 274199a6c4fd..b1f3476eab73 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -47,7 +47,7 @@ transform::Pass AnnotateEntryFunc() { bool has_external_non_primfuncs = false; IRModule with_annotations; for (const auto& [gvar, base_func] : mod->functions) { - bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); if (is_external) { if (auto ptr = base_func.as()) { with_annotations->Add(gvar,