diff --git a/source/opt/convert_to_half_pass.cpp b/source/opt/convert_to_half_pass.cpp index cb0065d2d3..e243bedf0c 100644 --- a/source/opt/convert_to_half_pass.cpp +++ b/source/opt/convert_to_half_pass.cpp @@ -171,6 +171,19 @@ bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) { bool ConvertToHalfPass::GenHalfArith(Instruction* inst) { bool modified = false; + // If this is a OpCompositeExtract instruction and has a struct operand, we + // should not relax this instruction. Doing so could cause a mismatch between + // the result type and the struct member type. + bool hasStructOperand = false; + if (inst->opcode() == spv::Op::OpCompositeExtract) { + inst->ForEachInId([&hasStructOperand, this](uint32_t* idp) { + Instruction* op_inst = get_def_use_mgr()->GetDef(*idp); + if (IsStruct(op_inst)) hasStructOperand = true; + }); + if (hasStructOperand) { + return false; + } + } // Convert all float32 based operands to float16 equivalent and change // instruction type to float16 equivalent. inst->ForEachInId([&inst, &modified, this](uint32_t* idp) { @@ -303,12 +316,19 @@ bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) { if (closure_ops_.count(inst->opcode()) == 0) return false; // Can relax if all float operands are relaxed bool relax = true; - inst->ForEachInId([&relax, this](uint32_t* idp) { + bool hasStructOperand = false; + inst->ForEachInId([&relax, &hasStructOperand, this](uint32_t* idp) { Instruction* op_inst = get_def_use_mgr()->GetDef(*idp); - if (IsStruct(op_inst)) relax = false; + if (IsStruct(op_inst)) hasStructOperand = true; if (!IsFloat(op_inst, 32)) return; if (!IsRelaxed(*idp)) relax = false; }); + // If the instruction has a struct operand, we should not relax it, even if + // all its uses are relaxed. Doing so could cause a mismatch between the + // result type and the struct member type. + if (hasStructOperand) { + return false; + } if (relax) { AddRelaxed(inst->result_id()); return true; diff --git a/test/opt/convert_relaxed_to_half_test.cpp b/test/opt/convert_relaxed_to_half_test.cpp index 62b9ae4535..c5774045c3 100644 --- a/test/opt/convert_relaxed_to_half_test.cpp +++ b/test/opt/convert_relaxed_to_half_test.cpp @@ -1713,6 +1713,75 @@ TEST_F(ConvertToHalfTest, PreserveImageOperandPrecision) { SinglePassRunAndMatch(test, true); } +TEST_F(ConvertToHalfTest, DontRelaxDecoratedOpCompositeExtract) { + // This test checks that a OpCompositeExtract with a Struct operand won't be + // relaxed, even if it is explicitly decorated with RelaxedPrecision. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +OpDecorate %9 RelaxedPrecision +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_struct_6 = OpTypeStruct %v4float +%7 = OpUndef %_struct_6 +%1 = OpFunction %void None %3 +%8 = OpLabel +%9 = OpCompositeExtract %float %7 0 3 +OpReturn +OpFunctionEnd +)"; + + const std::string expected = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_struct_6 = OpTypeStruct %v4float +%7 = OpUndef %_struct_6 +%1 = OpFunction %void None %3 +%8 = OpLabel +%9 = OpCompositeExtract %float %7 0 3 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, expected, true); +} + +TEST_F(ConvertToHalfTest, DontRelaxOpCompositeExtract) { + // This test checks that a OpCompositeExtract with a Struct operand won't be + // relaxed, even if its result has no uses. + const std::string test = + R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %1 "main" +OpExecutionMode %1 OriginUpperLeft +%void = OpTypeVoid +%3 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_struct_6 = OpTypeStruct %v4float +%7 = OpUndef %_struct_6 +%1 = OpFunction %void None %3 +%8 = OpLabel +%9 = OpCompositeExtract %float %7 0 3 +OpReturn +OpFunctionEnd +)"; + + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndCheck(test, test, true); +} + } // namespace } // namespace opt } // namespace spvtools