Skip to content

Commit

Permalink
[Target] Enable device querying for all targets. (apache#8602)
Browse files Browse the repository at this point in the history
- Move "from_device" argument definition from "vulkan" target to all
  targets.

- Add device querying to TargetInternal::FromConfig, using
  "from_device" argument.  If present, these have lower priority than
  explicitly-specified attributes, but higher priority than the
  default attribute values.

- Add default no-op DeviceAPI::GetTargetProperty.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 91efdb6 commit 8038535
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 88 deletions.
9 changes: 9 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ class TVM_DLL DeviceAPI {
* \sa DeviceAttrKind
*/
virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0;

/*!
* \brief Query the device for specified properties.
*
* This is used to expand "-from_device=N" in the target string to
* all properties that can be determined from that device.
*/
virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {}

/*!
* \brief Allocate a data space on device.
* \param dev The device device to perform operation.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() {
.add_attr_option<String>("device") \
.add_attr_option<String>("model") \
.add_attr_option<Array<String>>("libs") \
.add_attr_option<Target>("host")
.add_attr_option<Target>("host") \
.add_attr_option<Integer>("from_device")

} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
* Returns the results of feature/property queries done during the
* device initialization.
*/
void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv);
void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) final;

private:
std::vector<uint32_t> GetComputeQueueFamilies(VkPhysicalDevice phy_dev);
Expand Down
63 changes: 63 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/tag.h>
#include <tvm/target/target.h>
Expand Down Expand Up @@ -673,6 +674,68 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
throw Error(": Error when parsing target[\"" + key + "\"]" + e.what());
}
}

// if requested, query attributes from the device
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
attrs.erase("from_device");

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
ICHECK(api) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id
<< ", but support for this runtime wasn't enabled at compile-time.";

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist.";

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

// Don't overwrite explicitly-specified values
if (attrs.count(key)) {
continue;
}

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Bool(static_cast<bool>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Integer(static_cast<int64_t>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
attrs[key] = ret;
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
}
}
}

// set default attribute values if they do not exist
for (const auto& kv : target->kind->key2default_) {
if (!attrs.count(kv.first)) {
Expand Down
91 changes: 5 additions & 86 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,85 +209,6 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
return attrs;
}

/*!
* \brief Update the attributes in the Vulkan target.
* \param attrs The original attributes
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateVulkanAttrs(Map<String, ObjectRef> attrs) {
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
Device device{kDLVulkan, device_id};
const PackedFunc* get_target_property =
runtime::Registry::Get("device_api.vulkan.get_target_property");
ICHECK(get_target_property)
<< "Requested to read Vulkan parameters from device, but no Vulkan runtime available";

// Current vulkan implementation is partially a proof-of-concept,
// with long-term goal to move the -from_device functionality to
// TargetInternal::FromConfig, and to be usable by all targets.
// The duplicate list of parameters is needed until then, since
// TargetKind::Get("vulkan")->key2vtype_ is private.
std::vector<const char*> bool_opts = {
"supports_float16", "supports_float32",
"supports_float64", "supports_int8",
"supports_int16", "supports_int32",
"supports_int64", "supports_8bit_buffer",
"supports_16bit_buffer", "supports_storage_buffer_storage_class",
"supports_push_descriptor", "supports_dedicated_allocation"};
std::vector<const char*> int_opts = {"supported_subgroup_operations",
"max_num_threads",
"thread_warp_size",
"max_block_size_x",
"max_block_size_y",
"max_block_size_z",
"max_push_constants_size",
"max_uniform_buffer_range",
"max_storage_buffer_range",
"max_per_stage_descriptor_storage_buffer",
"max_shared_memory_per_block",
"driver_version",
"vulkan_api_version",
"max_spirv_version"};
std::vector<const char*> str_opts = {"device_name", "device_type"};

for (auto& key : bool_opts) {
if (!attrs.count(key)) {
attrs.Set(key, Bool(static_cast<bool>((*get_target_property)(device, key))));
}
}
for (auto& key : int_opts) {
if (!attrs.count(key)) {
attrs.Set(key, Integer(static_cast<int64_t>((*get_target_property)(device, key))));
}
}
for (auto& key : str_opts) {
if (!attrs.count(key)) {
attrs.Set(key, (*get_target_property)(device, key));
}
}

attrs.erase("from_device");
}

// Set defaults here, rather than in the .add_attr_option() calls.
// The priority should be user-specified > device-query > default,
// but defaults defined in .add_attr_option() are already applied by
// this point. Longer-term, would be good to add a
// `DeviceAPI::GetTargetProperty` function and extend "from_device"
// to work for all runtimes.
std::unordered_map<String, ObjectRef> defaults = {{"supports_float32", Bool(true)},
{"supports_int32", Bool(true)},
{"max_num_threads", Integer(256)},
{"thread_warp_size", Integer(1)}};
for (const auto& kv : defaults) {
if (!attrs.count(kv.first)) {
attrs.Set(kv.first, kv.second);
}
}
return attrs;
}

/********** Register Target kinds and attributes **********/

TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
Expand Down Expand Up @@ -362,14 +283,13 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal)

TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Bool>("from_device")
// Feature support
.add_attr_option<Bool>("supports_float16")
.add_attr_option<Bool>("supports_float32")
.add_attr_option<Bool>("supports_float32", Bool(true))
.add_attr_option<Bool>("supports_float64")
.add_attr_option<Bool>("supports_int8")
.add_attr_option<Bool>("supports_int16")
.add_attr_option<Bool>("supports_int32")
.add_attr_option<Bool>("supports_int32", Bool(true))
.add_attr_option<Bool>("supports_int64")
.add_attr_option<Bool>("supports_8bit_buffer")
.add_attr_option<Bool>("supports_16bit_buffer")
Expand All @@ -378,8 +298,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Bool>("supports_dedicated_allocation")
.add_attr_option<Integer>("supported_subgroup_operations")
// Physical device limits
.add_attr_option<Integer>("max_num_threads")
.add_attr_option<Integer>("thread_warp_size")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("max_block_size_x")
.add_attr_option<Integer>("max_block_size_y")
.add_attr_option<Integer>("max_block_size_z")
Expand All @@ -395,8 +315,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Integer>("vulkan_api_version")
.add_attr_option<Integer>("max_spirv_version")
// Tags
.set_default_keys({"vulkan", "gpu"})
.set_attrs_preprocessor(UpdateVulkanAttrs);
.set_default_keys({"vulkan", "gpu"});

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<Bool>("system-lib")
Expand Down

0 comments on commit 8038535

Please sign in to comment.