Skip to content

Commit

Permalink
[Target] Several minor corrections to the device property query (#8651)
Browse files Browse the repository at this point in the history
- Pass parameters through TVMRetValue as std::string instead of
  runtime::String

- Remove escaping of spaces inside quotes for target attributes.
  Updated unit test to verify round-trip behavior.

- Added missing "device_type" query for Vulkan.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg authored Aug 4, 2021
1 parent 40a9086 commit 4d2c5d5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
break;
}
case kDeviceName:
*rv = String(prop.device_name);
*rv = std::string(prop.device_name);
break;

case kMaxClockRate:
Expand Down Expand Up @@ -237,7 +237,10 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property,
*rv = int64_t(prop.max_shared_memory_per_block);
}
if (property == "device_name") {
*rv = String(prop.device_name);
*rv = prop.device_name;
}
if (property == "device_type") {
*rv = prop.device_type;
}
if (property == "driver_version") {
*rv = int64_t(prop.driver_version);
Expand Down
4 changes: 2 additions & 2 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ static Optional<String> JoinString(const std::vector<String>& array, char separa
} else {
os << quote;
for (char c : str) {
if (c == separator || c == quote) {
if (c == quote) {
os << escape;
}
os << c;
Expand Down Expand Up @@ -781,7 +781,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
attrs[key] = ret;
attrs[key] = String(ret.operator std::string());
break;

default:
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_target_string_with_spaces():
assert target.attrs["device_name"] == "Name of GPU with spaces"
assert target.attrs["device_type"] == "discrete"

target = tvm.target.Target(str(target))

assert target.attrs["device_name"] == "Name of GPU with spaces"
assert target.attrs["device_type"] == "discrete"


def test_target_create():
targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()]
Expand Down

0 comments on commit 4d2c5d5

Please sign in to comment.