Skip to content

Commit

Permalink
[RISCV] Handle zeroinitializer of vector tuple Type (#113995)
Browse files Browse the repository at this point in the history
It doesn't make sense to add a new generic ISD to handle riscv tuple
type. Instead we use `SPLAT_VECTOR` for ISD and further lower to
`VMV_V_X`.

Note: If there's `visitSPLAT_VECTOR` in generic DAG combiner, it needs
to skip riscv vector tuple type.

Stack on #114329
  • Loading branch information
4vtomat authored Dec 4, 2024
1 parent 4f41862 commit 109e4a1
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
12 changes: 12 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,18 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
}

if (VT.isRISCVVectorTuple()) {
assert(C->isNullValue() && "Can only zero this target type!");
return NodeMap[V] = DAG.getNode(
ISD::BITCAST, getCurSDLoc(), VT,
DAG.getNode(
ISD::SPLAT_VECTOR, getCurSDLoc(),
EVT::getVectorVT(*DAG.getContext(), MVT::i8,
VT.getSizeInBits().getKnownMinValue() / 8,
true),
DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8))));
}

VectorType *VecTy = cast<VectorType>(V->getType());

// Now that we know the number and type of the elements, get that number of
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
Ty->getIntParameter(0);
return TargetTypeInfo(
ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts),
TargetExtType::CanBeLocal);
TargetExtType::CanBeLocal, TargetExtType::HasZeroInit);
}

// DirectX resources
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18060,6 +18060,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT SrcVT = N0.getValueType();
if (VT.isRISCVVectorTuple() && N0->getOpcode() == ISD::SPLAT_VECTOR) {
unsigned NF = VT.getRISCVVectorTupleNumFields();
unsigned NumScalElts = VT.getSizeInBits().getKnownMinValue() / (NF * 8);
SDValue EltVal = DAG.getConstant(0, DL, Subtarget.getXLenVT());
MVT ScalTy = MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts);

SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalTy, EltVal);

SDValue Result = DAG.getUNDEF(VT);
for (unsigned i = 0; i < NF; ++i)
Result = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result, Splat,
DAG.getVectorIdxConstant(i, DL));
return Result;
}
// If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
// type, widen both sides to avoid a trip through memory.
if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) &&
Expand Down
52 changes: 52 additions & 0 deletions llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=riscv32 -mattr=+v \
; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK
; RUN: llc -mtriple=riscv64 -mattr=+v \
; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK

define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_power_of_2() {
; CHECK-LABEL: test_tuple_zero_power_of_2:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmv.v.i v10, 0
; CHECK-NEXT: ret
entry:
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer
}

define target("riscv.vector.tuple", <vscale x 16 x i8>, 3) @test_tuple_zero_non_power_of_2() {
; CHECK-LABEL: test_tuple_zero_non_power_of_2:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmv.v.i v10, 0
; CHECK-NEXT: vmv.v.i v12, 0
; CHECK-NEXT: ret
entry:
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 3) zeroinitializer
}

define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_insert1(<vscale x 4 x i32> %a) {
; CHECK-LABEL: test_tuple_zero_insert1:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; CHECK-NEXT: vmv.v.i v10, 0
; CHECK-NEXT: ret
entry:
%1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 0)
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
}

define target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @test_tuple_zero_insert2(<vscale x 4 x i32> %a) {
; CHECK-LABEL: test_tuple_zero_insert2:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma
; CHECK-NEXT: vmv.v.i v6, 0
; CHECK-NEXT: vmv2r.v v10, v8
; CHECK-NEXT: vmv2r.v v8, v6
; CHECK-NEXT: ret
entry:
%1 = call target("riscv.vector.tuple", <vscale x 16 x i8>, 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) zeroinitializer, <vscale x 4 x i32> %a, i32 1)
ret target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %1
}

0 comments on commit 109e4a1

Please sign in to comment.