From 2184096f9069d0b06e6af3ba697cd88e65fd2348 Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 23 Feb 2022 01:01:36 -0500
Subject: [PATCH] Use provided sqrt (#533)

---
 enzyme/Enzyme/AdjointGenerator.h              | 37 +++++++++++--------
 enzyme/test/Enzyme/ReverseMode/ompsqloop.ll   |  2 +-
 .../Enzyme/ReverseMode/ompsqloopoutofplace.ll |  2 +-
 3 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h
index 046f2059091a3..1aa26df2d3b9a 100644
--- a/enzyme/Enzyme/AdjointGenerator.h
+++ b/enzyme/Enzyme/AdjointGenerator.h
@@ -3129,15 +3129,19 @@ class AdjointGenerator
         if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
           SmallVector<Value *, 2> args = {
               lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2)};
-          Type *tys[] = {orig_ops[0]->getType()};
-          Function *SqrtF;
-          if (ID == Intrinsic::sqrt)
-            SqrtF = Intrinsic::getDeclaration(M, ID, tys);
-          else
-            SqrtF = Intrinsic::getDeclaration(M, ID);
 
-          auto cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
-          cal->setCallingConv(SqrtF->getCallingConv());
+          auto &CI = cast<CallInst>(I);
+#if LLVM_VERSION_MAJOR >= 11
+          auto *SqrtF = CI.getCalledOperand();
+#else
+          auto *SqrtF = CI.getCalledValue();
+#endif
+          assert(SqrtF);
+          auto FT =
+              cast<FunctionType>(SqrtF->getType()->getPointerElementType());
+
+          auto cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
+          cal->setCallingConv(CI.getCallingConv());
           cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
 
           Value *dif0 = Builder2.CreateBinOp(
@@ -3529,15 +3533,18 @@ class AdjointGenerator
         Value *args[1] = {gutils->getNewFromOriginal(orig_ops[0])};
         Type *tys[] = {orig_ops[0]->getType()};
 
-        Function *SqrtF;
-        if (ID == Intrinsic::sqrt)
-          SqrtF = Intrinsic::getDeclaration(M, ID, tys);
-        else
-          SqrtF = Intrinsic::getDeclaration(M, ID);
+        auto &CI = cast<CallInst>(I);
+#if LLVM_VERSION_MAJOR >= 11
+        auto *SqrtF = CI.getCalledOperand();
+#else
+        auto *SqrtF = CI.getCalledValue();
+#endif
+        assert(SqrtF);
+        auto FT = cast<FunctionType>(SqrtF->getType()->getPointerElementType());
 
         auto rule = [&](Value *op) {
-          CallInst *cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
-          cal->setCallingConv(SqrtF->getCallingConv());
+          CallInst *cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
+          cal->setCallingConv(CI.getCallingConv());
           cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
 
           Value *half = ConstantFP::get(orig_ops[0]->getType(), 0.5);
diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll
index 8d0b5a3528060..742d7ce35dd03 100644
--- a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll
+++ b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll
@@ -210,7 +210,7 @@ attributes #1 = { argmemonly }
 ; CHECK-NEXT:   %[[i9:.+]] = add nuw nsw i64 %"iv'ac.0", %_unwrap2
 ; CHECK-NEXT:   %[[i10:.+]] = getelementptr inbounds double, double* %truetape, i64 %[[i9]]
 ; CHECK-NEXT:   %[[i11:.+]] = load double, double* %[[i10]], align 8, !tbaa !9, !invariant.group !
-; CHECK-NEXT:   %[[i12:.+]] = call fast double @llvm.sqrt.f64(double %[[i11]])
+; CHECK-NEXT:   %[[i12:.+]] = call fast double @sqrt(double %[[i11]])
 ; CHECK-NEXT:   %[[i13:.+]] = fmul fast double 5.000000e-01, %[[i8]]
 ; CHECK-NEXT:   %[[i14:.+]] = fdiv fast double %[[i13]], %[[i12]]
 ; CHECK-NEXT:   %[[i15:.+]] = fcmp fast oeq double %[[i11]], 0.000000e+00
diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll
index 6cb2ef655e8d3..0c354029ba9d3 100644
--- a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll
+++ b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll
@@ -155,7 +155,7 @@ attributes #1 = { argmemonly }
 ; CHECK-NEXT:   store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8
 ; CHECK-NEXT:   %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3
 ; CHECK-NEXT:   %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16
-; CHECK-NEXT:   %2 = call fast double @llvm.sqrt.f64(double %_unwrap4)
+; CHECK-NEXT:   %2 = call fast double @sqrt(double %_unwrap4)
 ; CHECK-NEXT:   %3 = fmul fast double 5.000000e-01, %1
 ; CHECK-NEXT:   %4 = fdiv fast double %3, %2
 ; CHECK-NEXT:   %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00