Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Framework for device querying for all targets. #8602

Merged
merged 1 commit into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ class TVM_DLL DeviceAPI {
* \sa DeviceAttrKind
*/
virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0;

/*!
* \brief Query the device for specified properties.
*
* This is used to expand "-from_device=N" in the target string to
* all properties that can be determined from that device.
*/
virtual void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) {}

/*!
* \brief Allocate a data space on device.
* \param dev The device device to perform operation.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() {
.add_attr_option<String>("device") \
.add_attr_option<String>("model") \
.add_attr_option<Array<String>>("libs") \
.add_attr_option<Target>("host")
.add_attr_option<Target>("host") \
.add_attr_option<Integer>("from_device")

} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
* Returns the results of feature/property queries done during the
* device initialization.
*/
void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv);
void GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) final;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit curious about why adding final here?

Copy link
Contributor Author

@Lunderberg Lunderberg Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't strictly necessary, but I try to be in the habit of including it for virtual functions that aren't intended to be subclassed further. It's always safe to remove final later, but until then it serves as a warning to my future self that I should be look into the class's implementation in detail before making any subclass of it.

I can go either way on having it or not, since final could be interpreted either as "not intended to be subclassed, so take a careful look first" or as "intended not to be subclassed, please don't change it". I tend to take the first interpretation, but the second is completely valid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks!


private:
std::vector<uint32_t> GetComputeQueueFamilies(VkPhysicalDevice phy_dev);
Expand Down
63 changes: 63 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/tag.h>
#include <tvm/target/target.h>
Expand Down Expand Up @@ -673,6 +674,68 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
throw Error(": Error when parsing target[\"" + key + "\"]" + e.what());
}
}

// if requested, query attributes from the device
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
attrs.erase("from_device");

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
ICHECK(api) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id
<< ", but support for this runtime wasn't enabled at compile-time.";

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist.";

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

// Don't overwrite explicitly-specified values
if (attrs.count(key)) {
continue;
}

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Bool(static_cast<bool>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Integer(static_cast<int64_t>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
attrs[key] = ret;
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
}
}
}

// set default attribute values if they do not exist
for (const auto& kv : target->kind->key2default_) {
if (!attrs.count(kv.first)) {
Expand Down
91 changes: 5 additions & 86 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,85 +209,6 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
return attrs;
}

/*!
* \brief Update the attributes in the Vulkan target.
* \param attrs The original attributes
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateVulkanAttrs(Map<String, ObjectRef> attrs) {
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
Device device{kDLVulkan, device_id};
const PackedFunc* get_target_property =
runtime::Registry::Get("device_api.vulkan.get_target_property");
ICHECK(get_target_property)
<< "Requested to read Vulkan parameters from device, but no Vulkan runtime available";

// Current vulkan implementation is partially a proof-of-concept,
// with long-term goal to move the -from_device functionality to
// TargetInternal::FromConfig, and to be usable by all targets.
// The duplicate list of parameters is needed until then, since
// TargetKind::Get("vulkan")->key2vtype_ is private.
std::vector<const char*> bool_opts = {
"supports_float16", "supports_float32",
"supports_float64", "supports_int8",
"supports_int16", "supports_int32",
"supports_int64", "supports_8bit_buffer",
"supports_16bit_buffer", "supports_storage_buffer_storage_class",
"supports_push_descriptor", "supports_dedicated_allocation"};
std::vector<const char*> int_opts = {"supported_subgroup_operations",
"max_num_threads",
"thread_warp_size",
"max_block_size_x",
"max_block_size_y",
"max_block_size_z",
"max_push_constants_size",
"max_uniform_buffer_range",
"max_storage_buffer_range",
"max_per_stage_descriptor_storage_buffer",
"max_shared_memory_per_block",
"driver_version",
"vulkan_api_version",
"max_spirv_version"};
std::vector<const char*> str_opts = {"device_name", "device_type"};

for (auto& key : bool_opts) {
if (!attrs.count(key)) {
attrs.Set(key, Bool(static_cast<bool>((*get_target_property)(device, key))));
}
}
for (auto& key : int_opts) {
if (!attrs.count(key)) {
attrs.Set(key, Integer(static_cast<int64_t>((*get_target_property)(device, key))));
}
}
for (auto& key : str_opts) {
if (!attrs.count(key)) {
attrs.Set(key, (*get_target_property)(device, key));
}
}

attrs.erase("from_device");
}

// Set defaults here, rather than in the .add_attr_option() calls.
// The priority should be user-specified > device-query > default,
// but defaults defined in .add_attr_option() are already applied by
// this point. Longer-term, would be good to add a
// `DeviceAPI::GetTargetProperty` function and extend "from_device"
// to work for all runtimes.
std::unordered_map<String, ObjectRef> defaults = {{"supports_float32", Bool(true)},
{"supports_int32", Bool(true)},
{"max_num_threads", Integer(256)},
{"thread_warp_size", Integer(1)}};
for (const auto& kv : defaults) {
if (!attrs.count(kv.first)) {
attrs.Set(kv.first, kv.second);
}
}
return attrs;
}

/********** Register Target kinds and attributes **********/

TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
Expand Down Expand Up @@ -360,14 +281,13 @@ TVM_REGISTER_TARGET_KIND("metal", kDLMetal)

TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Bool>("from_device")
// Feature support
.add_attr_option<Bool>("supports_float16")
.add_attr_option<Bool>("supports_float32")
.add_attr_option<Bool>("supports_float32", Bool(true))
.add_attr_option<Bool>("supports_float64")
.add_attr_option<Bool>("supports_int8")
.add_attr_option<Bool>("supports_int16")
.add_attr_option<Bool>("supports_int32")
.add_attr_option<Bool>("supports_int32", Bool(true))
.add_attr_option<Bool>("supports_int64")
.add_attr_option<Bool>("supports_8bit_buffer")
.add_attr_option<Bool>("supports_16bit_buffer")
Expand All @@ -376,8 +296,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Bool>("supports_dedicated_allocation")
.add_attr_option<Integer>("supported_subgroup_operations")
// Physical device limits
.add_attr_option<Integer>("max_num_threads")
.add_attr_option<Integer>("thread_warp_size")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("max_block_size_x")
.add_attr_option<Integer>("max_block_size_y")
.add_attr_option<Integer>("max_block_size_z")
Expand All @@ -393,8 +313,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Integer>("vulkan_api_version")
.add_attr_option<Integer>("max_spirv_version")
// Tags
.set_default_keys({"vulkan", "gpu"})
.set_attrs_preprocessor(UpdateVulkanAttrs);
.set_default_keys({"vulkan", "gpu"});

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<Bool>("system-lib")
Expand Down