diff --git a/src/coreclr/jit/gtlist.h b/src/coreclr/jit/gtlist.h index 0ec112781c050e..93fb4dfd358b37 100644 --- a/src/coreclr/jit/gtlist.h +++ b/src/coreclr/jit/gtlist.h @@ -280,19 +280,19 @@ GTNODE(SELECT_NEGCC , GenTreeOpCC ,0,0,GTK_BINOP|DBK_NOTHIR) #ifdef TARGET_RISCV64 // Maps to riscv64 sh1add instruction. Computes result = op2 + (op1 << 1). GTNODE(SH1ADD , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) -// Maps to riscv64 sh1add.uw instruction. Computes result = op2 + zext(op1[31..0] << 1). +// Maps to riscv64 sh1add.uw instruction. Computes result = op2 + (zext(op1[31..0]) << 1). GTNODE(SH1ADD_UW , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) // Maps to riscv64 sh2add instruction. Computes result = op2 + (op1 << 2). GTNODE(SH2ADD , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) -// Maps to riscv64 sh2add.uw instruction. Computes result = op2 + zext(op1[31..0] << 2). +// Maps to riscv64 sh2add.uw instruction. Computes result = op2 + (zext(op1[31..0]) << 2). GTNODE(SH2ADD_UW , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) // Maps to riscv64 sh3add instruction. Computes result = op2 + (op1 << 3). GTNODE(SH3ADD , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) -// Maps to riscv64 sh3add.uw instruction. Computes result = op2 + zext(op1[31..0] << 3). +// Maps to riscv64 sh3add.uw instruction. Computes result = op2 + (zext(op1[31..0]) << 3). GTNODE(SH3ADD_UW , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) // Maps to riscv64 add.uw instruction. Computes result = op2 + zext(op1[31..0]). GTNODE(ADD_UW , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) -// Maps to riscv64 slli.uw instruction. Computes result = zext(op1[31..0] << imm). +// Maps to riscv64 slli.uw instruction. Computes result = zext(op1[31..0]) << imm. GTNODE(SLLI_UW , GenTreeOp ,0,0,GTK_BINOP|DBK_NOTHIR) #endif diff --git a/src/coreclr/jit/ifconversion.cpp b/src/coreclr/jit/ifconversion.cpp index ef81c2580a1928..abf7f53a6b0250 100644 --- a/src/coreclr/jit/ifconversion.cpp +++ b/src/coreclr/jit/ifconversion.cpp @@ -57,6 +57,9 @@ class OptIfConversionDsc bool IfConvertCheckStmts(BasicBlock* fromBlock, IfConvertOperation* foundOperation); void IfConvertJoinStmts(BasicBlock* fromBlock); + GenTree* TryTransformSelectOperOrLocal(GenTree* oper, GenTree* lcl); + GenTree* TryTransformSelectOperOrZero(GenTree* oper, GenTree* lcl); + GenTree* TryTransformSelectToOrdinaryOps(GenTree* trueInput, GenTree* falseInput); #ifdef DEBUG void IfConvertDump(); #endif @@ -678,17 +681,8 @@ bool OptIfConversionDsc::optIfConvert() GenTree* selectFalseInput; if (m_mainOper == GT_STORE_LCL_VAR) { - if (m_doElseConversion) - { - selectTrueInput = m_elseOperation.node->AsLclVar()->Data(); - selectFalseInput = m_thenOperation.node->AsLclVar()->Data(); - } - else // Duplicate the destination of the Then store. - { - GenTreeLclVar* store = m_thenOperation.node->AsLclVar(); - selectTrueInput = m_comp->gtNewLclVarNode(store->GetLclNum(), store->TypeGet()); - selectFalseInput = m_thenOperation.node->AsLclVar()->Data(); - } + selectFalseInput = m_thenOperation.node->AsLclVar()->Data(); + selectTrueInput = m_doElseConversion ? m_elseOperation.node->AsLclVar()->Data() : nullptr; // Pick the type as the type of the local, which should always be compatible even for implicit coercions. selectType = genActualType(m_thenOperation.node); @@ -704,23 +698,20 @@ bool OptIfConversionDsc::optIfConvert() selectType = genActualType(m_thenOperation.node); } - GenTree* select = nullptr; - if (selectTrueInput->TypeIs(TYP_INT) && selectFalseInput->TypeIs(TYP_INT)) + GenTree* select = TryTransformSelectToOrdinaryOps(selectTrueInput, selectFalseInput); + if (select == nullptr) { - if (selectTrueInput->IsIntegralConst(1) && selectFalseInput->IsIntegralConst(0)) - { - // compare ? true : false --> compare - select = m_cond; - } - else if (selectTrueInput->IsIntegralConst(0) && selectFalseInput->IsIntegralConst(1)) +#ifdef TARGET_RISCV64 + JITDUMP("Skipping if-conversion that cannot be transformed to ordinary operations\n"); + return false; +#endif + if (selectTrueInput == nullptr) { - // compare ? false : true --> reversed_compare - select = m_comp->gtReverseCond(m_cond); + // Duplicate the destination of the Then store. + assert(m_mainOper == GT_STORE_LCL_VAR && !m_doElseConversion); + GenTreeLclVar* store = m_thenOperation.node->AsLclVar(); + selectTrueInput = m_comp->gtNewLclVarNode(store->GetLclNum(), store->TypeGet()); } - } - - if (select == nullptr) - { // Create a select node select = m_comp->gtNewConditionalNode(GT_SELECT, m_cond, selectTrueInput, selectFalseInput, selectType); } @@ -774,6 +765,235 @@ bool OptIfConversionDsc::optIfConvert() return true; } +struct IntConstSelectOper +{ + genTreeOps oper; + var_types type; + unsigned bitIndex; + + bool isMatched() const + { + return oper != GT_NONE; + } +}; + +//----------------------------------------------------------------------------- +// MatchIntConstSelectValues: Matches an operation so that `trueVal` can be calculated as: +// oper(type, falseVal, condition) +// +// Notes: +// A non-zero bitIndex (log2(trueVal)) differentiates (condition << bitIndex) from (falseVal << condition). +// +// Return Value: +// The matched operation (if any). +// +static IntConstSelectOper MatchIntConstSelectValues(int64_t trueVal, int64_t falseVal) +{ + if (trueVal == falseVal + 1) + return {GT_ADD, TYP_LONG}; + + if (trueVal == int64_t(int32_t(falseVal) + 1)) + return {GT_ADD, TYP_INT}; + + if (falseVal == 0) + { + unsigned bitIndex = BitOperations::Log2((uint64_t)trueVal); + assert(bitIndex > 0); + if (trueVal == (int64_t(1) << bitIndex)) + return {GT_LSH, TYP_LONG, bitIndex}; + + bitIndex = BitOperations::Log2((uint32_t)trueVal); + assert(bitIndex > 0); + if (trueVal == int64_t(int32_t(int32_t(1) << bitIndex))) + return {GT_LSH, TYP_INT, bitIndex}; + } + + if (trueVal == falseVal << 1) + return {GT_LSH, TYP_LONG}; + + if (trueVal == int64_t(int32_t(falseVal) << 1)) + return {GT_LSH, TYP_INT}; + + if (trueVal == falseVal >> 1) + return {GT_RSH, TYP_LONG}; + + if (trueVal == int64_t(int32_t(falseVal) >> 1)) + return {GT_RSH, TYP_INT}; + + if (trueVal == int64_t(uint64_t(falseVal) >> 1)) + return {GT_RSZ, TYP_LONG}; + + if (trueVal == int64_t(uint32_t(falseVal) >> 1)) + return {GT_RSZ, TYP_INT}; + + return {GT_NONE}; +} + +//----------------------------------------------------------------------------- +// TryTransformSelectOperOrLocal: Try to trasform "cond ? oper(lcl, (-)1) : lcl" into "oper(')(lcl, cond)" +// +// Arguments: +// trueInput - expression to be evaluated when m_cond is true +// falseInput - expression to be evaluated when m_cond is false +// +// Return Value: +// The transformed expression, or null if no transformation took place +// +GenTree* OptIfConversionDsc::TryTransformSelectOperOrLocal(GenTree* trueInput, GenTree* falseInput) +{ + GenTree* oper = trueInput; + GenTree* lcl = falseInput; + + bool isCondReversed = !lcl->OperIsAnyLocal(); + if (isCondReversed) + std::swap(oper, lcl); + + if (lcl->OperIsAnyLocal() && (oper->OperIs(GT_ADD, GT_OR, GT_XOR) || oper->OperIsShift())) + { + GenTree* lcl2 = oper->gtGetOp1(); + GenTree* one = oper->gtGetOp2(); + if (oper->OperIsCommutative() && !one->IsIntegralConst()) + std::swap(lcl2, one); + + bool isDecrement = oper->OperIs(GT_ADD) && one->IsIntegralConst(-1); + if (one->IsIntegralConst(1) || isDecrement) + { + unsigned lclNum = lcl->AsLclVarCommon()->GetLclNum(); + if (lcl2->OperIs(GT_LCL_VAR) && (lcl2->AsLclVar()->GetLclNum() == lclNum)) + { + oper->AsOp()->gtOp1 = lcl2; + oper->AsOp()->gtOp2 = isCondReversed ? m_comp->gtReverseCond(m_cond) : m_cond; + if (isDecrement) + oper->ChangeOper(GT_SUB); + + oper->gtFlags |= m_cond->gtFlags & GTF_ALL_EFFECT; + return oper; + } + } + } + return nullptr; +} + +//----------------------------------------------------------------------------- +// TryTransformSelectOperOrZero: Try to trasform "cond ? oper(1, expr) : 0" into "oper(cond, expr)" +// +// Arguments: +// trueInput - expression to be evaluated when m_cond is true +// falseInput - expression to be evaluated when m_cond is false +// +// Return Value: +// The transformed expression, or null if no transformation took place +// +GenTree* OptIfConversionDsc::TryTransformSelectOperOrZero(GenTree* trueInput, GenTree* falseInput) +{ + GenTree* oper = trueInput; + GenTree* zero = falseInput; + + bool isCondReversed = !zero->IsIntegralConst(); + if (isCondReversed) + std::swap(oper, zero); + + if (zero->IsIntegralConst(0) && oper->OperIs(GT_AND, GT_LSH)) + { + GenTree* one = oper->gtGetOp1(); + GenTree* expr = oper->gtGetOp2(); + if (oper->OperIsCommutative() && !one->IsIntegralConst()) + std::swap(one, expr); + + if (one->IsIntegralConst(1)) + { + oper->AsOp()->gtOp1 = isCondReversed ? m_comp->gtReverseCond(m_cond) : m_cond; + oper->AsOp()->gtOp2 = expr; + + oper->gtFlags |= m_cond->gtFlags & GTF_ALL_EFFECT; + return oper; + } + } + return nullptr; +} + +//----------------------------------------------------------------------------- +// TryTransformSelectToOrdinaryOps: Try transforming the identified if-else expressions to a single expression +// +// This is meant mostly for RISC-V where the condition (1 or 0) is stored in a regular general-purpose register +// which can be fed as an argument to standard operations, e.g. +// * (cond ? 6 : 5) becomes (5 + cond) +// * (cond ? -25 : -13) becomes (-25 >> cond) +// * if (cond) a++; becomes (a + cond) +// * (cond ? 1 << a : 0) becomes (cond << a) +// +// Arguments: +// trueInput - expression to be evaluated when m_cond is true, or null if there is no else expression +// falseInput - expression to be evaluated when m_cond is false +// +// Return Value: +// The transformed single expression equivalent to the if-else expressions, or null if no transformation took place +// +GenTree* OptIfConversionDsc::TryTransformSelectToOrdinaryOps(GenTree* trueInput, GenTree* falseInput) +{ + assert(falseInput != nullptr); + + if ((trueInput != nullptr && trueInput->IsIntegralConst()) && falseInput->IsIntegralConst()) + { + int64_t trueVal = trueInput->AsIntConCommon()->IntegralValue(); + int64_t falseVal = falseInput->AsIntConCommon()->IntegralValue(); + if (trueInput->TypeIs(TYP_INT) && falseInput->TypeIs(TYP_INT)) + { + if (trueVal == 1 && falseVal == 0) + { + // compare ? true : false --> compare + return m_cond; + } + else if (trueVal == 0 && falseVal == 1) + { + // compare ? false : true --> reversed_compare + return m_comp->gtReverseCond(m_cond); + } + } +#ifdef TARGET_RISCV64 + bool isCondReversed = false; + IntConstSelectOper selectOper = MatchIntConstSelectValues(trueVal, falseVal); + if (!selectOper.isMatched()) + { + isCondReversed = true; + selectOper = MatchIntConstSelectValues(falseVal, trueVal); + } + if (selectOper.isMatched()) + { + GenTree* left = isCondReversed ? trueInput : falseInput; + GenTree* right = isCondReversed ? m_comp->gtReverseCond(m_cond) : m_cond; + if (selectOper.bitIndex > 0) + { + assert(selectOper.oper == GT_LSH); + left->AsIntConCommon()->SetIntegralValue(selectOper.bitIndex); + std::swap(left, right); + } + return m_comp->gtNewOperNode(selectOper.oper, selectOper.type, left, right); + } + return nullptr; +#endif // TARGET_RISCV64 + } +#ifdef TARGET_RISCV64 + else + { + if (trueInput == nullptr) + { + assert(m_mainOper == GT_STORE_LCL_VAR && !m_doElseConversion); + trueInput = m_thenOperation.node; + } + + GenTree* transformed = TryTransformSelectOperOrLocal(trueInput, falseInput); + if (transformed != nullptr) + return transformed; + + transformed = TryTransformSelectOperOrZero(trueInput, falseInput); + if (transformed != nullptr) + return transformed; + } +#endif // TARGET_RISCV64 + return nullptr; +} + //----------------------------------------------------------------------------- // optIfConversion: If conversion // @@ -800,7 +1020,7 @@ PhaseStatus Compiler::optIfConversion() assert(!fgSsaValid); optReachableBitVecTraits = nullptr; -#if defined(TARGET_ARM64) || defined(TARGET_XARCH) +#if defined(TARGET_ARM64) || defined(TARGET_XARCH) || defined(TARGET_RISCV64) // Reverse iterate through the blocks. BasicBlock* block = fgLastBB; while (block != nullptr) diff --git a/src/tests/JIT/opt/Compares/conditionalSimpleOps.cs b/src/tests/JIT/opt/Compares/conditionalSimpleOps.cs new file mode 100644 index 00000000000000..547962c693c911 --- /dev/null +++ b/src/tests/JIT/opt/Compares/conditionalSimpleOps.cs @@ -0,0 +1,291 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// unit test for the full range comparison optimization + +using System; +using System.Runtime.CompilerServices; +using Xunit; + +public class ConditionalSimpleOpConstantTest +{ + [Theory] + [InlineData(12, 10)] + [InlineData(45, 5)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_left(byte op1, int expected) + { + int result = op1 < 42 ? 10 : 5; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, -13)] + [InlineData(45, -25)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_arithmetic(byte op1, int expected) + { + int result = op1 > 42 ? -25 : -13; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 0x7FFF_FFF3)] + [InlineData(45, -25)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_logic(byte op1, int expected) + { + int result = op1 < 42 ? 0x7FFF_FFF3 : -25; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 0x7FFF_FFFF_FFFF_FFF3ul)] + [InlineData(45, 0xFFFF_FFFF_FFFF_FFE7ul)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_logic_ulong(byte op1, ulong expected) + { + ulong result = op1 < 42 ? 0x7FFF_FFFF_FFFF_FFF3ul : 0xFFFF_FFFF_FFFF_FFE7ul; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 0x7FFF_FFF3)] + [InlineData(45, 0xFFFF_FFE7)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_logic_long_32(byte op1, long expected) + { + long result = op1 > 42 ? 0xFFFF_FFE7 : 0x7FFF_FFF3; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 64)] + [InlineData(45, 0)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void pow2_or_zero(byte op1, int expected) + { + int result = op1 < 42 ? 64 : 0; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, long.MinValue)] + [InlineData(45, 0)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void pow2_or_zero_long(byte op1, long expected) + { + long result = op1 >= 42 ? 0 : long.MinValue; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 0xFFFF_FFFF_8000_0000ul)] + [InlineData(45, 0ul)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void pow2_or_zero_ulong_32(byte op1, ulong expected) + { + ulong result = op1 < 42 ? 0xFFFF_FFFF_8000_0000ul : 0ul; + Assert.Equal(expected, result); + } +} + +public class ConditionalSimpleOpVariableTest +{ + [Theory] + [InlineData(11, 12)] + [InlineData(12, 13)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void add_var(int a, int expected) + { + a = a < 42 ? a + 1 : a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(11, 12)] + [InlineData(12, 13)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void add_var_no_else(int a, int expected) + { + if (a < 42) + a++; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(11, 12)] + [InlineData(12, 13)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void add_var_reversed(int a, int expected) + { + a = a > 42 ? a : ++a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(12, 13)] + [InlineData(13, 13)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void or_var(int a, int expected) + { + a = a < 42 ? a | 1 : a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(12, 13)] + [InlineData(13, 13)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void or_var_no_else(int a, int expected) + { + if (a < 42) + a = a | 1; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(11, 10)] + [InlineData(12, 11)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void sub_var(int a, int expected) + { + a = a < 42 ? a - 1 : a; + Assert.Equal(expected, a); + } + + public static int globVar = 0; + [Theory] + [InlineData(11, 10)] + [InlineData(12, 11)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void sub_var_globref(int a, int expected) + { + a = (a + globVar) < 42 ? a - 1 : a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(11, 10)] + [InlineData(12, 11)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void sub_var_globref_no_else(int a, int expected) + { + if ((a + globVar) < 42) + --a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(12, 13)] + [InlineData(13, 12)] + [InlineData(45, 45)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void xor_var(int a, int expected) + { + a = a < 42 ? a ^ 1 : a; + Assert.Equal(expected, a); + } + + [Theory] + [InlineData(-12, -24)] + [InlineData(12, 24)] + [InlineData(43, 43)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_left_var(int a, int expected) + { + long result = a > 42 ? a : a * 2; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(-12, -24)] + [InlineData(12, 24)] + [InlineData(43, 43)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_left_var_no_else(int a, int expected) + { + long result = a; + if (a <= 42) + result *= 2; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(-12, -24)] + [InlineData(12, 24)] + [InlineData(43, 3)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_left_var_no_else_different_var(long a, long expected) + { + long result = 3; + if (a <= 42) + result = a * 2; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(12, 6)] + [InlineData(-25, -13)] + [InlineData(45, 45)] + [InlineData(-4000_000_000_000l, -2000_000_000_000l)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_arithmetic_var(long a, long expected) + { + long result = a > 42 ? a : a >> 1; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(43, 21)] + [InlineData(0x8000_0000, 0x4000_0000)] + [InlineData(12, 12)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void shift_right_logic_var(uint a, uint expected) + { + uint result = a > 42 ? a >> 1 : a; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(44, 0)] + [InlineData(43, 1)] + [InlineData(11, 0)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void and_or_zero_var(int a, int expected) + { + int result = a > 42 ? a & 1 : 0; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(44, 0)] + [InlineData(43, 0)] + [InlineData(11, 1)] + [InlineData(10, 0)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void and_or_zero_var_globref_reversed(uint a, uint expected) + { + uint result = (a ^ globVar) > 42 ? 0 : a & 1; + Assert.Equal(expected, result); + } + + [Theory] + [InlineData(4, 16)] + [InlineData(6, 64)] + [InlineData(43, 0)] + [MethodImpl(MethodImplOptions.NoInlining)] + public static void pow2_or_zero_var(int a, int expected) + { + int result = a > 42 ? 0 : 1 << a; + Assert.Equal(expected, result); + } +} diff --git a/src/tests/JIT/opt/Compares/conditionalSimpleOps.csproj b/src/tests/JIT/opt/Compares/conditionalSimpleOps.csproj new file mode 100644 index 00000000000000..b042a45bea9216 --- /dev/null +++ b/src/tests/JIT/opt/Compares/conditionalSimpleOps.csproj @@ -0,0 +1,15 @@ + + + + true + + + None + True + + + + + + +