From b5f23bfc83b2eadc3312518f1c00b4c020e1b9a3 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 1 Feb 2024 09:24:40 -0500 Subject: [PATCH 1/6] Fold 64-bit int operations. Adds folding rules that will fold basic artimetic for signed and unsigned integers of all sizes, including 64-bit. Also folds OpSConvert and OpUConvert. --- source/opt/const_folding_rules.cpp | 215 +++++++++++++++++++- test/opt/fold_test.cpp | 313 ++++++++++++++++++++++++++++- 2 files changed, 524 insertions(+), 4 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index e676974c8c..db4212fb17 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -21,6 +21,59 @@ namespace opt { namespace { constexpr uint32_t kExtractCompositeIdInIdx = 0; +// Returns the value obtained by setting clearing the `number_of_bits` most +// significant bits of `value`. +uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { + if (number_of_bits == 64) return value; + + uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1); + uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull; + if (value & mask_for_sign_bit) { + // Set upper bits to 1 + value |= ~mask_for_significant_bits; + } else { + // Clear the upper bits + value &= mask_for_significant_bits; + } + return value; +} + +// Returns the value obtained from clearing the `number_of_bits` most +// significant bits of `value`. +uint64_t ClearUpperBits(uint64_t value, uint32_t number_of_bits) { + if (number_of_bits == 0) return value; + + uint64_t mask_for_first_bit_to_clear = 1ull << (64 - number_of_bits); + uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1; + value &= mask_for_bits_to_keep; + return value; +} + +// Returns a constant whose value is `value` and type is `type`. This constant +// will be generated by `const_mgr`. The type must be a scalar integer type. +const analysis::Constant* GenerateIntegerConstant( + const analysis::Integer* integer_type, uint64_t result, + analysis::ConstantManager* const_mgr) { + assert(integer_type != nullptr); + + std::vector words; + if (integer_type->width() == 64) { + // In the 64-bit case, two words are needed to represent the value. + words = {static_cast(result), + static_cast(result >> 32)}; + } else { + // In all other cases, only a single word is needed. + assert(integer_type->width() <= 32); + if (integer_type->IsSigned()) { + result = SignExtendValue(result, integer_type->width()); + } else { + result = ClearUpperBits(result, 64 - integer_type->width()); + } + words = {static_cast(result)}; + } + return const_mgr->GetConstant(integer_type, words); +} + // Returns a constants with the value NaN of the given type. Only works for // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. const analysis::Constant* GetNan(const analysis::Type* type, @@ -676,7 +729,6 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -716,6 +768,63 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { }; } +// Returns a |ConstantFoldingRule| that folds binary scalar ops +// using |scalar_rule| and unary vectors ops by applying +// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| +// that is returned assumes that |constants| contains 2 entries. If they are +// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| +// whose element type is |Float| or |Integer|. +ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { + return [scalar_rule](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); + const analysis::Vector* vector_type = result_type->AsVector(); + + const analysis::Constant* arg1 = + (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; + const analysis::Constant* arg2 = + (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; + + if (arg1 == nullptr) { + return nullptr; + } + if (arg2 == nullptr) { + return nullptr; + } + + if (vector_type != nullptr) { + std::vector a_components; + std::vector b_components; + std::vector results_components; + + a_components = arg1->GetVectorComponents(const_mgr); + b_components = arg2->GetVectorComponents(const_mgr); + + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], + b_components[i], const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; + } + } + + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); + } + return const_mgr->GetConstant(vector_type, ids); + } else { + return scalar_rule(result_type, arg1, arg2, const_mgr); + } + }; +} + // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops // using |scalar_rule| and unary float point vectors ops by applying // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| @@ -1587,6 +1696,68 @@ BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double, return nullptr; }; } + +enum Sign { Signed, Unsigned }; + +// Returns a BinaryScalarFoldingRule that applies `op` to the scalars. +// The `signedness` is used to determine if the operands should be interpreted +// as signed or unsigned. If the operands are signed, the will be sign extended +// before the value is passed to `op`. Otherwise the values will be zero +// extended. +template +BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, + uint64_t)) { + return + [op](const analysis::Type* result_type, const analysis::Constant* a, + const analysis::Constant* b, + analysis::ConstantManager* const_mgr) -> const analysis::Constant* { + assert(result_type != nullptr && a != nullptr); + const analysis::Integer* integer_type = a->type()->AsInteger(); + assert(integer_type != nullptr); + assert(integer_type == result_type->AsInteger()); + assert(integer_type == b->type()->AsInteger()); + + // In SPIR-V, the signedness of the operands is determined by the + // opcode, and not by the type of the operands. This is why we use the + // template argument to determine how to interpret the operands. + uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() + : a->GetZeroExtendedValue()); + uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() + : b->GetZeroExtendedValue()); + uint64_t result = op(ia, ib); + + const analysis::Constant* result_constant = + GenerateIntegerConstant(integer_type, result, const_mgr); + return result_constant; + }; +} + +// A scalar folding rule that foles OpSConvert. +const analysis::Constant* FoldScalarSConvert( + const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) { + assert(a); + const analysis::Integer* integer_type = result_type->AsInteger(); + assert(integer_type && "The result type of an SConvert"); + int64_t value = a->GetSignExtendedValue(); + return GenerateIntegerConstant(integer_type, value, const_mgr); +} + +// A scalar folding rule that foles OpSConvert. +const analysis::Constant* FoldScalarUConvert( + const analysis::Type* result_type, const analysis::Constant* a, + analysis::ConstantManager* const_mgr) { + assert(a); + const analysis::Integer* integer_type = result_type->AsInteger(); + assert(integer_type && "The result type of an SConvert"); + uint64_t value = a->GetZeroExtendedValue(); + + // If the operand was an unsigned value with less than 32-bit, it would have + // been sign extended earlier, and we need to clear those bits. + auto* operand_type = a->type()->AsInteger(); + value = ClearUpperBits(value, 64 - operand_type->width()); + return GenerateIntegerConstant(integer_type, value, const_mgr); +} } // namespace void ConstantFoldingRules::AddFoldingRules() { @@ -1604,6 +1775,8 @@ void ConstantFoldingRules::AddFoldingRules() { rules_[spv::Op::OpConvertFToU].push_back(FoldFToI()); rules_[spv::Op::OpConvertSToF].push_back(FoldIToF()); rules_[spv::Op::OpConvertUToF].push_back(FoldIToF()); + rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert)); + rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert)); rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants()); rules_[spv::Op::OpFAdd].push_back(FoldFAdd()); @@ -1662,6 +1835,46 @@ void ConstantFoldingRules::AddFoldingRules() { rules_[spv::Op::OpSNegate].push_back(FoldSNegate()); rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16()); + rules_[spv::Op::OpIAdd].push_back( + FoldBinaryOp(FoldBinaryIntegerOperation( + [](uint64_t a, uint64_t b) { return a + b; }))); + rules_[spv::Op::OpISub].push_back( + FoldBinaryOp(FoldBinaryIntegerOperation( + [](uint64_t a, uint64_t b) { return a - b; }))); + rules_[spv::Op::OpIMul].push_back( + FoldBinaryOp(FoldBinaryIntegerOperation( + [](uint64_t a, uint64_t b) { return a * b; }))); + rules_[spv::Op::OpUDiv].push_back( + FoldBinaryOp(FoldBinaryIntegerOperation( + [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); }))); + rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp( + FoldBinaryIntegerOperation([](uint64_t a, uint64_t b) { + return (b != 0 ? static_cast(static_cast(a) / + static_cast(b)) + : 0); + }))); + rules_[spv::Op::OpUMod].push_back( + FoldBinaryOp(FoldBinaryIntegerOperation( + [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); }))); + + rules_[spv::Op::OpSRem].push_back(FoldBinaryOp( + FoldBinaryIntegerOperation([](uint64_t a, uint64_t b) { + return (b != 0 ? static_cast(static_cast(a) % + static_cast(b)) + : 0); + }))); + + rules_[spv::Op::OpSMod].push_back(FoldBinaryOp( + FoldBinaryIntegerOperation([](uint64_t a, uint64_t b) { + if (b == 0) return static_cast(0ull); + + int64_t signed_a = static_cast(a); + int64_t signed_b = static_cast(b); + int64_t result = signed_a % signed_b; + if ((signed_b < 0) != (result < 0)) result += signed_b; + return static_cast(result); + }))); + // Add rules for GLSLstd450 FeatureManager* feature_manager = context_->get_feature_mgr(); uint32_t ext_inst_glslstd450_id = diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index d1a81dff32..794d8c9797 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -107,6 +107,10 @@ std::tuple, Instruction*> FoldInstruction( std::tie(context, inst) = GetInstructionToFold(test_body, id_to_fold, spv_env); + if (context == nullptr) { + return {nullptr, nullptr}; + } + std::unique_ptr original_inst(inst->Clone(context.get())); bool succeeded = context->get_instruction_folder().FoldInstruction(inst); EXPECT_EQ(inst->result_id(), original_inst->result_id()); @@ -237,9 +241,13 @@ OpName %main "main" %ulong = OpTypeInt 64 0 %v2int = OpTypeVector %int 2 %v4int = OpTypeVector %int 4 +%v2short = OpTypeVector %short 2 +%v2long = OpTypeVector %long 2 +%v4long = OpTypeVector %long 4 %v4float = OpTypeVector %float 4 %v4double = OpTypeVector %double 4 %v2uint = OpTypeVector %uint 2 +%v2ulong = OpTypeVector %ulong 2 %v2float = OpTypeVector %float 2 %v2double = OpTypeVector %double 2 %v2half = OpTypeVector %half 2 @@ -270,6 +278,7 @@ OpName %main "main" %short_0 = OpConstant %short 0 %short_2 = OpConstant %short 2 %short_3 = OpConstant %short 3 +%short_n5 = OpConstant %short -5 %ubyte_1 = OpConstant %ubyte 1 %byte_n1 = OpConstant %byte -1 %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. @@ -289,7 +298,13 @@ OpName %main "main" %long_1 = OpConstant %long 1 %long_2 = OpConstant %long 2 %long_3 = OpConstant %long 3 +%long_n3 = OpConstant %long -3 +%long_7 = OpConstant %long 7 +%long_n7 = OpConstant %long -7 %long_10 = OpConstant %long 10 +%long_32768 = OpConstant %long 32768 +%long_n57344 = OpConstant %long -57344 +%long_n4611686018427387904 = OpConstant %long -4611686018427387904 %long_4611686018427387904 = OpConstant %long 4611686018427387904 %long_n1 = OpConstant %long -1 %long_n3689348814741910323 = OpConstant %long -3689348814741910323 @@ -318,6 +333,9 @@ OpName %main "main" %v2int_n1_n24 = OpConstantComposite %v2int %int_n1 %int_n24 %v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4 %v2int_min_max = OpConstantComposite %v2int %int_min %int_max +%v2short_2_n5 = OpConstantComposite %v2short %short_2 %short_n5 +%v2long_2_2 = OpConstantComposite %v2long %long_2 %long_2 +%v2long_2_3 = OpConstantComposite %v2long %long_2 %long_3 %v2bool_null = OpConstantNull %v2bool %v2bool_true_false = OpConstantComposite %v2bool %true %false %v2bool_false_true = OpConstantComposite %v2bool %false %true @@ -1016,10 +1034,238 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpSNegate %ushort %ushort_0x4400\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0xBC00 /* expected to be zero extended. */) + 2, 0xBC00 /* expected to be zero extended. */), + // Test case 67: Fold 2 + 3 (short) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %short %short_2 %short_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 5), + // Test case 68: Fold 2 + -5 (short) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpIAdd %short %short_2 %short_n5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -3), + // Test case 69: Fold int(3ll) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSConvert %int %long_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 3), + // Test case 70: Fold short(-3ll) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSConvert %short %long_n3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -3), + // Test case 71: Fold short(32768ll) - This should do a sign extend when + // converting to short. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSConvert %short %long_32768\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, -32768), + // Test case 72: Fold short(-57344) - This should do a sign extend when + // converting to short making the upper bits 0. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSConvert %short %long_n57344\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 8192), + // Test case 73: Fold int(-5(short)). The -5 should be interpreted as an unsigned value, and be zero extended to 32-bits. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpUConvert %uint %short_n5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 65531), + // Test case 74: Fold short(-24(int)). The upper bits should be cleared. So 0xFFFFFFE8 should become 0x0000FFE8. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpUConvert %ushort %int_n24\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 65512) )); // clang-format on +using LongIntegerInstructionFoldingTest = + ::testing::TestWithParam>; + +TEST_P(LongIntegerInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + std::unique_ptr context; + Instruction* inst; + std::tie(context, inst) = + FoldInstruction(tc.test_body, tc.id_to_fold, SPV_ENV_UNIVERSAL_1_1); + CheckForExpectedScalarConstant( + inst, tc.expected_result, [](const analysis::Constant* c) { + return c->AsScalarConstant()->GetU64BitValue(); + }); +} + +INSTANTIATE_TEST_SUITE_P( + TestCase, LongIntegerInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold 1+4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIAdd %long %long_1 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 1 + 4611686018427387904), + // Test case 1: fold 1-4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpISub %long %long_1 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 1 - 4611686018427387904), + // Test case 2: fold 2*4611686018427387904 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIMul %long %long_2 %long_4611686018427387904\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 2 * 4611686018427387904), + // Test case 3: fold 4611686018427387904/2 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpUDiv %long %long_4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904 / 2), + // Test case 4: fold 4611686018427387904/2 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %long %long_4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904 / 2), + // Test case 5: fold -4611686018427387904/2 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %long %long_n4611686018427387904 %long_2\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, -4611686018427387904 / 2), + // Test case 6: fold 4611686018427387904 mod 7 (unsigned) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpUMod %long %long_4611686018427387904 %long_7\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 4611686018427387904ull % 7ull), + // Test case 7: fold 7 mod 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ull), + // Test case 8: fold 7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ull), + // Test case 9: fold 7 mod -3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_7 %long_n3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, -2ll), + // Test case 10: fold 7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_7 %long_n3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 1ll), + // Test case 11: fold -7 mod 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSMod %long %long_n7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, 2ll), + // Test case 12: fold -7 rem 3 (signed) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSRem %long %long_n7 %long_3\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, -1ll), + // Test case 13: fold long(-24) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSConvert %long %int_n24\n" + "OpReturn\n" + + "OpFunctionEnd", + 2, -24ll), + // Test case 14: fold long(-24) + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + "%2 = OpSConvert %long %int_10\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 10ll), + // Test case 15: fold long(-24(short)). + // The upper bits should be cleared. So 0xFFFFFFE8 should become + // 0x000000000000FFE8. + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + "%2 = OpUConvert %ulong %short_n5\n" + + "OpReturn\n" + "OpFunctionEnd", + 2, 65531ull))); + using UIntVectorInstructionFoldingTest = ::testing::TestWithParam>>; @@ -1077,14 +1323,30 @@ ::testing::Values( "OpReturn\n" + "OpFunctionEnd", 2, {static_cast(-0x3f800000), static_cast(-0xbf800000)}), - // Test case 6: fold vector components of uint (incuding integer overflow) + // Test case 6: fold vector components of uint (including integer overflow) InstructionFoldingCase>( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%2 = OpIAdd %v2uint %v2uint_0x3f800000_0xbf800000 %v2uint_0x3f800000_0xbf800000\n" + "OpReturn\n" + "OpFunctionEnd", - 2, {0x7f000000u, 0x7f000000u}) + 2, {0x7f000000u, 0x7f000000u}), + // Test case 6: fold vector components of uint + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpSConvert %v2int %v2short_2_n5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {2,static_cast(-5)}), + // Test case 6: fold vector components of uint (incuding integer overflow) + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpUConvert %v2uint %v2short_2_n5\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {2,65531}) )); // clang-format on @@ -1142,6 +1404,51 @@ ::testing::Values( )); // clang-format on +using LongIntVectorInstructionFoldingTest = + ::testing::TestWithParam>>; + +TEST_P(LongIntVectorInstructionFoldingTest, Case) { + const auto& tc = GetParam(); + + std::unique_ptr context; + Instruction* inst; + std::tie(context, inst) = + FoldInstruction(tc.test_body, tc.id_to_fold, SPV_ENV_UNIVERSAL_1_1); + CheckForExpectedVectorConstant( + inst, tc.expected_result, + [](const analysis::Constant* c) { return c->GetU64(); }); +} + +// clang-format off +INSTANTIATE_TEST_SUITE_P(TestCase, LongIntVectorInstructionFoldingTest, + ::testing::Values( + // Test case 0: fold {2,2} + {2,3} (Testing that the vector logic works + // correctly. Scalar tests will check that the 64-bit values are correctly + // folded.) + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpIAdd %v2long %v2long_2_2 %v2long_2_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {4,5}), + // Test case 0: fold {2,2} / {2,3} (Testing that the vector logic works + // correctly. Scalar tests will check that the 64-bit values are correctly + // folded.) + InstructionFoldingCase>( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_int Function\n" + + "%load = OpLoad %int %n\n" + + "%2 = OpSDiv %v2long %v2long_2_2 %v2long_2_3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, {1,0}) + )); +// clang-format on + using DoubleVectorInstructionFoldingTest = ::testing::TestWithParam>>; From e50b100dfd36a11775d8aaf1463cb2d37ec8b095 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 09:42:45 -0500 Subject: [PATCH 2/6] Fixes based on code review. Remove stale test cases. --- source/opt/const_folding_rules.cpp | 101 +++++++++++++++-------------- test/opt/fold_test.cpp | 22 +------ 2 files changed, 56 insertions(+), 67 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index db4212fb17..7723f94ed2 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -21,8 +21,8 @@ namespace opt { namespace { constexpr uint32_t kExtractCompositeIdInIdx = 0; -// Returns the value obtained by setting clearing the `number_of_bits` most -// significant bits of `value`. +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and sign-extending it to 64-bits. uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { if (number_of_bits == 64) return value; @@ -38,12 +38,12 @@ uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { return value; } -// Returns the value obtained from clearing the `number_of_bits` most -// significant bits of `value`. -uint64_t ClearUpperBits(uint64_t value, uint32_t number_of_bits) { - if (number_of_bits == 0) return value; +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and zero-extending it to 64-bits. +uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) { + if (number_of_bits == 64) return value; - uint64_t mask_for_first_bit_to_clear = 1ull << (64 - number_of_bits); + uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits); uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1; value &= mask_for_bits_to_keep; return value; @@ -67,7 +67,7 @@ const analysis::Constant* GenerateIntegerConstant( if (integer_type->IsSigned()) { result = SignExtendValue(result, integer_type->width()); } else { - result = ClearUpperBits(result, 64 - integer_type->width()); + result = ZeroExtendValue(result, integer_type->width()); } words = {static_cast(result)}; } @@ -769,15 +769,18 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { } // Returns a |ConstantFoldingRule| that folds binary scalar ops -// using |scalar_rule| and unary vectors ops by applying -// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| -// that is returned assumes that |constants| contains 2 entries. If they are -// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| -// whose element type is |Float| or |Integer|. +// using |scalar_rule| and binary vectors ops by applying +// |scalar_rule| to the elements of the vector. The folding rule assumes that op +// has two inputs. For regular instruction, those are in operands 0 and 1. For +// extended instruction, they are in operands 1 and 2. If an element in +// |constants| is not nullprt, then the constant's type is |Float|, |Integer|, +// or |Vector| whose element type is |Float| or |Integer|. ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { + assert(constants.size() == inst->NumInOperands()); + assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2)); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -788,40 +791,38 @@ ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { const analysis::Constant* arg2 = (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; - if (arg1 == nullptr) { + if (arg1 == nullptr || arg2 == nullptr) { return nullptr; } - if (arg2 == nullptr) { - return nullptr; + + if (vector_type == nullptr) { + return scalar_rule(result_type, arg1, arg2, const_mgr); } - if (vector_type != nullptr) { - std::vector a_components; - std::vector b_components; - std::vector results_components; + std::vector a_components; + std::vector b_components; + std::vector results_components; - a_components = arg1->GetVectorComponents(const_mgr); - b_components = arg2->GetVectorComponents(const_mgr); + a_components = arg1->GetVectorComponents(const_mgr); + b_components = arg2->GetVectorComponents(const_mgr); + assert(a_components.size() == b_components.size()); - // Fold each component of the vector. - for (uint32_t i = 0; i < a_components.size(); ++i) { - results_components.push_back(scalar_rule(vector_type->element_type(), - a_components[i], - b_components[i], const_mgr)); - if (results_components[i] == nullptr) { - return nullptr; - } + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], b_components[i], + const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; } + } - // Build the constant object and return it. - std::vector ids; - for (const analysis::Constant* member : results_components) { - ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); - } - return const_mgr->GetConstant(vector_type, ids); - } else { - return scalar_rule(result_type, arg1, arg2, const_mgr); + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); } + return const_mgr->GetConstant(vector_type, ids); }; } @@ -1701,9 +1702,9 @@ enum Sign { Signed, Unsigned }; // Returns a BinaryScalarFoldingRule that applies `op` to the scalars. // The `signedness` is used to determine if the operands should be interpreted -// as signed or unsigned. If the operands are signed, the will be sign extended -// before the value is passed to `op`. Otherwise the values will be zero -// extended. +// as signed or unsigned. If the operands are signed, the value will be sign +// extended before the value is passed to `op`. Otherwise the values will be +// zero extended. template BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, uint64_t)) { @@ -1711,7 +1712,7 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, [op](const analysis::Type* result_type, const analysis::Constant* a, const analysis::Constant* b, analysis::ConstantManager* const_mgr) -> const analysis::Constant* { - assert(result_type != nullptr && a != nullptr); + assert(result_type != nullptr && a != nullptr && b != nullptr); const analysis::Integer* integer_type = a->type()->AsInteger(); assert(integer_type != nullptr); assert(integer_type == result_type->AsInteger()); @@ -1732,30 +1733,34 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, }; } -// A scalar folding rule that foles OpSConvert. +// A scalar folding rule that folds OpSConvert. const analysis::Constant* FoldScalarSConvert( const analysis::Type* result_type, const analysis::Constant* a, analysis::ConstantManager* const_mgr) { - assert(a); + assert(result_type != nullptr); + assert(a != nullptr); + assert(const_mgr != nullptr); const analysis::Integer* integer_type = result_type->AsInteger(); assert(integer_type && "The result type of an SConvert"); int64_t value = a->GetSignExtendedValue(); return GenerateIntegerConstant(integer_type, value, const_mgr); } -// A scalar folding rule that foles OpSConvert. +// A scalar folding rule that folds OpUConvert. const analysis::Constant* FoldScalarUConvert( const analysis::Type* result_type, const analysis::Constant* a, analysis::ConstantManager* const_mgr) { - assert(a); + assert(result_type != nullptr); + assert(a != nullptr); + assert(const_mgr != nullptr); const analysis::Integer* integer_type = result_type->AsInteger(); - assert(integer_type && "The result type of an SConvert"); + assert(integer_type && "The result type of an UConvert"); uint64_t value = a->GetZeroExtendedValue(); // If the operand was an unsigned value with less than 32-bit, it would have // been sign extended earlier, and we need to clear those bits. auto* operand_type = a->type()->AsInteger(); - value = ClearUpperBits(value, 64 - operand_type->width()); + value = ZeroExtendValue(value, operand_type->width()); return GenerateIntegerConstant(integer_type, value, const_mgr); } } // namespace diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 794d8c9797..e5f663f149 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -4118,23 +4118,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 38: Don't fold 2 + 3 (long), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %long %long_2 %long_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 39: Don't fold 2 + 3 (short), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %short %short_2 %short_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 40: fold 1*n + // Test case 38: fold 1*n InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -4144,7 +4128,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 41: fold n*1 + // Test case 39: fold n*1 InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -4154,7 +4138,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 42: Don't fold comparisons of 64-bit types + // Test case 40: Don't fold comparisons of 64-bit types // (https://github.com/KhronosGroup/SPIRV-Tools/issues/3343). InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + From c5dfe790f624e60b189a32104f68771909c842b8 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 09:59:01 -0500 Subject: [PATCH 3/6] Fix undefined overflow --- test/opt/fold_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index e5f663f149..92a76100cc 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -1149,7 +1149,7 @@ INSTANTIATE_TEST_SUITE_P( "%load = OpLoad %int %n\n" + "%2 = OpIMul %long %long_2 %long_4611686018427387904\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 2 * 4611686018427387904), + 2, 9223372036854775808ull), // Test case 3: fold 4611686018427387904/2 (unsigned) InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + From 99c466da11e8a72059630772062a4d2d8c0979cf Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 14:31:42 -0500 Subject: [PATCH 4/6] Use unsigned types with unsigned opcodes in tests. --- test/opt/fold_test.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 92a76100cc..a4e0447c10 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -310,6 +310,8 @@ OpName %main "main" %long_n3689348814741910323 = OpConstant %long -3689348814741910323 %long_min = OpConstant %long -9223372036854775808 %long_max = OpConstant %long 9223372036854775807 +%ulong_7 = OpConstant %ulong 7 +%ulong_4611686018427387904 = OpConstant %ulong 4611686018427387904 %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 %uint_2 = OpConstant %uint 2 @@ -1156,7 +1158,7 @@ INSTANTIATE_TEST_SUITE_P( "%main_lab = OpLabel\n" + "%n = OpVariable %_ptr_int Function\n" + "%load = OpLoad %int %n\n" + - "%2 = OpUDiv %long %long_4611686018427387904 %long_2\n" + + "%2 = OpUDiv %ulong %ulong_4611686018427387904 %ulong_2\n" + "OpReturn\n" + "OpFunctionEnd", 2, 4611686018427387904 / 2), // Test case 4: fold 4611686018427387904/2 (signed) @@ -1183,7 +1185,7 @@ INSTANTIATE_TEST_SUITE_P( "%main_lab = OpLabel\n" + "%n = OpVariable %_ptr_int Function\n" + "%load = OpLoad %int %n\n" + - "%2 = OpUMod %long %long_4611686018427387904 %long_7\n" + + "%2 = OpUMod %ulong %ulong_4611686018427387904 %ulong_7\n" + "OpReturn\n" + "OpFunctionEnd", 2, 4611686018427387904ull % 7ull), // Test case 7: fold 7 mod 3 (signed) From dfe137b40f5f5efa2ce689084b2182c6c283b459 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 14:34:33 -0500 Subject: [PATCH 5/6] Fix comment describing SPIR-V handling of intergers. --- source/opt/const_folding_rules.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 7723f94ed2..ac119c8311 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -1718,9 +1718,8 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, assert(integer_type == result_type->AsInteger()); assert(integer_type == b->type()->AsInteger()); - // In SPIR-V, the signedness of the operands is determined by the - // opcode, and not by the type of the operands. This is why we use the - // template argument to determine how to interpret the operands. + // In SPIR-V, all operations support unsigned types, but the way they are interpreted depends on the opcode. + // This is why we use the template argument to determine how to interpret the operands. uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() : a->GetZeroExtendedValue()); uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() From fca83ffd5cd18f3f8617d5e74ebc569518e60707 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 15:29:24 -0500 Subject: [PATCH 6/6] Fix release build. --- source/opt/const_folding_rules.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index ac119c8311..79f34acd3c 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -1713,13 +1713,14 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, const analysis::Constant* b, analysis::ConstantManager* const_mgr) -> const analysis::Constant* { assert(result_type != nullptr && a != nullptr && b != nullptr); - const analysis::Integer* integer_type = a->type()->AsInteger(); + const analysis::Integer* integer_type = result_type->AsInteger(); assert(integer_type != nullptr); - assert(integer_type == result_type->AsInteger()); + assert(integer_type == a->type()->AsInteger()); assert(integer_type == b->type()->AsInteger()); - // In SPIR-V, all operations support unsigned types, but the way they are interpreted depends on the opcode. - // This is why we use the template argument to determine how to interpret the operands. + // In SPIR-V, all operations support unsigned types, but the way they + // are interpreted depends on the opcode. This is why we use the + // template argument to determine how to interpret the operands. uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() : a->GetZeroExtendedValue()); uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()