diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 71188574ac2a..c3d83bf2993f 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -85,6 +85,15 @@ class TVM_DLL DeviceAPI { * \sa DeviceAttrKind */ virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0; + + /*! + * \brief Query the device for specified properties. + * + * This is used to expand "-from_device=N" in the target string to + * all properties that can be determined from that device. + */ + virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {} + /*! * \brief Allocate a data space on device. * \param dev The device device to perform operation. diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index e7da2dd413a0..8a2bbcbd0121 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -377,7 +377,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() { .add_attr_option("device") \ .add_attr_option("model") \ .add_attr_option>("libs") \ - .add_attr_option("host") + .add_attr_option("host") \ + .add_attr_option("from_device") } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index b8be3eb43c79..851fede3067f 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -106,7 +106,7 @@ class VulkanDeviceAPI final : public DeviceAPI { * Returns the results of feature/property queries done during the * device initialization. */ - void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv); + void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) final; private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); diff --git a/src/target/target.cc b/src/target/target.cc index df810185784e..d8e71de762e8 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,6 +21,7 @@ * \file src/target/target.cc */ #include +#include #include #include #include @@ -673,6 +674,68 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_map(attrs.at("from_device")); + attrs.erase("from_device"); + + Device device{static_cast(target->kind->device_type), device_id}; + + auto api = runtime::DeviceAPI::Get(device, true); + ICHECK(api) << "Requested reading the parameters for " << target->kind->name + << " from device_id " << device_id + << ", but support for this runtime wasn't enabled at compile-time."; + + TVMRetValue ret; + api->GetAttr(device, runtime::kExist, &ret); + ICHECK(ret) << "Requested reading the parameters for " << target->kind->name + << " from device_id " << device_id << ", but device_id " << device_id + << " doesn't exist."; + + for (const auto& kv : target->kind->key2vtype_) { + const String& key = kv.first; + const TargetKindNode::ValueTypeInfo& type_info = kv.second; + + // Don't overwrite explicitly-specified values + if (attrs.count(key)) { + continue; + } + + TVMRetValue ret; + api->GetTargetProperty(device, key, &ret); + + switch (ret.type_code()) { + case kTVMNullptr: + // Nothing returned for this parameter, move on to the next one. + continue; + + case kTVMArgInt: + if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + attrs[key] = Bool(static_cast(ret)); + } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + attrs[key] = Integer(static_cast(ret)); + } else { + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received integer from device api"; + } + break; + + case kTVMStr: + ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) + << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received string from device api"; + attrs[key] = ret; + break; + + default: + LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key + << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; + break; + } + } + } + // set default attribute values if they do not exist for (const auto& kv : target->kind->key2default_) { if (!attrs.count(kv.first)) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index a56916248858..65ec9a04fe05 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -209,85 +209,6 @@ Map UpdateROCmAttrs(Map attrs) { return attrs; } -/*! - * \brief Update the attributes in the Vulkan target. - * \param attrs The original attributes - * \return The updated attributes - */ -Map UpdateVulkanAttrs(Map attrs) { - if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")); - Device device{kDLVulkan, device_id}; - const PackedFunc* get_target_property = - runtime::Registry::Get("device_api.vulkan.get_target_property"); - ICHECK(get_target_property) - << "Requested to read Vulkan parameters from device, but no Vulkan runtime available"; - - // Current vulkan implementation is partially a proof-of-concept, - // with long-term goal to move the -from_device functionality to - // TargetInternal::FromConfig, and to be usable by all targets. - // The duplicate list of parameters is needed until then, since - // TargetKind::Get("vulkan")->key2vtype_ is private. - std::vector bool_opts = { - "supports_float16", "supports_float32", - "supports_float64", "supports_int8", - "supports_int16", "supports_int32", - "supports_int64", "supports_8bit_buffer", - "supports_16bit_buffer", "supports_storage_buffer_storage_class", - "supports_push_descriptor", "supports_dedicated_allocation"}; - std::vector int_opts = {"supported_subgroup_operations", - "max_num_threads", - "thread_warp_size", - "max_block_size_x", - "max_block_size_y", - "max_block_size_z", - "max_push_constants_size", - "max_uniform_buffer_range", - "max_storage_buffer_range", - "max_per_stage_descriptor_storage_buffer", - "max_shared_memory_per_block", - "driver_version", - "vulkan_api_version", - "max_spirv_version"}; - std::vector str_opts = {"device_name", "device_type"}; - - for (auto& key : bool_opts) { - if (!attrs.count(key)) { - attrs.Set(key, Bool(static_cast((*get_target_property)(device, key)))); - } - } - for (auto& key : int_opts) { - if (!attrs.count(key)) { - attrs.Set(key, Integer(static_cast((*get_target_property)(device, key)))); - } - } - for (auto& key : str_opts) { - if (!attrs.count(key)) { - attrs.Set(key, (*get_target_property)(device, key)); - } - } - - attrs.erase("from_device"); - } - - // Set defaults here, rather than in the .add_attr_option() calls. - // The priority should be user-specified > device-query > default, - // but defaults defined in .add_attr_option() are already applied by - // this point. Longer-term, would be good to add a - // `DeviceAPI::GetTargetProperty` function and extend "from_device" - // to work for all runtimes. - std::unordered_map defaults = {{"supports_float32", Bool(true)}, - {"supports_int32", Bool(true)}, - {"max_num_threads", Integer(256)}, - {"thread_warp_size", Integer(1)}}; - for (const auto& kv : defaults) { - if (!attrs.count(kv.first)) { - attrs.Set(kv.first, kv.second); - } - } - return attrs; -} - /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) @@ -360,14 +281,13 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal) TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("system-lib") - .add_attr_option("from_device") // Feature support .add_attr_option("supports_float16") - .add_attr_option("supports_float32") + .add_attr_option("supports_float32", Bool(true)) .add_attr_option("supports_float64") .add_attr_option("supports_int8") .add_attr_option("supports_int16") - .add_attr_option("supports_int32") + .add_attr_option("supports_int32", Bool(true)) .add_attr_option("supports_int64") .add_attr_option("supports_8bit_buffer") .add_attr_option("supports_16bit_buffer") @@ -376,8 +296,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_dedicated_allocation") .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads") - .add_attr_option("thread_warp_size") + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("thread_warp_size", Integer(1)) .add_attr_option("max_block_size_x") .add_attr_option("max_block_size_y") .add_attr_option("max_block_size_z") @@ -393,8 +313,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") // Tags - .set_default_keys({"vulkan", "gpu"}) - .set_attrs_preprocessor(UpdateVulkanAttrs); + .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("system-lib")