Skip to content

Commit

Permalink
[RISCV] Merge GPRPair and GPRF64Pair (llvm#116094)
Browse files Browse the repository at this point in the history
As suggested by Craig, this tries to merge the two sets of register
classes created in llvm#112983, GPRPair* and GPRF64Pair*.

- I added some explicit annotations to `RISCVInstrInfoD.td` which fixed
the type inference issues I was seeing from tablegen for select
patterns.
- I've had to make the behaviour of `splitValueIntoRegisterParts` and
`joinRegisterPartsIntoValue` cover more cases, because you cannot
bitcast to/from untyped (the bitcast would otherwise have been inserted
automatically by TargetLowering code).
- I apparently didn't need to change `getNumRegisters` again, which
continues to tell me there's a bug in the code for tied inputs. I added
some more test coverage of this case but it didn't seem to help find the
asserts I was finding before - I think the difference is between the
default behaviour for integers which doesn't apply to floats.
- There's still a difference between BuildGPRPair and BuildPairF64 (and
the same for SplitGPRPair and SplitF64). I'm not happy with this, I
think it's quite confusing, as they're very similar, just differing in
whether they give a `untyped` or a `f64`. I haven't really worked out
how the DAGCombiner copes if one meets the other, I know we have some of
this for the f64 variants already, but they're a lot more complex than
the GPRPair variants anyway.
  • Loading branch information
lenary authored Nov 20, 2024
1 parent 77bf34c commit 408659c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 64 deletions.
10 changes: 2 additions & 8 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,10 @@ struct RISCVOperand final : public MCParsedAsmOperand {
RISCVMCRegisterClasses[RISCV::GPRF32RegClassID].contains(Reg.RegNum);
}

bool isGPRF64Pair() const {
return Kind == KindTy::Register &&
RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID].contains(
Reg.RegNum);
}

bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR32() const { return isGPRF32() && Reg.IsGPRAsFPR; }
bool isGPRPairAsFPR64() const { return isGPRF64Pair() && Reg.IsGPRAsFPR; }
bool isGPRPairAsFPR64() const { return isGPRPair() && Reg.IsGPRAsFPR; }

static bool evaluateConstantImm(const MCExpr *Expr, int64_t &Imm,
RISCVMCExpr::VariantKind &VK) {
Expand Down Expand Up @@ -2405,7 +2399,7 @@ ParseStatus RISCVAsmParser::parseGPRPairAsFPR64(OperandVector &Operands) {
const MCRegisterInfo *RI = getContext().getRegisterInfo();
MCRegister Pair = RI->getMatchingSuperReg(
Reg, RISCV::sub_gpr_even,
&RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID]);
&RISCVMCRegisterClasses[RISCV::GPRPairRegClassID]);
Operands.push_back(RISCVOperand::createReg(Pair, S, E, /*isGPRAsFPR=*/true));
return ParseStatus::Success;
}
Expand Down
17 changes: 6 additions & 11 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,20 +958,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
assert((!Subtarget->is64Bit() || Opcode == RISCVISD::BuildGPRPair) &&
"BuildPairF64 only handled here on rv32i_zdinx");

int RegClassID = (Opcode == RISCVISD::BuildGPRPair)
? RISCV::GPRPairRegClassID
: RISCV::GPRF64PairRegClassID;
MVT OutType = (Opcode == RISCVISD::BuildGPRPair) ? MVT::Untyped : MVT::f64;

SDValue Ops[] = {
CurDAG->getTargetConstant(RegClassID, DL, MVT::i32),
CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32),
Node->getOperand(0),
CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32),
Node->getOperand(1),
CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};

SDNode *N =
CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, OutType, Ops);
SDNode *N = CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, VT, Ops);
ReplaceNode(Node, N);
return;
}
Expand All @@ -982,14 +976,15 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
"SplitF64 only handled here on rv32i_zdinx");

if (!SDValue(Node, 0).use_empty()) {
SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL, VT,
SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL,
Node->getValueType(0),
Node->getOperand(0));
ReplaceUses(SDValue(Node, 0), Lo);
}

if (!SDValue(Node, 1).use_empty()) {
SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL, VT,
Node->getOperand(0));
SDValue Hi = CurDAG->getTargetExtractSubreg(
RISCV::sub_gpr_odd, DL, Node->getValueType(1), Node->getOperand(0));
ReplaceUses(SDValue(Node, 1), Hi);
}

Expand Down
51 changes: 34 additions & 17 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit())
addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
else
addRegisterClass(MVT::f64, &RISCV::GPRF64PairRegClass);
addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
}

static const MVT::SimpleValueType BoolVecVTs[] = {
Expand Down Expand Up @@ -20507,7 +20507,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32NoX0RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (VT == MVT::f16) {
Expand All @@ -20524,14 +20524,14 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64RegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
}
break;
case 'R':
if (VT == MVT::f64 && !Subtarget.is64Bit() && Subtarget.hasStdExtZdinx())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
default:
break;
Expand Down Expand Up @@ -20570,7 +20570,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32CRegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (!VT.isVector())
return std::make_pair(0U, &RISCV::GPRCRegClass);
} else if (Constraint == "cf") {
Expand All @@ -20588,7 +20588,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64CRegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRCRegClass);
}
Expand Down Expand Up @@ -20752,7 +20752,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
// Subtarget into account.
if (Res.second == &RISCV::GPRF16RegClass ||
Res.second == &RISCV::GPRF32RegClass ||
Res.second == &RISCV::GPRF64PairRegClass)
Res.second == &RISCV::GPRPairRegClass)
return std::make_pair(Res.first, &RISCV::GPRRegClass);

return Res;
Expand Down Expand Up @@ -21379,12 +21379,19 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();

if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
if ((ValueVT == PairVT ||
(!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
// Pairs in Inline Assembly
// Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
if (ValueVT == MVT::f64)
Val = DAG.getBitcast(MVT::i64, Val);
auto [Lo, Hi] = DAG.SplitScalar(Val, DL, XLenVT, XLenVT);
Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, MVT::Untyped, Lo, Hi);
// Always creating an MVT::Untyped part, so always use
// RISCVISD::BuildGPRPair.
Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, PartVT, Lo, Hi);
return true;
}

Expand All @@ -21396,7 +21403,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
Parts[0] = Val;
return true;
}
Expand Down Expand Up @@ -21465,14 +21472,24 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();

if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
if ((ValueVT == PairVT ||
(!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
// Pairs in Inline Assembly
// Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
SDValue Res = DAG.getNode(RISCVISD::SplitGPRPair, DL,
DAG.getVTList(XLenVT, XLenVT), Parts[0]);
return DAG.getNode(ISD::BUILD_PAIR, DL, ValueVT, Res.getValue(0),
Res.getValue(1));

SDValue Val = Parts[0];
// Always starting with an MVT::Untyped part, so always use
// RISCVISD::SplitGPRPair
Val = DAG.getNode(RISCVISD::SplitGPRPair, DL, DAG.getVTList(XLenVT, XLenVT),
Val);
Val = DAG.getNode(ISD::BUILD_PAIR, DL, PairVT, Val.getValue(0),
Val.getValue(1));
if (ValueVT == MVT::f64)
Val = DAG.getBitcast(ValueVT, Val);
return Val;
}

if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/RISCV/RISCVInstrInfoD.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def FPR64INX : RegisterOperand<GPR> {
let DecoderMethod = "DecodeGPRRegisterClass";
}

def FPR64IN32X : RegisterOperand<GPRF64Pair> {
def FPR64IN32X : RegisterOperand<GPRPair> {
let ParserMatchClass = GPRPairAsFPR;
}

Expand Down Expand Up @@ -457,16 +457,16 @@ def : PatSetCC<FPR64INX, any_fsetccs, SETOLE, FLE_D_INX, f64>;

let Predicates = [HasStdExtZdinx, IsRV32] in {
// Match signaling FEQ_D
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETOEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETOEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
// If both operands are the same, use a single FLE.
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETEQ)),
(FLE_D_IN32X $rs1, $rs1)>;
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETOEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETOEQ)),
(FLE_D_IN32X $rs1, $rs1)>;

def : PatSetCC<FPR64IN32X, any_fsetccs, SETLT, FLT_D_IN32X, f64>;
Expand Down Expand Up @@ -523,15 +523,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;

/// Loads
let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxLD : Pseudo<(outs GPRF64Pair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
(PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;

/// Stores
let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRF64Pair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRF64Pair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRF64Pair:$rs2, GPR:$rs1, simm12:$imm12)>;
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;
} // Predicates = [HasStdExtZdinx, IsRV32]

let Predicates = [HasStdExtD, IsRV32] in {
Expand Down
22 changes: 3 additions & 19 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ let RegAltNameIndices = [ABIRegAltName] in {

let RegInfos = XLenPairRI,
DecoderMethod = "DecodeGPRPairRegisterClass" in {
def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
def GPRPair : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
Expand All @@ -334,11 +334,11 @@ def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
X0_Pair, X2_X3, X4_X5
)>;

def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT], 64, (sub GPRPair, X0_Pair)>;
def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (sub GPRPair, X0_Pair)>;
} // let RegInfos = XLenPairRI, DecoderMethod = "DecodeGPRPairRegisterClass"

let RegInfos = XLenPairRI in
def GPRPairC : RISCVRegisterClass<[XLenPairVT], 64, (add
def GPRPairC : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;

Expand Down Expand Up @@ -464,22 +464,6 @@ def GPRF32C : RISCVRegisterClass<[f32], 32, (add (sequence "X%u_W", 10, 15),
(sequence "X%u_W", 8, 9))>;
def GPRF32NoX0 : RISCVRegisterClass<[f32], 32, (sub GPRF32, X0_W)>;

let DecoderMethod = "DecodeGPRPairRegisterClass" in
def GPRF64Pair : RISCVRegisterClass<[XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
X8_X9,
X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
X0_Pair, X2_X3, X4_X5
)>;

def GPRF64PairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;

def GPRF64PairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRF64Pair, X0_Pair)>;

//===----------------------------------------------------------------------===//
// Vector type mapping to LLVM types.
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,42 @@ entry:
ret void
}

define dso_local void @zdinx_asm_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
; CHECK-LABEL: zdinx_asm_inout:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: mv a2, a1
; CHECK-NEXT: #APP
; CHECK-NEXT: fmv.d a2, a2
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: sw a2, 8(a0)
; CHECK-NEXT: sw a3, 12(a0)
; CHECK-NEXT: ret
entry:
%arrayidx = getelementptr inbounds double, ptr %a, i32 1
%0 = tail call double asm "fsgnj.d $0, $1, $1", "=r,0"(double %b)
store double %0, ptr %arrayidx, align 8
ret void
}

define dso_local void @zdinx_asm_Pr_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
; CHECK-LABEL: zdinx_asm_Pr_inout:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: mv a2, a1
; CHECK-NEXT: #APP
; CHECK-NEXT: fabs.d a2, a2
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: sw a2, 8(a0)
; CHECK-NEXT: sw a3, 12(a0)
; CHECK-NEXT: ret
entry:
%arrayidx = getelementptr inbounds double, ptr %a, i32 1
%0 = tail call double asm "fsgnjx.d $0, $1, $1", "=R,0"(double %b)
store double %0, ptr %arrayidx, align 8
ret void
}

define dso_local void @zfinx_asm(ptr nocapture noundef writeonly %a, float noundef %b, float noundef %c) nounwind {
; CHECK-LABEL: zfinx_asm:
; CHECK: # %bb.0: # %entry
Expand Down

0 comments on commit 408659c

Please sign in to comment.