Skip to content

Commit

Permalink
[DAGCombiner] Combine vp.strided.load with unit stride to vp.load
Browse files Browse the repository at this point in the history
This is the VP equivalent of #65674. We already combine MGATHER loads with unit
stride to MLOAD, so this extends it for EXPERIMENTAL_VP_STRIDED_LOAD.
  • Loading branch information
lukel97 committed Sep 19, 2023
1 parent 33dac56 commit 6ef5440
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 28 deletions.
21 changes: 21 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ namespace {
SDValue visitMSCATTER(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
SDValue visitFP_TO_FP16(SDNode *N);
SDValue visitFP16_TO_FP(SDNode *N);
SDValue visitFP_TO_BF16(SDNode *N);
Expand Down Expand Up @@ -11959,6 +11960,22 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
// Combine strided loads with unit-stride to a regular load.
if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
SDValue NewLd = DAG.getLoadVP(
SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
SLD->getMemOperand(), SLD->isExpandingLoad());
return CombineTo(N, NewLd, NewLd.getValue(1));
}
return SDValue();
}

/// A vector select of 2 constant vectors can be simplified to math/logic to
/// avoid a variable select instruction and possibly avoid constant loads.
SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
Expand Down Expand Up @@ -25976,6 +25993,10 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
if (SDValue SD = visitVPSCATTER(N))
return SD;

if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
if (SDValue SD = visitVP_STRIDED_LOAD(N))
return SD;

// VP operations in which all vector elements are disabled - either by
// determining that the mask is all false or that the EVL is 0 - can be
// eliminated.
Expand Down
21 changes: 7 additions & 14 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-vpload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ define <8 x i8> @strided_vpload_v8i8(ptr %ptr, i32 signext %stride, <8 x i1> %m,
define <8 x i8> @strided_vpload_v8i8_unit_stride(ptr %ptr, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v8i8_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 1
; CHECK-NEXT: vsetvli zero, a1, e8, mf2, ta, ma
; CHECK-NEXT: vlse8.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <8 x i8> @llvm.experimental.vp.strided.load.v8i8.p0.i32(ptr %ptr, i32 1, <8 x i1> %m, i32 %evl)
ret <8 x i8> %load
Expand Down Expand Up @@ -146,9 +145,8 @@ define <8 x i16> @strided_vpload_v8i16(ptr %ptr, i32 signext %stride, <8 x i1> %
define <8 x i16> @strided_vpload_v8i16_unit_stride(ptr %ptr, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v8i16_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 2
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vlse16.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <8 x i16> @llvm.experimental.vp.strided.load.v8i16.p0.i32(ptr %ptr, i32 2, <8 x i1> %m, i32 %evl)
ret <8 x i16> %load
Expand Down Expand Up @@ -193,9 +191,8 @@ define <4 x i32> @strided_vpload_v4i32(ptr %ptr, i32 signext %stride, <4 x i1> %
define <4 x i32> @strided_vpload_v4i32_unit_stride(ptr %ptr, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v4i32_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 4
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle32.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i32(ptr %ptr, i32 4, <4 x i1> %m, i32 %evl)
ret <4 x i32> %load
Expand Down Expand Up @@ -240,9 +237,8 @@ define <2 x i64> @strided_vpload_v2i64(ptr %ptr, i32 signext %stride, <2 x i1> %
define <2 x i64> @strided_vpload_v2i64_unit_stride(ptr %ptr, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v2i64_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 8
; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma
; CHECK-NEXT: vlse64.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle64.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <2 x i64> @llvm.experimental.vp.strided.load.v2i64.p0.i32(ptr %ptr, i32 8, <2 x i1> %m, i32 %evl)
ret <2 x i64> %load
Expand Down Expand Up @@ -335,9 +331,8 @@ define <8 x half> @strided_vpload_v8f16(ptr %ptr, i32 signext %stride, <8 x i1>
define <8 x half> @strided_vpload_v8f16_unit_stride(ptr %ptr, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v8f16_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 2
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vlse16.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <8 x half> @llvm.experimental.vp.strided.load.v8f16.p0.i32(ptr %ptr, i32 2, <8 x i1> %m, i32 %evl)
ret <8 x half> %load
Expand Down Expand Up @@ -370,9 +365,8 @@ define <4 x float> @strided_vpload_v4f32(ptr %ptr, i32 signext %stride, <4 x i1>
define <4 x float> @strided_vpload_v4f32_unit_stride(ptr %ptr, <4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v4f32_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 4
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle32.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <4 x float> @llvm.experimental.vp.strided.load.v4f32.p0.i32(ptr %ptr, i32 4, <4 x i1> %m, i32 %evl)
ret <4 x float> %load
Expand Down Expand Up @@ -417,9 +411,8 @@ define <2 x double> @strided_vpload_v2f64(ptr %ptr, i32 signext %stride, <2 x i1
define <2 x double> @strided_vpload_v2f64_unit_stride(ptr %ptr, <2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_v2f64_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 8
; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma
; CHECK-NEXT: vlse64.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle64.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <2 x double> @llvm.experimental.vp.strided.load.v2f64.p0.i32(ptr %ptr, i32 8, <2 x i1> %m, i32 %evl)
ret <2 x double> %load
Expand Down
21 changes: 7 additions & 14 deletions llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ define <vscale x 8 x i8> @strided_vpload_nxv8i8(ptr %ptr, i32 signext %stride, <
define <vscale x 8 x i8> @strided_vpload_nxv8i8_unit_stride(ptr %ptr, <vscale x 8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv8i8_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 1
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma
; CHECK-NEXT: vlse8.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 8 x i8> @llvm.experimental.vp.strided.load.nxv8i8.p0.i32(ptr %ptr, i32 1, <vscale x 8 x i1> %m, i32 %evl)
ret <vscale x 8 x i8> %load
Expand Down Expand Up @@ -200,9 +199,8 @@ define <vscale x 4 x i16> @strided_vpload_nxv4i16(ptr %ptr, i32 signext %stride,
define <vscale x 4 x i16> @strided_vpload_nxv4i16_unit_stride(ptr %ptr, <vscale x 4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv4i16_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 2
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vlse16.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 4 x i16> @llvm.experimental.vp.strided.load.nxv4i16.p0.i32(ptr %ptr, i32 2, <vscale x 4 x i1> %m, i32 %evl)
ret <vscale x 4 x i16> %load
Expand Down Expand Up @@ -247,9 +245,8 @@ define <vscale x 2 x i32> @strided_vpload_nxv2i32(ptr %ptr, i32 signext %stride,
define <vscale x 2 x i32> @strided_vpload_nxv2i32_unit_stride(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv2i32_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 4
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle32.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 2 x i32> @llvm.experimental.vp.strided.load.nxv2i32.p0.i32(ptr %ptr, i32 4, <vscale x 2 x i1> %m, i32 %evl)
ret <vscale x 2 x i32> %load
Expand Down Expand Up @@ -306,9 +303,8 @@ define <vscale x 1 x i64> @strided_vpload_nxv1i64(ptr %ptr, i32 signext %stride,
define <vscale x 1 x i64> @strided_vpload_nxv1i64_unit_stride(ptr %ptr, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv1i64_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 8
; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma
; CHECK-NEXT: vlse64.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle64.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 8, <vscale x 1 x i1> %m, i32 %evl)
ret <vscale x 1 x i64> %load
Expand Down Expand Up @@ -413,9 +409,8 @@ define <vscale x 4 x half> @strided_vpload_nxv4f16(ptr %ptr, i32 signext %stride
define <vscale x 4 x half> @strided_vpload_nxv4f16_unit_stride(ptr %ptr, <vscale x 4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv4f16_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 2
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
; CHECK-NEXT: vlse16.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle16.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 4 x half> @llvm.experimental.vp.strided.load.nxv4f16.p0.i32(ptr %ptr, i32 2, <vscale x 4 x i1> %m, i32 %evl)
ret <vscale x 4 x half> %load
Expand Down Expand Up @@ -460,9 +455,8 @@ define <vscale x 2 x float> @strided_vpload_nxv2f32(ptr %ptr, i32 signext %strid
define <vscale x 2 x float> @strided_vpload_nxv2f32_unit_stride(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv2f32_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 4
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle32.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 2 x float> @llvm.experimental.vp.strided.load.nxv2f32.p0.i32(ptr %ptr, i32 4, <vscale x 2 x i1> %m, i32 %evl)
ret <vscale x 2 x float> %load
Expand Down Expand Up @@ -519,9 +513,8 @@ define <vscale x 1 x double> @strided_vpload_nxv1f64(ptr %ptr, i32 signext %stri
define <vscale x 1 x double> @strided_vpload_nxv1f64_unit_stride(ptr %ptr, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: strided_vpload_nxv1f64_unit_stride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a2, 8
; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma
; CHECK-NEXT: vlse64.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle64.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 1 x double> @llvm.experimental.vp.strided.load.nxv1f64.p0.i32(ptr %ptr, i32 8, <vscale x 1 x i1> %m, i32 %evl)
ret <vscale x 1 x double> %load
Expand Down

0 comments on commit 6ef5440

Please sign in to comment.