Skip to content

[NVPTX] Legalize ctpop and ctlz in operation legalization #130668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
@@ -5113,7 +5113,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
DAG.getConstant(NVT.getSizeInBits() -
OVT.getSizeInBits(), dl, NVT));
}
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
Results.push_back(
DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1, SDNodeFlags::NoWrap));
break;
}
case ISD::CTLZ_ZERO_UNDEF: {
31 changes: 22 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
@@ -764,16 +764,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Custom handling for i8 intrinsics
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);

for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
setOperationAction(ISD::ABS, Ty, Legal);
setOperationAction(ISD::SMIN, Ty, Legal);
setOperationAction(ISD::SMAX, Ty, Legal);
setOperationAction(ISD::UMIN, Ty, Legal);
setOperationAction(ISD::UMAX, Ty, Legal);
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
{MVT::i16, MVT::i32, MVT::i64}, Legal);

setOperationAction(ISD::CTPOP, Ty, Legal);
setOperationAction(ISD::CTLZ, Ty, Legal);
}
setOperationAction({ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, MVT::i16,
Promote);
setOperationAction({ISD::CTPOP, ISD::CTLZ}, MVT::i32, Legal);
setOperationAction({ISD::CTPOP, ISD::CTLZ}, MVT::i64, Custom);

setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
@@ -2748,6 +2745,19 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
return Op;
}

// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
// Lower these into a node returning the correct type which is zero-extended
// back to the correct size.
static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
SDValue V = Op->getOperand(0);
assert(V.getValueType() == MVT::i64 &&
"Unexpected CTLZ/CTPOP type to legalize");

SDLoc DL(Op);
SDValue CT = DAG.getNode(Op->getOpcode(), DL, MVT::i32, V);
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT, SDNodeFlags::NonNeg);
}

SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2833,6 +2843,9 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::FMUL:
// Used only for bf16 on SM80, where we select fma for non-ftz operation
return PromoteBinOpIfF32FTZ(Op, DAG);
case ISD::CTPOP:
case ISD::CTLZ:
return lowerCTLZCTPOP(Op, DAG);

default:
llvm_unreachable("Custom lowering not defined for operation");
71 changes: 11 additions & 60 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
@@ -3267,69 +3267,20 @@ def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, i32:$amt)),
def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, (i32 imm:$amt))),
(SHF_R_CLAMP_i $lo, $hi, imm:$amt)>;

// Count leading zeros
let hasSideEffects = false in {
def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
"clz.b32 \t$d, $a;", []>;
def CLZr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a),
"clz.b64 \t$d, $a;", []>;
foreach RT = [I32RT, I64RT] in {
// Count leading zeros
def CLZr # RT.Size : NVPTXInst<(outs Int32Regs:$d), (ins RT.RC:$a),
"clz.b" # RT.Size # " \t$d, $a;",
[(set i32:$d, (ctlz RT.Ty:$a))]>;

// Population count
def POPCr # RT.Size : NVPTXInst<(outs Int32Regs:$d), (ins RT.RC:$a),
"popc.b" # RT.Size # " \t$d, $a;",
[(set i32:$d, (ctpop RT.Ty:$a))]>;
}
}

// 32-bit has a direct PTX instruction
def : Pat<(i32 (ctlz i32:$a)), (CLZr32 $a)>;

// The return type of the ctlz ISD node is the same as its input, but the PTX
// ctz instruction always returns a 32-bit value. For ctlz.i64, convert the
// ptx value to 64 bits to match the ISD node's semantics, unless we know we're
// truncating back down to 32 bits.
def : Pat<(i64 (ctlz i64:$a)), (CVT_u64_u32 (CLZr64 $a), CvtNONE)>;
def : Pat<(i32 (trunc (i64 (ctlz i64:$a)))), (CLZr64 $a)>;

// For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
// result back to 16-bits if necessary. We also need to subtract 16 because
// the high-order 16 zeros were counted.
//
// TODO: NVPTX has a mov.b32 b32reg, {imm, b16reg} instruction, which we could
// use to save one SASS instruction (on sm_35 anyway):
//
// mov.b32 $tmp, {0xffff, $a}
// ctlz.b32 $result, $tmp
//
// That is, instead of zero-extending the input to 32 bits, we'd "one-extend"
// and then ctlz that value. This way we don't have to subtract 16 from the
// result. Unfortunately today we don't have a way to generate
// "mov b32reg, {b16imm, b16reg}", so we don't do this optimization.
def : Pat<(i16 (ctlz i16:$a)),
(SUBi16ri (CVT_u16_u32
(CLZr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE), 16)>;
def : Pat<(i32 (zext (i16 (ctlz i16:$a)))),
(SUBi32ri (CLZr32 (CVT_u32_u16 $a, CvtNONE)), 16)>;

// Population count
let hasSideEffects = false in {
def POPCr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
"popc.b32 \t$d, $a;", []>;
def POPCr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a),
"popc.b64 \t$d, $a;", []>;
}

// 32-bit has a direct PTX instruction
def : Pat<(i32 (ctpop i32:$a)), (POPCr32 $a)>;

// For 64-bit, the result in PTX is actually 32-bit so we zero-extend to 64-bit
// to match the LLVM semantics. Just as with ctlz.i64, we provide a second
// pattern that avoids the type conversion if we're truncating the result to
// i32 anyway.
def : Pat<(ctpop i64:$a), (CVT_u64_u32 (POPCr64 $a), CvtNONE)>;
def : Pat<(i32 (trunc (i64 (ctpop i64:$a)))), (POPCr64 $a)>;

// For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
// If we know that we're storing into an i32, we can avoid the final trunc.
def : Pat<(ctpop i16:$a),
(CVT_u16_u32 (POPCr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE)>;
def : Pat<(i32 (zext (i16 (ctpop i16:$a)))),
(POPCr32 (CVT_u32_u16 $a, CvtNONE))>;

// fpround f32 -> f16
def : Pat<(f16 (fpround f32:$a)),
(CVT_f16_f32 $a, CvtRN)>;
167 changes: 105 additions & 62 deletions llvm/test/CodeGen/NVPTX/ctlz.ll
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}

@@ -10,67 +11,95 @@ declare i64 @llvm.ctlz.i64(i64, i1) readnone
; There should be no difference between llvm.ctlz.i32(%a, true) and
; llvm.ctlz.i32(%a, false), as ptx's clz(0) is defined to return 0.

; CHECK-LABEL: myctlz(
define i32 @myctlz(i32 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b32
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [myctlz_param_0];
; CHECK-NEXT: clz.b32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%val = call i32 @llvm.ctlz.i32(i32 %a, i1 false) readnone
ret i32 %val
}
; CHECK-LABEL: myctlz_2(
define i32 @myctlz_2(i32 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b32
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz_2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [myctlz_2_param_0];
; CHECK-NEXT: clz.b32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%val = call i32 @llvm.ctlz.i32(i32 %a, i1 true) readnone
ret i32 %val
}

; PTX's clz.b64 returns a 32-bit value, but LLVM's intrinsic returns a 64-bit
; value, so here we have to zero-extend it.
; CHECK-LABEL: myctlz64(
define i64 @myctlz64(i64 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b64
; CHECK-NEXT: cvt.u64.u32
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz64(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz64_param_0];
; CHECK-NEXT: clz.b64 %r1, %rd1;
; CHECK-NEXT: cvt.u64.u32 %rd2, %r1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%val = call i64 @llvm.ctlz.i64(i64 %a, i1 false) readnone
ret i64 %val
}
; CHECK-LABEL: myctlz64_2(
define i64 @myctlz64_2(i64 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b64
; CHECK-NEXT: cvt.u64.u32
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz64_2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz64_2_param_0];
; CHECK-NEXT: clz.b64 %r1, %rd1;
; CHECK-NEXT: cvt.u64.u32 %rd2, %r1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%val = call i64 @llvm.ctlz.i64(i64 %a, i1 true) readnone
ret i64 %val
}

; Here we truncate the 64-bit value of LLVM's ctlz intrinsic to 32 bits, the
; natural return width of ptx's clz.b64 instruction. No conversions should be
; necessary in the PTX.
; CHECK-LABEL: myctlz64_as_32(
define i32 @myctlz64_as_32(i64 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b64
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz64_as_32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz64_as_32_param_0];
; CHECK-NEXT: clz.b64 %r1, %rd1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
; CHECK-NEXT: ret;
%val = call i64 @llvm.ctlz.i64(i64 %a, i1 false) readnone
%trunc = trunc i64 %val to i32
ret i32 %trunc
}
; CHECK-LABEL: myctlz64_as_32_2(
define i32 @myctlz64_as_32_2(i64 %a) {
; CHECK: ld.param.
; CHECK-NEXT: clz.b64
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz64_as_32_2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz64_as_32_2_param_0];
; CHECK-NEXT: clz.b64 %r1, %rd1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
; CHECK-NEXT: ret;
%val = call i64 @llvm.ctlz.i64(i64 %a, i1 false) readnone
%trunc = trunc i64 %val to i32
ret i32 %trunc
@@ -80,53 +109,67 @@ define i32 @myctlz64_as_32_2(i64 %a) {
; and then truncating the result back down to i16. But the NVPTX ABI
; zero-extends i16 return values to i32, so the final truncation doesn't appear
; in this function.
; CHECK-LABEL: myctlz_ret16(
define i16 @myctlz_ret16(i16 %a) {
; CHECK: ld.param.
; CHECK-NEXT: cvt.u32.u16
; CHECK-NEXT: clz.b32
; CHECK-NEXT: sub.
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz_ret16(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_ret16_param_0];
; CHECK-NEXT: clz.b32 %r2, %r1;
; CHECK-NEXT: add.s32 %r3, %r2, -16;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
ret i16 %val
}
; CHECK-LABEL: myctlz_ret16_2(
define i16 @myctlz_ret16_2(i16 %a) {
; CHECK: ld.param.
; CHECK-NEXT: cvt.u32.u16
; CHECK-NEXT: clz.b32
; CHECK-NEXT: sub.
; CHECK-NEXT: st.param.
; CHECK-NEXT: ret;
; CHECK-LABEL: myctlz_ret16_2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_ret16_2_param_0];
; CHECK-NEXT: shl.b32 %r2, %r1, 16;
; CHECK-NEXT: clz.b32 %r3, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 true) readnone
ret i16 %val
}

; Here we store the result of ctlz.16 into an i16 pointer, so the trunc should
; remain.
; CHECK-LABEL: myctlz_store16(
define void @myctlz_store16(i16 %a, ptr %b) {
; CHECK: ld.param.
; CHECK-NEXT: cvt.u32.u16
; CHECK-NEXT: clz.b32
; CHECK-DAG: cvt.u16.u32
; CHECK-DAG: sub.
; CHECK: st.{{[a-z]}}16
; CHECK: ret;
; CHECK-LABEL: myctlz_store16(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_store16_param_0];
; CHECK-NEXT: clz.b32 %r2, %r1;
; CHECK-NEXT: add.s32 %r3, %r2, -16;
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz_store16_param_1];
; CHECK-NEXT: st.u16 [%rd1], %r3;
; CHECK-NEXT: ret;
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
store i16 %val, ptr %b
ret void
}
; CHECK-LABEL: myctlz_store16_2(
define void @myctlz_store16_2(i16 %a, ptr %b) {
; CHECK: ld.param.
; CHECK-NEXT: cvt.u32.u16
; CHECK-NEXT: clz.b32
; CHECK-DAG: cvt.u16.u32
; CHECK-DAG: sub.
; CHECK: st.{{[a-z]}}16
; CHECK: ret;
; CHECK-LABEL: myctlz_store16_2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_store16_2_param_0];
; CHECK-NEXT: clz.b32 %r2, %r1;
; CHECK-NEXT: add.s32 %r3, %r2, -16;
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz_store16_2_param_1];
; CHECK-NEXT: st.u16 [%rd1], %r3;
; CHECK-NEXT: ret;
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
store i16 %val, ptr %b
ret void
Loading