Skip to content

Commit

Permalink
Use provided sqrt (rust-lang#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 23, 2022
1 parent 26d31e0 commit 2184096
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
37 changes: 22 additions & 15 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3129,15 +3129,19 @@ class AdjointGenerator
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
SmallVector<Value *, 2> args = {
lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2)};
Type *tys[] = {orig_ops[0]->getType()};
Function *SqrtF;
if (ID == Intrinsic::sqrt)
SqrtF = Intrinsic::getDeclaration(M, ID, tys);
else
SqrtF = Intrinsic::getDeclaration(M, ID);

auto cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
cal->setCallingConv(SqrtF->getCallingConv());
auto &CI = cast<CallInst>(I);
#if LLVM_VERSION_MAJOR >= 11
auto *SqrtF = CI.getCalledOperand();
#else
auto *SqrtF = CI.getCalledValue();
#endif
assert(SqrtF);
auto FT =
cast<FunctionType>(SqrtF->getType()->getPointerElementType());

auto cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
cal->setCallingConv(CI.getCallingConv());
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));

Value *dif0 = Builder2.CreateBinOp(
Expand Down Expand Up @@ -3529,15 +3533,18 @@ class AdjointGenerator
Value *args[1] = {gutils->getNewFromOriginal(orig_ops[0])};
Type *tys[] = {orig_ops[0]->getType()};

Function *SqrtF;
if (ID == Intrinsic::sqrt)
SqrtF = Intrinsic::getDeclaration(M, ID, tys);
else
SqrtF = Intrinsic::getDeclaration(M, ID);
auto &CI = cast<CallInst>(I);
#if LLVM_VERSION_MAJOR >= 11
auto *SqrtF = CI.getCalledOperand();
#else
auto *SqrtF = CI.getCalledValue();
#endif
assert(SqrtF);
auto FT = cast<FunctionType>(SqrtF->getType()->getPointerElementType());

auto rule = [&](Value *op) {
CallInst *cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
cal->setCallingConv(SqrtF->getCallingConv());
CallInst *cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
cal->setCallingConv(CI.getCallingConv());
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));

Value *half = ConstantFP::get(orig_ops[0]->getType(), 0.5);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/ompsqloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ attributes #1 = { argmemonly }
; CHECK-NEXT: %[[i9:.+]] = add nuw nsw i64 %"iv'ac.0", %_unwrap2
; CHECK-NEXT: %[[i10:.+]] = getelementptr inbounds double, double* %truetape, i64 %[[i9]]
; CHECK-NEXT: %[[i11:.+]] = load double, double* %[[i10]], align 8, !tbaa !9, !invariant.group !
; CHECK-NEXT: %[[i12:.+]] = call fast double @llvm.sqrt.f64(double %[[i11]])
; CHECK-NEXT: %[[i12:.+]] = call fast double @sqrt(double %[[i11]])
; CHECK-NEXT: %[[i13:.+]] = fmul fast double 5.000000e-01, %[[i8]]
; CHECK-NEXT: %[[i14:.+]] = fdiv fast double %[[i13]], %[[i12]]
; CHECK-NEXT: %[[i15:.+]] = fcmp fast oeq double %[[i11]], 0.000000e+00
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ attributes #1 = { argmemonly }
; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8
; CHECK-NEXT: %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3
; CHECK-NEXT: %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16
; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %_unwrap4)
; CHECK-NEXT: %2 = call fast double @sqrt(double %_unwrap4)
; CHECK-NEXT: %3 = fmul fast double 5.000000e-01, %1
; CHECK-NEXT: %4 = fdiv fast double %3, %2
; CHECK-NEXT: %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00
Expand Down

0 comments on commit 2184096

Please sign in to comment.