Skip to content

Commit

Permalink
enable amd_apu device on vulkan target (#5659)
Browse files Browse the repository at this point in the history
  • Loading branch information
mei-ye authored May 26, 2020
1 parent b6bd367 commit 03d21ff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
10 changes: 7 additions & 3 deletions apps/benchmark/gpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,17 @@ def benchmark(network, target):
'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'],
help='The name of neural network')
parser.add_argument("--device", type=str,
choices=['amd_apu'], default='amd_apu',
help="The name of the test device. If your device is not listed in "
"the choices list, pick the most similar one as argument.")
parser.add_argument("--model", type=str,
choices=['1080ti', 'titanx', 'tx2', 'gfx900'], default='1080ti',
choices=['1080ti', 'titanx', 'tx2', 'gfx900', 'v1000'], default='1080ti',
help="The model of the test device. If your device is not listed in "
"the choices list, pick the most similar one as argument.")
parser.add_argument("--repeat", type=int, default=600)
parser.add_argument("--target", type=str,
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda',
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal', 'vulkan'], default='cuda',
help="The tvm compilation target")
parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.")
args = parser.parse_args()
Expand All @@ -74,7 +78,7 @@ def benchmark(network, target):
else:
networks = [args.network]

target = tvm.target.create('%s -model=%s' % (args.target, args.model))
target = tvm.target.create('%s -device=%s -model=%s' % (args.target, args.device, args.model))

print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def set_task(self, task):
def get_build_kwargs(self):
kwargs = {}
if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \
'rocm' in self.task.target.keys:
'rocm' in self.task.target.keys or 'vulkan' in self.task.target.keys:
remote = request_remote(self.key, self.host, self.port)
ctx = remote.context(str(self.task.target), 0)
max_dims = ctx.max_thread_dimensions
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
'intel_graphics': "v0.02",

'vta': "v0.08",
'amd_apu': "v0.01",
}

logger = logging.getLogger('autotvm')
Expand All @@ -69,6 +70,7 @@ def _alias(name):
'webgpu': 'opencl',
'vulkan': 'opencl',
'nvptx': 'cuda',
'amd_apu': 'amd_apu'
}
return table.get(name, name)

Expand Down
14 changes: 12 additions & 2 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
case kMaxThreadsPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations;
*rv = value;
break;
}
Expand Down Expand Up @@ -401,8 +401,18 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
return;
case kExist:
break;
case kMaxThreadDimensions:
case kMaxThreadDimensions: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t dims[3];
dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0];
dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1];
dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2];
std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
break;
}
case kGcnArch:
return;
}
Expand Down

0 comments on commit 03d21ff

Please sign in to comment.