Skip to content

Commit

Permalink
[RISCV] Add basic cost model for vector casting
Browse files Browse the repository at this point in the history
To perform the cost model of vector casting, the patch consider most vector
casts as their scalar form and consider those vector form of free scalr castings
as 1.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D121771
  • Loading branch information
Yeting Kuo committed Mar 22, 2022
1 parent 23423c0 commit ecd7a01
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 0 deletions.
48 changes: 48 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/TargetLowering.h"
#include <cmath>
using namespace llvm;

#define DEBUG_TYPE "riscvtti"
Expand Down Expand Up @@ -218,6 +219,53 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
return NumLoads * MemOpCost;
}

InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
Type *Src,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I) {
if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
// FIXME: Need to compute legalizing cost for illegal types.
if (!isTypeLegal(Src) || !isTypeLegal(Dst))
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

// Skip if element size of Dst or Src is bigger than ELEN.
if (Src->getScalarSizeInBits() > ST->getMaxELENForFixedLengthVectors() ||
Dst->getScalarSizeInBits() > ST->getMaxELENForFixedLengthVectors())
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);

int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");

// FIXME: Need to consider vsetvli and lmul.
int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
(int)Log2_32(Src->getScalarSizeInBits());
switch (ISD) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return 1;
case ISD::TRUNCATE:
case ISD::FP_EXTEND:
case ISD::FP_ROUND:
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
if (std::abs(PowDiff) <= 1)
return 1;
// Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
// so it only need two conversion.
if (Src->isIntOrIntVectorTy())
return 2;
// Counts of narrow/widen instructions.
return std::abs(PowDiff);
}
}
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}

InstructionCost
RISCVTTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
bool IsUnsigned,
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
TTI::TargetCostKind CostKind,
const Instruction *I);

InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);

InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
bool IsUnsigned,
TTI::TargetCostKind CostKind);
Expand Down
Loading

0 comments on commit ecd7a01

Please sign in to comment.