From 076510243a575a25576c301f107630748963bec2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 8 Aug 2025 11:00:22 -0400 Subject: [PATCH] 11;rgb:1414/1414/1414# This is the 1st commit message: [FFI] Move Downcast out of ffi for now Downcast was added for backward compact reasons and it have duplicated features as Any.cast. This PR moves it out of ffi to node for now so the ffi part contains minimal set of implementations. --- ffi/include/tvm/ffi/any.h | 1 - ffi/include/tvm/ffi/cast.h | 94 +---------------- ffi/tests/cpp/test_string.cc | 3 +- include/tvm/node/cast.h | 117 ++++++++++++++++++++++ include/tvm/node/node.h | 3 +- include/tvm/runtime/disco/session.h | 15 ++- include/tvm/runtime/object.h | 1 - include/tvm/runtime/vm/vm.h | 8 +- src/node/container_printing.cc | 3 +- src/node/repr_printer.cc | 1 + src/runtime/profiling.cc | 10 +- src/runtime/vm/attn_backend.cc | 28 +++--- src/runtime/vm/cuda/cuda_graph_builtin.cc | 6 +- src/runtime/vm/vm.cc | 6 +- src/target/source/codegen_metal.cc | 2 +- src/tir/transforms/memhammer_coalesce.cc | 4 +- 16 files changed, 171 insertions(+), 131 deletions(-) create mode 100644 include/tvm/node/cast.h diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 55eff8802a3b..ed34328d1e67 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -635,7 +635,6 @@ struct AnyEqual { } } }; - } // namespace ffi // Expose to the tvm namespace for usability diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index 997c0bb17888..c75d4a075f97 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -18,21 +18,18 @@ */ /*! * \file tvm/ffi/cast.h - * \brief Value casting support + * \brief Extra value casting helpers */ #ifndef TVM_FFI_CAST_H_ #define TVM_FFI_CAST_H_ #include -#include -#include #include #include -#include - namespace tvm { namespace ffi { + /*! * \brief Get a reference type from a raw object ptr type * @@ -46,7 +43,7 @@ namespace ffi { * \return The corresponding RefType */ template -TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) { +inline RefType GetRef(const ObjectType* ptr) { static_assert(std::is_base_of_v, "Can only cast to the ref of same container type"); @@ -75,92 +72,9 @@ inline ObjectPtr GetObjectPtr(ObjectType* ptr) { "Can only cast to the ref of same container type"); return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); } - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template >> -inline SubRef Downcast(BaseRef ref) { - if (ref.defined()) { - if (!ref->template IsInstance()) { - TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; - } - return SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); - } else { - if constexpr (is_optional_type_v || SubRef::_type_is_nullable) { - return SubRef(ObjectPtr(nullptr)); - } - TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" - << SubRef::ContainerType::_type_key - << "` is not allowed. Use Downcast> instead."; - TVM_FFI_UNREACHABLE(); - } -} - -/*! - * \brief Downcast any to a specific type - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam T The target specific reference type. - */ -template -inline T Downcast(const Any& ref) { - if constexpr (std::is_same_v) { - return ref; - } else { - return ref.cast(); - } -} - -/*! - * \brief Downcast any to a specific type - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam T The target specific reference type. - */ -template -inline T Downcast(Any&& ref) { - if constexpr (std::is_same_v) { - return std::move(ref); - } else { - return std::move(ref).cast(); - } -} - -/*! - * \brief Downcast std::optional to std::optional - * - * \param ref The input reference - * \return The corresponding SubRef. - * \tparam OptionalType The target optional type - */ -template >> -inline OptionalType Downcast(const std::optional& ref) { - if (ref.has_value()) { - if constexpr (std::is_same_v) { - return *ref; - } else { - return (*ref).cast(); - } - } else { - return OptionalType(std::nullopt); - } -} - } // namespace ffi -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Downcast; +using ffi::GetObjectPtr; using ffi::GetRef; } // namespace tvm #endif // TVM_FFI_CAST_H_ diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index 364f2f6540c6..8522aa93a3b9 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -18,7 +18,6 @@ */ #include #include -#include #include namespace { @@ -266,7 +265,7 @@ TEST(String, Cast) { string source = "this is a string"; String s{source}; Any r = s; - String s2 = Downcast(r); + String s2 = r.cast(); } TEST(String, Concat) { diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h new file mode 100644 index 000000000000..ae23c9e9aa33 --- /dev/null +++ b/include/tvm/node/cast.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/cast.h + * \brief Value casting helpers + */ +#ifndef TVM_NODE_CAST_H_ +#define TVM_NODE_CAST_H_ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template >> +inline SubRef Downcast(BaseRef ref) { + if (ref.defined()) { + if (!ref->template IsInstance()) { + TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } + return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(ref))); + } else { + if constexpr (ffi::is_optional_type_v || SubRef::_type_is_nullable) { + return SubRef(ffi::ObjectPtr(nullptr)); + } + TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" + << SubRef::ContainerType::_type_key + << "` is not allowed. Use Downcast> instead."; + TVM_FFI_UNREACHABLE(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(const ffi::Any& ref) { + if constexpr (std::is_same_v) { + return ref; + } else { + return ref.cast(); + } +} + +/*! + * \brief Downcast any to a specific type + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam T The target specific reference type. + */ +template +inline T Downcast(ffi::Any&& ref) { + if constexpr (std::is_same_v) { + return std::move(ref); + } else { + return std::move(ref).cast(); + } +} + +/*! + * \brief Downcast std::optional to std::optional + * + * \param ref The input reference + * \return The corresponding SubRef. + * \tparam OptionalType The target optional type + */ +template >> +inline OptionalType Downcast(const std::optional& ref) { + if (ref.has_value()) { + if constexpr (std::is_same_v) { + return *ref; + } else { + return (*ref).cast(); + } + } else { + return OptionalType(std::nullopt); + } +} +} // namespace tvm +#endif // TVM_NODE_CAST_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 4398f5881d90..734a28c13301 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -35,6 +35,7 @@ #define TVM_NODE_NODE_H_ #include +#include #include #include #include @@ -57,8 +58,6 @@ using ffi::ObjectPtrHash; using ffi::ObjectRef; using ffi::PackedArgs; using ffi::TypeIndex; -using runtime::Downcast; -using runtime::GetRef; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 0c1ed7ca0aaf..4fe0e72e79c1 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -124,6 +124,8 @@ inline std::string DiscoAction2String(DiscoAction action) { LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast(action); } +class SessionObj; + /*! * \brief An object that exists on all workers. * @@ -156,6 +158,9 @@ class DRefObj : public Object { int64_t reg_id; /*! \brief Back-pointer to the host controler session */ ObjectRef session{nullptr}; + + private: + inline SessionObj* GetSession(); }; /*! @@ -321,18 +326,22 @@ class WorkerZeroData { // Implementation details +inline SessionObj* DRefObj::GetSession() { + return const_cast(static_cast(session.get())); +} + DRefObj::~DRefObj() { if (this->session.defined()) { - Downcast(this->session)->DeallocReg(reg_id); + GetSession()->DeallocReg(reg_id); } } ffi::Any DRefObj::DebugGetFromRemote(int worker_id) { - return Downcast(this->session)->DebugGetFromRemote(this->reg_id, worker_id); + return GetSession()->DebugGetFromRemote(this->reg_id, worker_id); } void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) { - return Downcast(this->session)->DebugSetRegister(this->reg_id, value, worker_id); + return GetSession()->DebugSetRegister(this->reg_id, value, worker_id); } template diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 1fa9f248e812..302b161b6fd7 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -39,7 +39,6 @@ using tvm::ffi::ObjectPtrEqual; using tvm::ffi::ObjectPtrHash; using tvm::ffi::ObjectRef; -using tvm::ffi::Downcast; using tvm::ffi::GetObjectPtr; using tvm::ffi::GetRef; diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 9aa34c9b4468..ed74ba7b7b2a 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -189,10 +189,12 @@ class VirtualMachine : public runtime::ModuleNode { using ContainerType = typename T::ContainerType; uint32_t key = ContainerType::RuntimeTypeIndex(); if (auto it = extensions.find(key); it != extensions.end()) { - return Downcast((*it).second); + ffi::Any value = (*it).second; + return value.cast(); } auto [it, _] = extensions.emplace(key, T::Create()); - return Downcast((*it).second); + ffi::Any value = (*it).second; + return value.cast(); } /*! @@ -224,7 +226,7 @@ class VirtualMachine : public runtime::ModuleNode { std::vector devices; /*! \brief The VM extensions. Mapping from the type index of the extension to the extension * instance. */ - std::unordered_map extensions; + std::unordered_map extensions; }; } // namespace vm diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index 7441db783296..b4773b2a816b 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -22,6 +22,7 @@ * \file node/container_printint.cc */ #include +#include #include #include @@ -62,6 +63,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << ffi::Downcast(node); + p->stream << Downcast(node); }); } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 6a60b9723d3d..04a6f7533a19 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index d60030729fba..4cce0d40d168 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -479,15 +479,15 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con for (size_t i = 0; i < calls.size(); i++) { auto& frame = calls[i]; auto it = frame.find("Hash"); - std::string name = Downcast(frame["Name"]); + std::string name = frame["Name"].cast(); if (it != frame.end()) { - name = Downcast((*it).second); + name = (*it).second.cast(); } if (frame.find("Argument Shapes") != frame.end()) { - name += Downcast(frame["Argument Shapes"]); + name += frame["Argument Shapes"].cast(); } if (frame.find("Device") != frame.end()) { - name += Downcast(frame["Device"]); + name += frame["Device"].cast(); } if (aggregates.find(name) == aggregates.end()) { @@ -680,7 +680,7 @@ Report Profiler::Report() { for (size_t i = 0; i < devs_.size(); i++) { auto row = rows[rows.size() - 1]; rows.pop_back(); - device_metrics[Downcast(row["Device"])] = row; + device_metrics[row["Device"].cast()] = row; overall_time_us = std::max(overall_time_us, row["Duration (us)"].as()->microseconds); } diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index 04e5094d8e7f..c8fbd9082103 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -33,13 +33,13 @@ std::unique_ptr ConvertPagedPrefillFunc(Array args, String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -55,13 +55,13 @@ std::unique_ptr ConvertRaggedPrefillFunc(Array args String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -73,16 +73,16 @@ std::unique_ptr ConvertPagedDecodeFunc(Array args, At if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { CHECK_EQ(args.size(), 3); - ffi::Function attn_func = Downcast(args[1]); - ffi::Function plan_func = Downcast(args[2]); + ffi::Function attn_func = args[1].cast(); + ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } @@ -95,10 +95,10 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array< if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; @@ -110,10 +110,10 @@ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Arra if (args.empty()) { return nullptr; } - String backend_name = Downcast(args[0]); + String backend_name = args[0].cast(); if (backend_name == "tir") { CHECK_EQ(args.size(), 2); - ffi::Function attn_func = Downcast(args[1]); + ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } LOG(FATAL) << "Cannot reach here"; diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 8844517973bb..691246c3bf77 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -149,7 +149,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { * \param entry_index The unique index of the capture function used for lookup. * \return The return value of the capture function. */ - ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, + ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, Any args, int64_t entry_index, Optional shape_expr) { CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { @@ -160,7 +160,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { } // Set up arguments for the graph execution - Array tuple_args = Downcast>(args); + Array tuple_args = args.cast>(); int nargs = static_cast(tuple_args.size()); std::vector packed_args(nargs); @@ -250,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); auto capture_func = args[1].cast(); - auto func_args = args[2].cast(); + Any func_args = args[2]; int64_t entry_index = args[3].cast(); Optional shape_expr = std::nullopt; if (args.size() == 5) { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 17b99bbc8ce4..c28e30084fc1 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -74,7 +74,7 @@ ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs args, int starting_ LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; } int index = args[i].cast(); - auto arr = Downcast>(obj); + auto arr = obj.cast>(); // make sure the index is in bounds if (index >= static_cast(arr.size())) { LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; @@ -96,10 +96,10 @@ NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* allo Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { if (src.as()) { - return ConvertNDArrayToDevice(Downcast(src), dev, alloc); + return ConvertNDArrayToDevice(src.cast(), dev, alloc); } else if (src.as()) { std::vector ret; - auto arr = Downcast>(src); + auto arr = src.cast>(); for (size_t i = 0; i < arr.size(); i++) { ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 3cd4a6ed0d81..ffb1737a7063 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -370,7 +370,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT }; if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { ICHECK_EQ(op->args.size(), 5); - Var var = runtime::Downcast(op->args[0]); + Var var = Downcast(op->args[0]); // Get the data type of the simdgroup matrix auto it = simdgroup_dtype_.find(var.get()); ICHECK(it != simdgroup_dtype_.end()) diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 2be5e148fbfe..43a976fa892f 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -204,9 +204,9 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, if (is_one(write_region->region[i]->extent)) { write_index.push_back(write_region->region[i]->min); } else { - Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + Var var = Downcast(loop_vars[j]).copy_with_suffix("_inverse"); new_loop_vars.push_back(var); - substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + substitute_map.Set(Downcast(loop_vars[j++]), var); write_index.push_back(write_region->region[i]->min + var); } }