Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPV_NV_shader_atomic_fp16_vector #5581

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ vars = {
'protobuf_revision': 'v21.12',

're2_revision': 'b4c6fe091b74b65f706ff9c9ff369b396c2a3177',
'spirv_headers_revision': 'd3c2a6fa95ad463ca8044d7fc45557db381a6a64',
'spirv_headers_revision': '05cc486580771e4fa7ddc89f5c9ee1e97382689a',
}

deps = {
Expand Down
50 changes: 36 additions & 14 deletions source/val/validate_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpAtomicFlagClear: {
const uint32_t result_type = inst->type_id();

// All current atomics only are scalar result
// Validate return type first so can just check if pointer type is same
// (if applicable)
if (HasReturnType(opcode)) {
if (HasOnlyFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type)) {
(!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsFloatScalarType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": expected Result Type to be float scalar type";
Expand All @@ -160,6 +161,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
<< ": expected Result Type to be integer scalar type";
} else if (HasIntOrFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type) &&
!(opcode == spv::Op::OpAtomicExchange &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsIntScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
Expand Down Expand Up @@ -222,12 +226,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {

if (opcode == spv::Op::OpAtomicFAddEXT) {
// result type being float checked already
if ((_.GetBitWidth(result_type) == 16) &&
(!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float add atomics require the AtomicFloat32AddEXT "
"capability";
if (_.GetBitWidth(result_type) == 16) {
if (_.IsFloat16Vector2Or4Type(result_type)) {
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float vector atomics require the "
"AtomicFloat16VectorNV capability";
} else {
if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float add atomics require the AtomicFloat32AddEXT "
"capability";
}
}
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
Expand All @@ -245,12 +258,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
}
} else if (opcode == spv::Op::OpAtomicFMinEXT ||
opcode == spv::Op::OpAtomicFMaxEXT) {
if ((_.GetBitWidth(result_type) == 16) &&
(!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float min/max atomics require the "
"AtomicFloat16MinMaxEXT capability";
if (_.GetBitWidth(result_type) == 16) {
if (_.IsFloat16Vector2Or4Type(result_type)) {
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float vector atomics require the "
"AtomicFloat16VectorNV capability";
} else {
if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float min/max atomics require the "
"AtomicFloat16MinMaxEXT capability";
}
}
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
Expand Down
19 changes: 16 additions & 3 deletions source/val/validate_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
const auto ptr_type = result_type->GetOperandAs<uint32_t>(2);
const auto ptr_opcode = _.GetIdOpcode(ptr_type);
if (ptr_opcode != spv::Op::OpTypeInt && ptr_opcode != spv::Op::OpTypeFloat &&
ptr_opcode != spv::Op::OpTypeVoid) {
ptr_opcode != spv::Op::OpTypeVoid &&
!(ptr_opcode == spv::Op::OpTypeVector &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(ptr_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be OpTypePointer whose Type operand "
"must be a scalar numerical type or OpTypeVoid";
Expand All @@ -1142,7 +1145,14 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
<< "Corrupt image type definition";
}

if (info.sampled_type != ptr_type) {
if (info.sampled_type != ptr_type &&
!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(ptr_type) &&
_.GetIdOpcode(info.sampled_type) == spv::Op::OpTypeFloat &&
((_.GetDimension(ptr_type) == 2 &&
info.format == spv::ImageFormat::Rg16f) ||
(_.GetDimension(ptr_type) == 4 &&
info.format == spv::ImageFormat::Rgba16f)))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Image 'Sampled Type' to be the same as the Type "
"pointed to by Result Type";
Expand Down Expand Up @@ -1213,7 +1223,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
(info.format != spv::ImageFormat::R64ui) &&
(info.format != spv::ImageFormat::R32f) &&
(info.format != spv::ImageFormat::R32i) &&
(info.format != spv::ImageFormat::R32ui)) {
(info.format != spv::ImageFormat::R32ui) &&
!((info.format == spv::ImageFormat::Rg16f ||
info.format == spv::ImageFormat::Rgba16f) &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4658)
<< "Expected the Image Format in Image to be R64i, R64ui, R32f, "
Expand Down
14 changes: 14 additions & 0 deletions source/val/validation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,20 @@ bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
return false;
}

bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
const Instruction* inst = FindDef(id);
assert(inst);

if (inst->opcode() == spv::Op::OpTypeVector) {
uint32_t vectorDim = GetDimension(id);
return IsFloatScalarType(GetComponentType(id)) &&
(vectorDim == 2 || vectorDim == 4) &&
(GetBitWidth(GetComponentType(id)) == 16);
}

return false;
}

bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
const Instruction* inst = FindDef(id);
if (!inst) {
Expand Down
1 change: 1 addition & 0 deletions source/val/validation_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ class ValidationState_t {
bool IsVoidType(uint32_t id) const;
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
bool IsFloat16Vector2Or4Type(uint32_t id) const;
bool IsFloatScalarOrVectorType(uint32_t id) const;
bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
Expand Down
142 changes: 138 additions & 4 deletions test/val/val_atomics_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ TEST_F(ValidateAtomics, AtomicAddFloatVulkan) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFAddEXT requires one of these capabilities: "
"AtomicFloat32AddEXT AtomicFloat64AddEXT AtomicFloat16AddEXT"));
"AtomicFloat16VectorNV AtomicFloat32AddEXT AtomicFloat64AddEXT "
"AtomicFloat16AddEXT"));
}

TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
Expand All @@ -331,7 +332,8 @@ TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFMinEXT requires one of these capabilities: "
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT "
"AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
}

TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
Expand All @@ -343,8 +345,10 @@ TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFMaxEXT requires one of these capabilities: "
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
HasSubstr(
"Opcode AtomicFMaxEXT requires one of these capabilities: "
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT "
"AtomicFloat16MinMaxEXT"));
}

TEST_F(ValidateAtomics, AtomicAddFloatVulkanWrongType1) {
Expand Down Expand Up @@ -2713,6 +2717,136 @@ TEST_F(ValidateAtomics, IIncrementBadPointerDataType) {
"value of type Result Type"));
}

TEST_F(ValidateAtomics, AtomicFloat16VectorSuccess) {
const std::string definitions = R"(
%f16 = OpTypeFloat 16
%f16vec2 = OpTypeVector %f16 2
%f16vec4 = OpTypeVector %f16 4

%f16_1 = OpConstant %f16 1
%f16vec2_1 = OpConstantComposite %f16vec2 %f16_1 %f16_1
%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1

%f16vec2_ptr = OpTypePointer Workgroup %f16vec2
%f16vec4_ptr = OpTypePointer Workgroup %f16vec4
%f16vec2_var = OpVariable %f16vec2_ptr Workgroup
%f16vec4_var = OpVariable %f16vec4_ptr Workgroup
)";

const std::string body = R"(
%val3 = OpAtomicFMinEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val4 = OpAtomicFMaxEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val8 = OpAtomicFAddEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val9 = OpAtomicExchange %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1

%val11 = OpAtomicFMinEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val12 = OpAtomicFMaxEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val18 = OpAtomicFAddEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val19 = OpAtomicExchange %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1

)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
}

static constexpr char Float16Vector3Defs[] = R"(
%f16 = OpTypeFloat 16
%f16vec3 = OpTypeVector %f16 3

%f16_1 = OpConstant %f16 1
%f16vec3_1 = OpConstantComposite %f16vec3 %f16_1 %f16_1 %f16_1

%f16vec3_ptr = OpTypePointer Workgroup %f16vec3
%f16vec3_var = OpVariable %f16vec3_ptr Workgroup
)";

TEST_F(ValidateAtomics, AtomicFloat16Vector3MinFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val11 = OpAtomicFMinEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFMinEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3MaxFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val12 = OpAtomicFMaxEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFMaxEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3AddFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val18 = OpAtomicFAddEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFAddEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3ExchangeFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val19 = OpAtomicExchange %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("AtomicExchange: expected Result Type to be integer or "
"float scalar type"));
}

} // namespace
} // namespace val
} // namespace spvtools