Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Support passing 64 bit scalar #7572

Merged
merged 13 commits into from
Mar 5, 2021
4 changes: 2 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_na
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->context.device_id;
if (scache_[device_id] == nil) {
Expand All @@ -197,7 +197,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const
}
if (num_pack_args_ != 0) {
[encoder setBytes:pack_args
length:num_pack_args_ * sizeof(ArgUnion)
length:num_pack_args_ * sizeof(ArgUnion64)
atIndex:num_buffer_args_];
}
// launch
Expand Down
36 changes: 25 additions & 11 deletions src/runtime/pack_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,24 @@ namespace tvm {
namespace runtime {
/*!
* \brief argument union type of 32bit.
* Choose 32 bit because most GPU API do not work well with 64 bit.
*/
union ArgUnion {
union ArgUnion32 {
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
};

/*!
* \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
*/
union ArgUnion64 {
int32_t v_int32[2];
uint32_t v_uint32[2];
float v_float32[2];
int64_t v_int64;
uint64_t v_uint64;
double v_float64;
};
/*!
* \brief Create a packed function from void addr types.
*
Expand Down Expand Up @@ -140,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<void*, N> addr_(num_args);
TempArray<ArgUnion, N> holder_(num_args);
TempArray<ArgUnion32, N> holder_(num_args);
void** addr = addr_.data();
ArgUnion* holder = holder_.data();
ArgUnion32* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
Expand Down Expand Up @@ -177,25 +188,28 @@ template <int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<ArgUnion, N> holder_(num_args);
ArgUnion* holder = holder_.data();
TempArray<ArgUnion64, N> holder_(num_args);
ArgUnion64* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case INT64_TO_INT64: {
holder[i].v_int64 = args.values[base + i].v_int64;
break;
}
case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Do not support 64bit argument to device function";
holder[i].v_float64 = args.values[base + i].v_float64;
masahi marked this conversation as resolved.
Show resolved Hide resolved
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
break;
}
case HANDLE_TO_HANDLE: {
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class VulkanWrappedFunc {
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
}

void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const;

private:
// internal module
Expand Down Expand Up @@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion) * num_pack_args;
crange.size = sizeof(ArgUnion64) * num_pack_args;

VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
Expand Down Expand Up @@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
return streams_[device_id].get();
}

void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
const ArgUnion64* pack_args) const {
int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
ICHECK_LT(device_id, kVulkanMaxNumDevice);
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
Expand Down Expand Up @@ -1075,7 +1076,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
descriptor_buffers.data());
if (num_pack_args_ != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
pack_args);
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
Expand All @@ -1093,7 +1094,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
}

// Otherwise, the more expensive deferred path.
std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
std::vector<VkWriteDescriptorSet> write_descriptor_sets;
write_descriptor_sets.resize(descriptor_buffers.size());
Expand All @@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
nullptr);
if (pack_args_storage.size() != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
0, pack_args_storage.size() * sizeof(ArgUnion64),
pack_args_storage.data());
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() {
decl_stream << "#include <metal_stdlib>\n";
decl_stream << "using namespace metal;\n\n";
decl_stream << "union __TVMArgUnion {\n"
<< " int v_int;\n"
<< " int v_int[2];\n"
<< "};\n\n";
}

Expand Down Expand Up @@ -102,6 +102,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
if (v.dtype().bits() == 32) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << "[2];\n";
vref << varg << "." << vid << "[0]";
masahi marked this conversation as resolved.
Show resolved Hide resolved
} else if (v.dtype().bits() == 64) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << ";\n";
Expand Down
6 changes: 6 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin<GLSLstd450Sin>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin<GLSLstd450Cos>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin<GLSLstd450Log2>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
Expand Down
7 changes: 7 additions & 0 deletions tests/python/topi/python/test_topi_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
"generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
"cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
Expand All @@ -44,6 +46,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")

for in_dtype in ["float32", "float64"]:
if target == "metal" and in_dtype == "float64":
# float64 is not supported in metal
continue
data = np.random.randn(10, 10).astype(in_dtype)
check_cumsum(np.cumsum(data), data)
check_cumsum(np.cumsum(data, axis=0), data, axis=0)
Expand All @@ -70,3 +75,5 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan"]:
check_device(device)


Expand Down