Skip to content

Commit 729701b

Browse files
[NVPTX] Constant fold NVVM add/mul/div/fma (llvm#152544)
Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.
1 parent eac19d4 commit 729701b

File tree

6 files changed

+3974
-0
lines changed

6 files changed

+3974
-0
lines changed

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
414414
return DenormalMode::getIEEE();
415415
}
416416

417+
inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
418+
switch (IntrinsicID) {
419+
case Intrinsic::nvvm_add_rm_ftz_f:
420+
case Intrinsic::nvvm_add_rn_ftz_f:
421+
case Intrinsic::nvvm_add_rp_ftz_f:
422+
case Intrinsic::nvvm_add_rz_ftz_f:
423+
return true;
424+
425+
case Intrinsic::nvvm_add_rm_f:
426+
case Intrinsic::nvvm_add_rn_f:
427+
case Intrinsic::nvvm_add_rp_f:
428+
case Intrinsic::nvvm_add_rz_f:
429+
case Intrinsic::nvvm_add_rm_d:
430+
case Intrinsic::nvvm_add_rn_d:
431+
case Intrinsic::nvvm_add_rp_d:
432+
case Intrinsic::nvvm_add_rz_d:
433+
return false;
434+
}
435+
llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
436+
}
437+
438+
inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
439+
switch (IntrinsicID) {
440+
case Intrinsic::nvvm_add_rm_f:
441+
case Intrinsic::nvvm_add_rm_d:
442+
case Intrinsic::nvvm_add_rm_ftz_f:
443+
return APFloat::rmTowardNegative;
444+
case Intrinsic::nvvm_add_rn_f:
445+
case Intrinsic::nvvm_add_rn_d:
446+
case Intrinsic::nvvm_add_rn_ftz_f:
447+
return APFloat::rmNearestTiesToEven;
448+
case Intrinsic::nvvm_add_rp_f:
449+
case Intrinsic::nvvm_add_rp_d:
450+
case Intrinsic::nvvm_add_rp_ftz_f:
451+
return APFloat::rmTowardPositive;
452+
case Intrinsic::nvvm_add_rz_f:
453+
case Intrinsic::nvvm_add_rz_d:
454+
case Intrinsic::nvvm_add_rz_ftz_f:
455+
return APFloat::rmTowardZero;
456+
}
457+
llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
458+
}
459+
460+
inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
461+
switch (IntrinsicID) {
462+
case Intrinsic::nvvm_mul_rm_ftz_f:
463+
case Intrinsic::nvvm_mul_rn_ftz_f:
464+
case Intrinsic::nvvm_mul_rp_ftz_f:
465+
case Intrinsic::nvvm_mul_rz_ftz_f:
466+
return true;
467+
468+
case Intrinsic::nvvm_mul_rm_f:
469+
case Intrinsic::nvvm_mul_rn_f:
470+
case Intrinsic::nvvm_mul_rp_f:
471+
case Intrinsic::nvvm_mul_rz_f:
472+
case Intrinsic::nvvm_mul_rm_d:
473+
case Intrinsic::nvvm_mul_rn_d:
474+
case Intrinsic::nvvm_mul_rp_d:
475+
case Intrinsic::nvvm_mul_rz_d:
476+
return false;
477+
}
478+
llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
479+
}
480+
481+
inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
482+
switch (IntrinsicID) {
483+
case Intrinsic::nvvm_mul_rm_f:
484+
case Intrinsic::nvvm_mul_rm_d:
485+
case Intrinsic::nvvm_mul_rm_ftz_f:
486+
return APFloat::rmTowardNegative;
487+
case Intrinsic::nvvm_mul_rn_f:
488+
case Intrinsic::nvvm_mul_rn_d:
489+
case Intrinsic::nvvm_mul_rn_ftz_f:
490+
return APFloat::rmNearestTiesToEven;
491+
case Intrinsic::nvvm_mul_rp_f:
492+
case Intrinsic::nvvm_mul_rp_d:
493+
case Intrinsic::nvvm_mul_rp_ftz_f:
494+
return APFloat::rmTowardPositive;
495+
case Intrinsic::nvvm_mul_rz_f:
496+
case Intrinsic::nvvm_mul_rz_d:
497+
case Intrinsic::nvvm_mul_rz_ftz_f:
498+
return APFloat::rmTowardZero;
499+
}
500+
llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
501+
}
502+
503+
inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
504+
switch (IntrinsicID) {
505+
case Intrinsic::nvvm_div_rm_ftz_f:
506+
case Intrinsic::nvvm_div_rn_ftz_f:
507+
case Intrinsic::nvvm_div_rp_ftz_f:
508+
case Intrinsic::nvvm_div_rz_ftz_f:
509+
return true;
510+
511+
case Intrinsic::nvvm_div_rm_f:
512+
case Intrinsic::nvvm_div_rn_f:
513+
case Intrinsic::nvvm_div_rp_f:
514+
case Intrinsic::nvvm_div_rz_f:
515+
case Intrinsic::nvvm_div_rm_d:
516+
case Intrinsic::nvvm_div_rn_d:
517+
case Intrinsic::nvvm_div_rp_d:
518+
case Intrinsic::nvvm_div_rz_d:
519+
return false;
520+
}
521+
llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
522+
}
523+
524+
inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
525+
switch (IntrinsicID) {
526+
case Intrinsic::nvvm_div_rm_f:
527+
case Intrinsic::nvvm_div_rm_d:
528+
case Intrinsic::nvvm_div_rm_ftz_f:
529+
return APFloat::rmTowardNegative;
530+
case Intrinsic::nvvm_div_rn_f:
531+
case Intrinsic::nvvm_div_rn_d:
532+
case Intrinsic::nvvm_div_rn_ftz_f:
533+
return APFloat::rmNearestTiesToEven;
534+
case Intrinsic::nvvm_div_rp_f:
535+
case Intrinsic::nvvm_div_rp_d:
536+
case Intrinsic::nvvm_div_rp_ftz_f:
537+
return APFloat::rmTowardPositive;
538+
case Intrinsic::nvvm_div_rz_f:
539+
case Intrinsic::nvvm_div_rz_d:
540+
case Intrinsic::nvvm_div_rz_ftz_f:
541+
return APFloat::rmTowardZero;
542+
}
543+
llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
544+
}
545+
546+
inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
547+
switch (IntrinsicID) {
548+
case Intrinsic::nvvm_fma_rm_ftz_f:
549+
case Intrinsic::nvvm_fma_rn_ftz_f:
550+
case Intrinsic::nvvm_fma_rp_ftz_f:
551+
case Intrinsic::nvvm_fma_rz_ftz_f:
552+
return true;
553+
554+
case Intrinsic::nvvm_fma_rm_f:
555+
case Intrinsic::nvvm_fma_rn_f:
556+
case Intrinsic::nvvm_fma_rp_f:
557+
case Intrinsic::nvvm_fma_rz_f:
558+
case Intrinsic::nvvm_fma_rm_d:
559+
case Intrinsic::nvvm_fma_rn_d:
560+
case Intrinsic::nvvm_fma_rp_d:
561+
case Intrinsic::nvvm_fma_rz_d:
562+
return false;
563+
}
564+
llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
565+
}
566+
567+
inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
568+
switch (IntrinsicID) {
569+
case Intrinsic::nvvm_fma_rm_f:
570+
case Intrinsic::nvvm_fma_rm_d:
571+
case Intrinsic::nvvm_fma_rm_ftz_f:
572+
return APFloat::rmTowardNegative;
573+
case Intrinsic::nvvm_fma_rn_f:
574+
case Intrinsic::nvvm_fma_rn_d:
575+
case Intrinsic::nvvm_fma_rn_ftz_f:
576+
return APFloat::rmNearestTiesToEven;
577+
case Intrinsic::nvvm_fma_rp_f:
578+
case Intrinsic::nvvm_fma_rp_d:
579+
case Intrinsic::nvvm_fma_rp_ftz_f:
580+
return APFloat::rmTowardPositive;
581+
case Intrinsic::nvvm_fma_rz_f:
582+
case Intrinsic::nvvm_fma_rz_d:
583+
case Intrinsic::nvvm_fma_rz_ftz_f:
584+
return APFloat::rmTowardZero;
585+
}
586+
llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
587+
}
588+
417589
} // namespace nvvm
418590
} // namespace llvm
419591
#endif // LLVM_IR_NVVMINTRINSICUTILS_H

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
18471847
case Intrinsic::nvvm_sqrt_rn_ftz_f:
18481848
return !Call->isStrictFP();
18491849

1850+
// NVVM add intrinsics with explicit rounding modes
1851+
case Intrinsic::nvvm_add_rm_d:
1852+
case Intrinsic::nvvm_add_rn_d:
1853+
case Intrinsic::nvvm_add_rp_d:
1854+
case Intrinsic::nvvm_add_rz_d:
1855+
case Intrinsic::nvvm_add_rm_f:
1856+
case Intrinsic::nvvm_add_rn_f:
1857+
case Intrinsic::nvvm_add_rp_f:
1858+
case Intrinsic::nvvm_add_rz_f:
1859+
case Intrinsic::nvvm_add_rm_ftz_f:
1860+
case Intrinsic::nvvm_add_rn_ftz_f:
1861+
case Intrinsic::nvvm_add_rp_ftz_f:
1862+
case Intrinsic::nvvm_add_rz_ftz_f:
1863+
1864+
// NVVM div intrinsics with explicit rounding modes
1865+
case Intrinsic::nvvm_div_rm_d:
1866+
case Intrinsic::nvvm_div_rn_d:
1867+
case Intrinsic::nvvm_div_rp_d:
1868+
case Intrinsic::nvvm_div_rz_d:
1869+
case Intrinsic::nvvm_div_rm_f:
1870+
case Intrinsic::nvvm_div_rn_f:
1871+
case Intrinsic::nvvm_div_rp_f:
1872+
case Intrinsic::nvvm_div_rz_f:
1873+
case Intrinsic::nvvm_div_rm_ftz_f:
1874+
case Intrinsic::nvvm_div_rn_ftz_f:
1875+
case Intrinsic::nvvm_div_rp_ftz_f:
1876+
case Intrinsic::nvvm_div_rz_ftz_f:
1877+
1878+
// NVVM mul intrinsics with explicit rounding modes
1879+
case Intrinsic::nvvm_mul_rm_d:
1880+
case Intrinsic::nvvm_mul_rn_d:
1881+
case Intrinsic::nvvm_mul_rp_d:
1882+
case Intrinsic::nvvm_mul_rz_d:
1883+
case Intrinsic::nvvm_mul_rm_f:
1884+
case Intrinsic::nvvm_mul_rn_f:
1885+
case Intrinsic::nvvm_mul_rp_f:
1886+
case Intrinsic::nvvm_mul_rz_f:
1887+
case Intrinsic::nvvm_mul_rm_ftz_f:
1888+
case Intrinsic::nvvm_mul_rn_ftz_f:
1889+
case Intrinsic::nvvm_mul_rp_ftz_f:
1890+
case Intrinsic::nvvm_mul_rz_ftz_f:
1891+
1892+
// NVVM fma intrinsics with explicit rounding modes
1893+
case Intrinsic::nvvm_fma_rm_d:
1894+
case Intrinsic::nvvm_fma_rn_d:
1895+
case Intrinsic::nvvm_fma_rp_d:
1896+
case Intrinsic::nvvm_fma_rz_d:
1897+
case Intrinsic::nvvm_fma_rm_f:
1898+
case Intrinsic::nvvm_fma_rn_f:
1899+
case Intrinsic::nvvm_fma_rp_f:
1900+
case Intrinsic::nvvm_fma_rz_f:
1901+
case Intrinsic::nvvm_fma_rm_ftz_f:
1902+
case Intrinsic::nvvm_fma_rn_ftz_f:
1903+
case Intrinsic::nvvm_fma_rp_ftz_f:
1904+
case Intrinsic::nvvm_fma_rz_ftz_f:
1905+
18501906
// Sign operations are actually bitwise operations, they do not raise
18511907
// exceptions even for SNANs.
18521908
case Intrinsic::fabs:
@@ -3322,6 +3378,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
33223378

33233379
return ConstantFP::get(Ty->getContext(), Res);
33243380
}
3381+
3382+
case Intrinsic::nvvm_add_rm_f:
3383+
case Intrinsic::nvvm_add_rn_f:
3384+
case Intrinsic::nvvm_add_rp_f:
3385+
case Intrinsic::nvvm_add_rz_f:
3386+
case Intrinsic::nvvm_add_rm_d:
3387+
case Intrinsic::nvvm_add_rn_d:
3388+
case Intrinsic::nvvm_add_rp_d:
3389+
case Intrinsic::nvvm_add_rz_d:
3390+
case Intrinsic::nvvm_add_rm_ftz_f:
3391+
case Intrinsic::nvvm_add_rn_ftz_f:
3392+
case Intrinsic::nvvm_add_rp_ftz_f:
3393+
case Intrinsic::nvvm_add_rz_ftz_f: {
3394+
3395+
bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
3396+
APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
3397+
APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
3398+
3399+
APFloat::roundingMode RoundMode =
3400+
nvvm::GetFAddRoundingMode(IntrinsicID);
3401+
3402+
APFloat Res = A;
3403+
APFloat::opStatus Status = Res.add(B, RoundMode);
3404+
3405+
if (!Res.isNaN() &&
3406+
(Status == APFloat::opOK || Status == APFloat::opInexact)) {
3407+
Res = IsFTZ ? FTZPreserveSign(Res) : Res;
3408+
return ConstantFP::get(Ty->getContext(), Res);
3409+
}
3410+
return nullptr;
3411+
}
3412+
3413+
case Intrinsic::nvvm_mul_rm_f:
3414+
case Intrinsic::nvvm_mul_rn_f:
3415+
case Intrinsic::nvvm_mul_rp_f:
3416+
case Intrinsic::nvvm_mul_rz_f:
3417+
case Intrinsic::nvvm_mul_rm_d:
3418+
case Intrinsic::nvvm_mul_rn_d:
3419+
case Intrinsic::nvvm_mul_rp_d:
3420+
case Intrinsic::nvvm_mul_rz_d:
3421+
case Intrinsic::nvvm_mul_rm_ftz_f:
3422+
case Intrinsic::nvvm_mul_rn_ftz_f:
3423+
case Intrinsic::nvvm_mul_rp_ftz_f:
3424+
case Intrinsic::nvvm_mul_rz_ftz_f: {
3425+
3426+
bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
3427+
APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
3428+
APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
3429+
3430+
APFloat::roundingMode RoundMode =
3431+
nvvm::GetFMulRoundingMode(IntrinsicID);
3432+
3433+
APFloat Res = A;
3434+
APFloat::opStatus Status = Res.multiply(B, RoundMode);
3435+
3436+
if (!Res.isNaN() &&
3437+
(Status == APFloat::opOK || Status == APFloat::opInexact)) {
3438+
Res = IsFTZ ? FTZPreserveSign(Res) : Res;
3439+
return ConstantFP::get(Ty->getContext(), Res);
3440+
}
3441+
return nullptr;
3442+
}
3443+
3444+
case Intrinsic::nvvm_div_rm_f:
3445+
case Intrinsic::nvvm_div_rn_f:
3446+
case Intrinsic::nvvm_div_rp_f:
3447+
case Intrinsic::nvvm_div_rz_f:
3448+
case Intrinsic::nvvm_div_rm_d:
3449+
case Intrinsic::nvvm_div_rn_d:
3450+
case Intrinsic::nvvm_div_rp_d:
3451+
case Intrinsic::nvvm_div_rz_d:
3452+
case Intrinsic::nvvm_div_rm_ftz_f:
3453+
case Intrinsic::nvvm_div_rn_ftz_f:
3454+
case Intrinsic::nvvm_div_rp_ftz_f:
3455+
case Intrinsic::nvvm_div_rz_ftz_f: {
3456+
bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
3457+
APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
3458+
APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
3459+
APFloat::roundingMode RoundMode =
3460+
nvvm::GetFDivRoundingMode(IntrinsicID);
3461+
3462+
APFloat Res = A;
3463+
APFloat::opStatus Status = Res.divide(B, RoundMode);
3464+
if (!Res.isNaN() &&
3465+
(Status == APFloat::opOK || Status == APFloat::opInexact)) {
3466+
Res = IsFTZ ? FTZPreserveSign(Res) : Res;
3467+
return ConstantFP::get(Ty->getContext(), Res);
3468+
}
3469+
return nullptr;
3470+
}
33253471
}
33263472

33273473
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3733,6 +3879,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
37333879
V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
37343880
return ConstantFP::get(Ty->getContext(), V);
37353881
}
3882+
3883+
case Intrinsic::nvvm_fma_rm_f:
3884+
case Intrinsic::nvvm_fma_rn_f:
3885+
case Intrinsic::nvvm_fma_rp_f:
3886+
case Intrinsic::nvvm_fma_rz_f:
3887+
case Intrinsic::nvvm_fma_rm_d:
3888+
case Intrinsic::nvvm_fma_rn_d:
3889+
case Intrinsic::nvvm_fma_rp_d:
3890+
case Intrinsic::nvvm_fma_rz_d:
3891+
case Intrinsic::nvvm_fma_rm_ftz_f:
3892+
case Intrinsic::nvvm_fma_rn_ftz_f:
3893+
case Intrinsic::nvvm_fma_rp_ftz_f:
3894+
case Intrinsic::nvvm_fma_rz_ftz_f: {
3895+
bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
3896+
APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
3897+
APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
3898+
APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
3899+
3900+
APFloat::roundingMode RoundMode =
3901+
nvvm::GetFMARoundingMode(IntrinsicID);
3902+
3903+
APFloat Res = A;
3904+
APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
3905+
3906+
if (!Res.isNaN() &&
3907+
(Status == APFloat::opOK || Status == APFloat::opInexact)) {
3908+
Res = IsFTZ ? FTZPreserveSign(Res) : Res;
3909+
return ConstantFP::get(Ty->getContext(), Res);
3910+
}
3911+
return nullptr;
3912+
}
3913+
37363914
case Intrinsic::amdgcn_cubeid:
37373915
case Intrinsic::amdgcn_cubema:
37383916
case Intrinsic::amdgcn_cubesc:

0 commit comments

Comments
 (0)