Skip to content
Open
73 changes: 66 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
Expand Down Expand Up @@ -867,13 +868,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,

// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine(
{ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
{ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
ISD::FMINIMUMNUM, ISD::MUL, ISD::SELECT,
ISD::SHL, ISD::SREM, ISD::UREM,
ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST,
ISD::LOAD, ISD::STORE, ISD::ZERO_EXTEND,
ISD::SIGN_EXTEND});

// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
Expand Down Expand Up @@ -6233,6 +6235,61 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
return Result;
}

/// Transform patterns like:
/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
/// Into:
/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
///
/// These patterns arise from C/C++ code like `shift >= 32 ? 0 : x >> shift`
/// which guards against undefined behavior. PTX shr/shl instructions clamp
/// shift amounts >= BitWidth to produce 0 for logical shifts, making the
/// guard redundant.
///
/// Note: We only handle SRL and SHL, not SRA, because arithmetic right
/// shifts could produce 0 or -1 when shift >= BitWidth.
/// Note: We don't handle uge or ule. These don't appear because of
/// canonicalization.
static SDValue PerformSELECTShiftCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
if (!DCI.isAfterLegalizeDAG())
return SDValue();

using namespace SDPatternMatch;
unsigned BitWidth = N->getValueType(0).getSizeInBits();
SDValue ShiftAmt, ShiftOp;

// Match logical shifts where the shift amount in the guard matches the shift
// amount in the operation.
auto LogicalShift =
m_AllOf(m_Value(ShiftOp),
m_AnyOf(m_Srl(m_Value(), m_TruncOrSelf(m_Deferred(ShiftAmt))),
m_Shl(m_Value(), m_TruncOrSelf(m_Deferred(ShiftAmt)))));

// shift_amt > BitWidth-1 ? 0 : shift_op
bool MatchedUGT =
sd_match(N, m_Select(m_SetCC(m_Value(ShiftAmt),
m_SpecificInt(APInt(BitWidth, BitWidth - 1)),
m_SpecificCondCode(ISD::SETUGT)),
m_Zero(), LogicalShift));
// shift_amt < BitWidth ? shift_op : 0
bool MatchedULT =
!MatchedUGT &&
sd_match(N, m_Select(m_SetCC(m_Value(ShiftAmt),
m_SpecificInt(APInt(BitWidth, BitWidth)),
m_SpecificCondCode(ISD::SETULT)),
LogicalShift, m_Zero()));

if (!MatchedUGT && !MatchedULT)
return SDValue();

// Return a clamp shift operation, which has the same semantics as PTX shift.
unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
: NVPTXISD::SHL_CLAMP;
return DCI.DAG.getNode(ClampOpc, SDLoc(N), ShiftOp.getValueType(),
ShiftOp.getOperand(0), ShiftOp.getOperand(1));
}

static SDValue PerformVSELECTCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
SDValue VA = N->getOperand(1);
Expand Down Expand Up @@ -6544,6 +6601,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
return combineSTORE(N, DCI, STI);
case ISD::SELECT:
return PerformSELECTShiftCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,15 @@ defm SHL : SHIFT<"shl.b", shl>;
defm SRA : SHIFT<"shr.s", sra>;
defm SRL : SHIFT<"shr.u", srl>;

// Shift with clamping semantics - these have defined behavior for shift amounts
// >= BitWidth (returning 0 for logical shifts). Used to optimize guarded shift
// patterns like `shift >= 32 ? 0 : x >> shift`.
def shl_clamp : SDNode<"NVPTXISD::SHL_CLAMP", SDTIntShiftOp, []>;
def srl_clamp : SDNode<"NVPTXISD::SRL_CLAMP", SDTIntShiftOp, []>;

defm SHL_CLAMP : SHIFT<"shl.b", shl_clamp>;
defm SRL_CLAMP : SHIFT<"shr.u", srl_clamp>;

// Bit-reverse
foreach t = [I64RT, I32RT] in
def BREV_ # t.PtxType :
Expand Down
Loading