Skip to content

Commit

Permalink
[Vulkan] Implement sync for SyncThread("warp") (#8320)
Browse files Browse the repository at this point in the history
- Add sync if a SyncThread("warp") node is present.  The sync is done
  at spv::ScopeSubgroup if supported (Vulkan 1.1+), and at
  spv::ScopeWorkgroup otherwise.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg authored Jun 24, 2021
1 parent 07701f2 commit 3e28716
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
21 changes: 19 additions & 2 deletions src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,24 @@ namespace codegen {

class SPIRVTools {
public:
SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); }
explicit SPIRVTools(Target target) {
uint32_t vulkan_version =
target->GetAttr<Integer>("vulkan_api_version").value_or(VK_API_VERSION_1_0);
uint32_t spirv_version = target->GetAttr<Integer>("max_spirv_version").value_or(0x10000);

spv_target_env validation_version;
if (vulkan_version >= VK_API_VERSION_1_2) {
validation_version = SPV_ENV_VULKAN_1_2;
} else if (vulkan_version >= VK_API_VERSION_1_1 && spirv_version >= 0x10400) {
validation_version = SPV_ENV_VULKAN_1_1_SPIRV_1_4;
} else if (vulkan_version >= VK_API_VERSION_1_1) {
validation_version = SPV_ENV_VULKAN_1_1;
} else {
validation_version = SPV_ENV_VULKAN_1_0;
}

ctx_ = spvContextCreate(validation_version);
}
~SPIRVTools() { spvContextDestroy(ctx_); }
std::string BinaryToText(const std::vector<uint32_t>& bin) {
spv_text text = nullptr;
Expand Down Expand Up @@ -80,7 +97,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction)
using tvm::runtime::VulkanShader;

std::ostringstream code_data;
static SPIRVTools spirv_tools;
SPIRVTools spirv_tools(target);
std::unordered_map<std::string, VulkanShader> smap;

const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
Expand Down
29 changes: 18 additions & 11 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,27 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext
spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
spirv::Value value;
if (sync == "warp") {
return value;
} else if (sync == "shared") {
auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(
spv::OpControlBarrier,
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
builder_->IntImm(type_int,
static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
spv::MemorySemanticsWorkgroupMemoryMask)));

uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;

int64_t sync_scope;
int64_t memory_semantics;
if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
sync_scope = spv::ScopeSubgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask;
} else if ((sync == "shared") || (sync == "warp")) {
sync_scope = spv::ScopeWorkgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask;
} else {
LOG(FATAL) << "Do not support sync " << sync;
}

auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, memory_semantics));
return value;
}

Expand Down
4 changes: 4 additions & 0 deletions src/target/spirv/spirv_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) {
ICHECK_EQ(target->kind->device_type, kDLVulkan)
<< "SPIRVSupport can only be checked for vulkan device type";

if (target->GetAttr<Integer>("vulkan_api_version")) {
vulkan_api_version = target->GetAttr<Integer>("vulkan_api_version").value();
}

if (target->GetAttr<Integer>("supported_subgroup_operations")) {
supported_subgroup_operations =
target->GetAttr<Integer>("supported_subgroup_operations").value();
Expand Down
14 changes: 14 additions & 0 deletions src/target/spirv/spirv_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define TVM_TARGET_SPIRV_SPIRV_SUPPORT_H_

#include <tvm/target/target.h>
#include <vulkan/vulkan_core.h>

namespace tvm {
namespace codegen {
Expand All @@ -37,6 +38,19 @@ struct SPIRVSupport {
*/
explicit SPIRVSupport(Target target);

/*! \brief The Vulkan API version supported by the device.
*
* Vulkan struct: VkPhysicalDeviceProperties
* Device property: apiVersion
*
* If VK_KHR_driver_properties is present, will also check the
* driver conformance version. If the version advertised does not
* pass the Vulkan conformance test, vulkan_api_version will be the
* latest Vulkan version that does pass the conformance test
* instead.
*/
uint32_t vulkan_api_version{VK_MAKE_VERSION(1, 0, 0)};

/*!
* \brief The supported subgroup operations
*
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
seq.emplace_back(SyncThread("warp"));
in_warp_seq.emplace_back(SyncThread("warp"));
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
Expand Down

0 comments on commit 3e28716

Please sign in to comment.