From 2184096f9069d0b06e6af3ba697cd88e65fd2348 Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 23 Feb 2022 01:01:36 -0500 Subject: [PATCH] Use provided sqrt (#533) --- enzyme/Enzyme/AdjointGenerator.h | 37 +++++++++++-------- enzyme/test/Enzyme/ReverseMode/ompsqloop.ll | 2 +- .../Enzyme/ReverseMode/ompsqloopoutofplace.ll | 2 +- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 046f2059091a3..1aa26df2d3b9a 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3129,15 +3129,19 @@ class AdjointGenerator if (vdiff && !gutils->isConstantValue(orig_ops[0])) { SmallVector<Value *, 2> args = { lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2)}; - Type *tys[] = {orig_ops[0]->getType()}; - Function *SqrtF; - if (ID == Intrinsic::sqrt) - SqrtF = Intrinsic::getDeclaration(M, ID, tys); - else - SqrtF = Intrinsic::getDeclaration(M, ID); - auto cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args)); - cal->setCallingConv(SqrtF->getCallingConv()); + auto &CI = cast<CallInst>(I); +#if LLVM_VERSION_MAJOR >= 11 + auto *SqrtF = CI.getCalledOperand(); +#else + auto *SqrtF = CI.getCalledValue(); +#endif + assert(SqrtF); + auto FT = + cast<FunctionType>(SqrtF->getType()->getPointerElementType()); + + auto cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args)); + cal->setCallingConv(CI.getCallingConv()); cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc())); Value *dif0 = Builder2.CreateBinOp( @@ -3529,15 +3533,18 @@ class AdjointGenerator Value *args[1] = {gutils->getNewFromOriginal(orig_ops[0])}; Type *tys[] = {orig_ops[0]->getType()}; - Function *SqrtF; - if (ID == Intrinsic::sqrt) - SqrtF = Intrinsic::getDeclaration(M, ID, tys); - else - SqrtF = Intrinsic::getDeclaration(M, ID); + auto &CI = cast<CallInst>(I); +#if LLVM_VERSION_MAJOR >= 11 + auto *SqrtF = CI.getCalledOperand(); +#else + auto *SqrtF = CI.getCalledValue(); +#endif + assert(SqrtF); + auto FT = cast<FunctionType>(SqrtF->getType()->getPointerElementType()); auto rule = [&](Value *op) { - CallInst *cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args)); - cal->setCallingConv(SqrtF->getCallingConv()); + CallInst *cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args)); + cal->setCallingConv(CI.getCallingConv()); cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc())); Value *half = ConstantFP::get(orig_ops[0]->getType(), 0.5); diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll index 8d0b5a3528060..742d7ce35dd03 100644 --- a/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll +++ b/enzyme/test/Enzyme/ReverseMode/ompsqloop.ll @@ -210,7 +210,7 @@ attributes #1 = { argmemonly } ; CHECK-NEXT: %[[i9:.+]] = add nuw nsw i64 %"iv'ac.0", %_unwrap2 ; CHECK-NEXT: %[[i10:.+]] = getelementptr inbounds double, double* %truetape, i64 %[[i9]] ; CHECK-NEXT: %[[i11:.+]] = load double, double* %[[i10]], align 8, !tbaa !9, !invariant.group ! -; CHECK-NEXT: %[[i12:.+]] = call fast double @llvm.sqrt.f64(double %[[i11]]) +; CHECK-NEXT: %[[i12:.+]] = call fast double @sqrt(double %[[i11]]) ; CHECK-NEXT: %[[i13:.+]] = fmul fast double 5.000000e-01, %[[i8]] ; CHECK-NEXT: %[[i14:.+]] = fdiv fast double %[[i13]], %[[i12]] ; CHECK-NEXT: %[[i15:.+]] = fcmp fast oeq double %[[i11]], 0.000000e+00 diff --git a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll index 6cb2ef655e8d3..0c354029ba9d3 100644 --- a/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll +++ b/enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll @@ -155,7 +155,7 @@ attributes #1 = { argmemonly } ; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8 ; CHECK-NEXT: %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3 ; CHECK-NEXT: %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16 -; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %_unwrap4) +; CHECK-NEXT: %2 = call fast double @sqrt(double %_unwrap4) ; CHECK-NEXT: %3 = fmul fast double 5.000000e-01, %1 ; CHECK-NEXT: %4 = fdiv fast double %3, %2 ; CHECK-NEXT: %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00