diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 55eff8802a3b..ed34328d1e67 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -635,7 +635,6 @@ struct AnyEqual { } } }; - } // namespace ffi // Expose to the tvm namespace for usability diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index 997c0bb17888..c75d4a075f97 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -18,21 +18,18 @@ */ /*! * \file tvm/ffi/cast.h - * \brief Value casting support + * \brief Extra value casting helpers */ #ifndef TVM_FFI_CAST_H_ #define TVM_FFI_CAST_H_ #include -#include -#include #include #include -#include - namespace tvm { namespace ffi { + /*! * \brief Get a reference type from a raw object ptr type * @@ -46,7 +43,7 @@ namespace ffi { * \return The corresponding RefType */ template -TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) { +inline RefType GetRef(const ObjectType* ptr) { static_assert(std::is_base_of_v, "Can only cast to the ref of same container type"); @@ -75,92 +72,9 @@ inline ObjectPtr GetObjectPtr(ObjectType* ptr) { "Can only cast to the ref of same container type"); return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); } - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template >> -inline SubRef Downcast(BaseRef ref) { - if (ref.defined()) { - if (!ref->template IsInstance()) { - TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; - } - return SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); - } else { - if constexpr (is_optional_type_v || SubRef::_type_is_nullable) { - return SubRef(ObjectPtr(nullptr)); - } - TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" - << SubRef::ContainerType::_type_key - << "` is not allowed. Use Downcast> instead."; - TVM_FFI_UNREACHABLE(); - } -} - -/*! - * \brief Downcast any to a specific type - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam T The target specific reference type. - */ -template -inline T Downcast(const Any& ref) { - if constexpr (std::is_same_v) { - return ref; - } else { - return ref.cast(); - } -} - -/*! - * \brief Downcast any to a specific type - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam T The target specific reference type. - */ -template -inline T Downcast(Any&& ref) { - if constexpr (std::is_same_v) { - return std::move(ref); - } else { - return std::move(ref).cast(); - } -} - -/*! - * \brief Downcast std::optional to std::optional - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam OptionalType The target optional type - */ -template >> -inline OptionalType Downcast(const std::optional& ref) { - if (ref.has_value()) { - if constexpr (std::is_same_v) { - return *ref; - } else { - return (*ref).cast(); - } - } else { - return OptionalType(std::nullopt); - } -} - } // namespace ffi -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Downcast; +using ffi::GetObjectPtr; using ffi::GetRef; } // namespace tvm #endif // TVM_FFI_CAST_H_ diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index 364f2f6540c6..8522aa93a3b9 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -18,7 +18,6 @@ */ #include #include -#include #include namespace { @@ -266,7 +265,7 @@ TEST(String, Cast) { string source = "this is a string"; String s{source}; Any r = s; - String s2 = Downcast(r); + String s2 = r.cast(); } TEST(String, Concat) { diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h new file mode 100644 index 000000000000..ae23c9e9aa33 --- /dev/null +++ b/include/tvm/node/cast.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/cast.h + * \brief Value casting helpers + */ +#ifndef TVM_NODE_CAST_H_ +#define TVM_NODE_CAST_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template >> +inline SubRef Downcast(BaseRef ref) { + if (ref.defined()) { + if (!ref->template IsInstance()) { + TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } + return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + } else { + if constexpr (ffi::is_optional_type_v || SubRef::_type_is_nullable) { + return SubRef(ffi::ObjectPtr(nullptr)); + } + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" + << SubRef::ContainerType::_type_key + << "` is not allowed. Use Downcast> instead."; + TVM_FFI_UNREACHABLE(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(const ffi::Any& ref) { + if constexpr (std::is_same_v) { + return ref; + } else { + return ref.cast(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(ffi::Any&& ref) { + if constexpr (std::is_same_v) { + return std::move(ref); + } else { + return std::move(ref).cast(); + } +} + +/*! + * \brief Downcast std::optional to std::optional + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam OptionalType The target optional type + */ +template >> +inline OptionalType Downcast(const std::optional& ref) { + if (ref.has_value()) { + if constexpr (std::is_same_v) { + return *ref; + } else { + return (*ref).cast(); + } + } else { + return OptionalType(std::nullopt); + } +} +} // namespace tvm +#endif // TVM_NODE_CAST_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 4398f5881d90..734a28c13301 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -35,6 +35,7 @@ #define TVM_NODE_NODE_H_ #include +#include #include #include #include @@ -57,8 +58,6 @@ using ffi::ObjectPtrHash; using ffi::ObjectRef; using ffi::PackedArgs; using ffi::TypeIndex; -using runtime::Downcast; -using runtime::GetRef; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 0c1ed7ca0aaf..4fe0e72e79c1 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -124,6 +124,8 @@ inline std::string DiscoAction2String(DiscoAction action) { LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast(action); } +class SessionObj; + /*! * \brief An object that exists on all workers. * @@ -156,6 +158,9 @@ class DRefObj : public Object { int64_t reg_id; /*! \brief Back-pointer to the host controler session */ ObjectRef session{nullptr}; + + private: + inline SessionObj* GetSession(); }; /*! @@ -321,18 +326,22 @@ class WorkerZeroData { // Implementation details +inline SessionObj* DRefObj::GetSession() { + return const_cast(static_cast(session.get())); +} + DRefObj::~DRefObj() { if (this->session.defined()) { - Downcast(this->session)->DeallocReg(reg_id); + GetSession()->DeallocReg(reg_id); } } ffi::Any DRefObj::DebugGetFromRemote(int worker_id) { - return Downcast(this->session)->DebugGetFromRemote(this->reg_id, worker_id); + return GetSession()->DebugGetFromRemote(this->reg_id, worker_id); } void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) { - return Downcast(this->session)->DebugSetRegister(this->reg_id, value, worker_id); + return GetSession()->DebugSetRegister(this->reg_id, value, worker_id); } template diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 1fa9f248e812..302b161b6fd7 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -39,7 +39,6 @@ using tvm::ffi::ObjectPtrEqual; using tvm::ffi::ObjectPtrHash; using tvm::ffi::ObjectRef; -using tvm::ffi::Downcast; using tvm::ffi::GetObjectPtr; using tvm::ffi::GetRef; diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 9aa34c9b4468..ed74ba7b7b2a 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -189,10 +189,12 @@ class VirtualMachine : public runtime::ModuleNode { using ContainerType = typename T::ContainerType; uint32_t key = ContainerType::RuntimeTypeIndex(); if (auto it = extensions.find(key); it != extensions.end()) { - return Downcast((*it).second); + ffi::Any value = (*it).second; + return value.cast(); } auto [it, _] = extensions.emplace(key, T::Create()); - return Downcast((*it).second); + ffi::Any value = (*it).second; + return value.cast(); } /*! @@ -224,7 +226,7 @@ class VirtualMachine : public runtime::ModuleNode { std::vector devices; /*! \brief The VM extensions. Mapping from the type index of the extension to the extension * instance. */ - std::unordered_map extensions; + std::unordered_map extensions; }; } // namespace vm diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index 7441db783296..b4773b2a816b 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -22,6 +22,7 @@ * \file node/container_printint.cc */ #include +#include #include #include @@ -62,6 +63,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << ffi::Downcast(node); + p->stream << Downcast(node); }); } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 6a60b9723d3d..04a6f7533a19 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index d60030729fba..4cce0d40d168 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -479,15 +479,15 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con for (size_t i = 0; i < calls.size(); i++) { auto& frame = calls[i]; auto it = frame.find("Hash"); - std::string name = Downcast(frame["Name"]); + std::string name = frame["Name"].cast(); if (it != frame.end()) { - name = Downcast((*it).second); + name = (*it).second.cast(); } if (frame.find("Argument Shapes") != frame.end()) { - name += Downcast(frame["Argument Shapes"]); + name += frame["Argument Shapes"].cast(); } if (frame.find("Device") != frame.end()) { - name += Downcast(frame["Device"]); + name += frame["Device"].cast(); } if (aggregates.find(name) == aggregates.end()) { @@ -680,7 +680,7 @@ Report Profiler::Report() { for (size_t i = 0; i < devs_.size(); i++) { auto row = rows[rows.size() - 1]; rows.pop_back(); - device_metrics[Downcast(row["Device"])] = row; + device_metrics[row["Device"].cast()] = row; overall_time_us = std::max(overall_time_us, row["Duration (us)"].as()->microseconds); } diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index 04e5094d8e7f..c8fbd9082103 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -33,13 +33,13 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -55,13 +55,13 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array args String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -73,16 +73,16 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -95,10 +95,10 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; @@ -110,10 +110,10 @@ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Arra if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 8844517973bb..691246c3bf77 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -149,7 +149,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { * \param entry_index The unique index of the capture function used for lookup. * \return The return value of the capture function. */ - ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, + ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, Any args, int64_t entry_index, Optional shape_expr) { CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { @@ -160,7 +160,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { } // Set up arguments for the graph execution - Array tuple_args = Downcast>(args); + Array tuple_args = args.cast>(); int nargs = static_cast(tuple_args.size()); std::vector packed_args(nargs); @@ -250,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); auto capture_func = args[1].cast(); - auto func_args = args[2].cast(); + Any func_args = args[2]; int64_t entry_index = args[3].cast(); Optional shape_expr = std::nullopt; if (args.size() == 5) { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 17b99bbc8ce4..c28e30084fc1 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -74,7 +74,7 @@ ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs args, int starting_ LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; } int index = args[i].cast(); - auto arr = Downcast>(obj); + auto arr = obj.cast>(); // make sure the index is in bounds if (index >= static_cast(arr.size())) { LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; @@ -96,10 +96,10 @@ NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* allo Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { if (src.as()) { - return ConvertNDArrayToDevice(Downcast(src), dev, alloc); + return ConvertNDArrayToDevice(src.cast(), dev, alloc); } else if (src.as()) { std::vector ret; - auto arr = Downcast>(src); + auto arr = src.cast>(); for (size_t i = 0; i < arr.size(); i++) { ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 3cd4a6ed0d81..ffb1737a7063 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -370,7 +370,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT }; if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { ICHECK_EQ(op->args.size(), 5); - Var var = runtime::Downcast(op->args[0]); + Var var = Downcast(op->args[0]); // Get the data type of the simdgroup matrix auto it = simdgroup_dtype_.find(var.get()); ICHECK(it != simdgroup_dtype_.end()) diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 2be5e148fbfe..43a976fa892f 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -204,9 +204,9 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, if (is_one(write_region->region[i]->extent)) { write_index.push_back(write_region->region[i]->min); } else { - Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + Var var = Downcast(loop_vars[j]).copy_with_suffix("_inverse"); new_loop_vars.push_back(var); - substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + substitute_map.Set(Downcast(loop_vars[j++]), var); write_index.push_back(write_region->region[i]->min + var); } }