Skip to content

Commit

Permalink
Remove two subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
tclin914 committed Aug 19, 2024
1 parent c85f7ff commit 88399e1
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 132 deletions.
4 changes: 2 additions & 2 deletions llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,11 +1085,11 @@ void RISCVAsmPrinter::emitMachineConstantPoolValue(
MCSymbol *MCSym;

if (RCPV->isGlobalValue()) {
auto *GV = cast<RISCVConstantPoolConstant>(RCPV)->getGlobalValue();
auto *GV = RCPV->getGlobalValue();
MCSym = getSymbol(GV);
} else {
assert(RCPV->isExtSymbol() && "unrecognized constant pool value");
auto Sym = cast<RISCVConstantPoolSymbol>(RCPV)->getSymbol();
auto Sym = RCPV->getSymbol();
MCSym = GetExternalSymbolSymbol(Sym);
}

Expand Down
98 changes: 44 additions & 54 deletions llvm/lib/Target/RISCV/RISCVConstantPoolValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,71 +21,61 @@

using namespace llvm;

RISCVConstantPoolValue::RISCVConstantPoolValue(LLVMContext &C, RISCVCPKind Kind)
: MachineConstantPoolValue((Type *)Type::getInt64Ty(C)), Kind(Kind) {}
RISCVConstantPoolValue::RISCVConstantPoolValue(Type *Ty, const GlobalValue *GV)
: MachineConstantPoolValue(Ty), GV(GV), Kind(RISCVCPKind::GlobalValue) {}

RISCVConstantPoolValue::RISCVConstantPoolValue(Type *Ty, RISCVCPKind Kind)
: MachineConstantPoolValue(Ty), Kind(Kind) {}
RISCVConstantPoolValue::RISCVConstantPoolValue(LLVMContext &C, StringRef S)
: MachineConstantPoolValue((Type *)Type::getInt64Ty(C)), S(S),
Kind(RISCVCPKind::ExtSymbol) {}

int RISCVConstantPoolValue::getExistingMachineCPValue(MachineConstantPool *CP,
Align Alignment) {
llvm_unreachable("Shouldn't be calling this directly!");
}

RISCVConstantPoolConstant::RISCVConstantPoolConstant(Type *Ty,
const Constant *GV,
RISCVCPKind Kind)
: RISCVConstantPoolValue(Ty, Kind), CVal(GV) {}

RISCVConstantPoolConstant *
RISCVConstantPoolConstant::Create(const GlobalValue *GV) {
return new RISCVConstantPoolConstant(GV->getType(), GV,
RISCVCPKind::GlobalValue);
}

RISCVConstantPoolConstant *
RISCVConstantPoolConstant::Create(const BlockAddress *BA) {
return new RISCVConstantPoolConstant(BA->getType(), BA,
RISCVCPKind::BlockAddress);
}

int RISCVConstantPoolConstant::getExistingMachineCPValue(
MachineConstantPool *CP, Align Alignment) {
return getExistingMachineCPValueImpl<RISCVConstantPoolConstant>(CP,
Alignment);
}

void RISCVConstantPoolConstant::addSelectionDAGCSEId(FoldingSetNodeID &ID) {
ID.AddPointer(CVal);
RISCVConstantPoolValue *RISCVConstantPoolValue::Create(const GlobalValue *GV) {
return new RISCVConstantPoolValue(GV->getType(), GV);
}

void RISCVConstantPoolConstant::print(raw_ostream &O) const {
O << CVal->getName();
RISCVConstantPoolValue *RISCVConstantPoolValue::Create(LLVMContext &C,
StringRef s) {
return new RISCVConstantPoolValue(C, s);
}

const GlobalValue *RISCVConstantPoolConstant::getGlobalValue() const {
return dyn_cast_or_null<GlobalValue>(CVal);
int RISCVConstantPoolValue::getExistingMachineCPValue(MachineConstantPool *CP,
Align Alignment) {
const std::vector<MachineConstantPoolEntry> &Constants = CP->getConstants();
for (unsigned i = 0, e = Constants.size(); i != e; ++i) {
if (Constants[i].isMachineConstantPoolEntry() &&
Constants[i].getAlign() >= Alignment) {
auto *CPV =
static_cast<RISCVConstantPoolValue *>(Constants[i].Val.MachineCPVal);
if (equals(CPV))
return i;
}
}

return -1;
}

const BlockAddress *RISCVConstantPoolConstant::getBlockAddress() const {
return dyn_cast_or_null<BlockAddress>(CVal);
void RISCVConstantPoolValue::addSelectionDAGCSEId(FoldingSetNodeID &ID) {
if (isGlobalValue())
ID.AddPointer(GV);
else {
assert(isExtSymbol() && "unrecognized constant pool type");
ID.AddString(S);
}
}

RISCVConstantPoolSymbol::RISCVConstantPoolSymbol(LLVMContext &C, StringRef s)
: RISCVConstantPoolValue(C, RISCVCPKind::ExtSymbol), S(s) {}

RISCVConstantPoolSymbol *RISCVConstantPoolSymbol::Create(LLVMContext &C,
StringRef s) {
return new RISCVConstantPoolSymbol(C, s);
void RISCVConstantPoolValue::print(raw_ostream &O) const {
if (isGlobalValue())
O << GV->getName();
else {
assert(isExtSymbol() && "unrecognized constant pool type");
O << S;
}
}

int RISCVConstantPoolSymbol::getExistingMachineCPValue(MachineConstantPool *CP,
Align Alignment) {
return getExistingMachineCPValueImpl<RISCVConstantPoolSymbol>(CP, Alignment);
}
bool RISCVConstantPoolValue::equals(const RISCVConstantPoolValue *A) const {
if (isGlobalValue() && A->isGlobalValue())
return GV == A->GV;
else if (isExtSymbol() && A->isExtSymbol())
return S == A->S;

void RISCVConstantPoolSymbol::addSelectionDAGCSEId(FoldingSetNodeID &ID) {
ID.AddString(S);
return false;
}

void RISCVConstantPoolSymbol::print(raw_ostream &O) const { O << S; }
84 changes: 12 additions & 72 deletions llvm/lib/Target/RISCV/RISCVConstantPoolValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,26 @@ class LLVMContext;

/// A RISCV-specific constant pool value.
class RISCVConstantPoolValue : public MachineConstantPoolValue {
protected:
enum class RISCVCPKind { ExtSymbol, GlobalValue, BlockAddress };

RISCVConstantPoolValue(LLVMContext &C, RISCVCPKind Kind);

RISCVConstantPoolValue(Type *Ty, RISCVCPKind Kind);

template <typename Derived>
int getExistingMachineCPValueImpl(MachineConstantPool *CP, Align Alignment) {
const std::vector<MachineConstantPoolEntry> &Constants = CP->getConstants();
for (unsigned i = 0, e = Constants.size(); i != e; ++i) {
if (Constants[i].isMachineConstantPoolEntry() &&
Constants[i].getAlign() >= Alignment) {
auto *CPV = static_cast<RISCVConstantPoolValue *>(
Constants[i].Val.MachineCPVal);
if (Derived *APC = dyn_cast<Derived>(CPV))
if (cast<Derived>(this)->equals(APC))
return i;
}
}

return -1;
}
const GlobalValue *GV;
const StringRef S;

RISCVConstantPoolValue(Type *Ty, const GlobalValue *GV);
RISCVConstantPoolValue(LLVMContext &C, StringRef s);

private:
enum class RISCVCPKind { ExtSymbol, GlobalValue };
RISCVCPKind Kind;

public:
~RISCVConstantPoolValue() = default;

bool isExtSymbol() const { return Kind == RISCVCPKind::ExtSymbol; }
bool isGlobalValue() const { return Kind == RISCVCPKind::GlobalValue; }
bool isBlockAddress() const { return Kind == RISCVCPKind::BlockAddress; }

int getExistingMachineCPValue(MachineConstantPool *CP,
Align Alignment) override;

void addSelectionDAGCSEId(FoldingSetNodeID &ID) override {}
};

class RISCVConstantPoolConstant : public RISCVConstantPoolValue {
const Constant *CVal;

RISCVConstantPoolConstant(Type *Ty, const Constant *GV, RISCVCPKind Kind);

public:
static RISCVConstantPoolConstant *Create(const GlobalValue *GV);
static RISCVConstantPoolConstant *Create(const BlockAddress *BA);

const GlobalValue *getGlobalValue() const;
const BlockAddress *getBlockAddress() const;
static RISCVConstantPoolValue *Create(const GlobalValue *GV);
static RISCVConstantPoolValue *Create(LLVMContext &C, StringRef s);

int getExistingMachineCPValue(MachineConstantPool *CP,
Align Alignment) override;

void addSelectionDAGCSEId(FoldingSetNodeID &ID) override;

void print(raw_ostream &O) const override;

bool equals(const RISCVConstantPoolConstant *A) const {
return CVal == A->CVal;
}

static bool classof(const RISCVConstantPoolValue *RCPV) {
return RCPV->isGlobalValue() || RCPV->isBlockAddress();
}
};

class RISCVConstantPoolSymbol : public RISCVConstantPoolValue {
const StringRef S;

RISCVConstantPoolSymbol(LLVMContext &C, StringRef s);

public:
static RISCVConstantPoolSymbol *Create(LLVMContext &C, StringRef s);
bool isGlobalValue() const { return Kind == RISCVCPKind::GlobalValue; }
bool isExtSymbol() const { return Kind == RISCVCPKind::ExtSymbol; }

const GlobalValue *getGlobalValue() const { return GV; }
StringRef getSymbol() const { return S; }

int getExistingMachineCPValue(MachineConstantPool *CP,
Expand All @@ -111,11 +55,7 @@ class RISCVConstantPoolSymbol : public RISCVConstantPoolValue {

void print(raw_ostream &O) const override;

bool equals(const RISCVConstantPoolSymbol *A) const { return S == A->S; }

static bool classof(const RISCVConstantPoolValue *RCPV) {
return RCPV->isExtSymbol();
}
bool equals(const RISCVConstantPoolValue *A) const;
};

} // end namespace llvm
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7303,8 +7303,7 @@ static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty,

static SDValue getLargeGlobalAddress(GlobalAddressSDNode *N, SDLoc DL, EVT Ty,
SelectionDAG &DAG) {
RISCVConstantPoolConstant *CPV =
RISCVConstantPoolConstant::Create(N->getGlobal());
RISCVConstantPoolValue *CPV = RISCVConstantPoolValue::Create(N->getGlobal());
SDValue CPAddr = DAG.getTargetConstantPool(CPV, Ty, Align(8));
SDValue LC = DAG.getNode(RISCVISD::LLA, DL, Ty, CPAddr);
return DAG.getLoad(
Expand All @@ -7314,8 +7313,8 @@ static SDValue getLargeGlobalAddress(GlobalAddressSDNode *N, SDLoc DL, EVT Ty,

static SDValue getLargeExternalSymbol(ExternalSymbolSDNode *N, SDLoc DL, EVT Ty,
SelectionDAG &DAG) {
RISCVConstantPoolSymbol *CPV =
RISCVConstantPoolSymbol::Create(*DAG.getContext(), N->getSymbol());
RISCVConstantPoolValue *CPV =
RISCVConstantPoolValue::Create(*DAG.getContext(), N->getSymbol());
SDValue CPAddr = DAG.getTargetConstantPool(CPV, Ty, Align(8));
SDValue LC = DAG.getNode(RISCVISD::LLA, DL, Ty, CPAddr);
return DAG.getLoad(
Expand Down

0 comments on commit 88399e1

Please sign in to comment.