Skip to content

Commit

Permalink
Soft float (rust-lang#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored May 24, 2022
1 parent c990136 commit 034a18f
Show file tree
Hide file tree
Showing 13 changed files with 625 additions and 1 deletion.
6 changes: 5 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,11 @@ const std::set<std::string> KnownInactiveFunctions = {
"_msize",
"ftnio_fmt_write64",
"f90_strcmp_klen",
"__swift_instantiateConcreteTypeFromMangledName"};
"__swift_instantiateConcreteTypeFromMangledName",
"logb",
"logbf",
"logbl",
};

/// Is the use of value val as an argument of call CI known to be inactive
/// This tool can only be used when in DOWN mode
Expand Down
317 changes: 317 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9239,6 +9239,323 @@ class AdjointGenerator
return;
}

if (funcName == "__mulsc3" || funcName == "__muldc3" ||
funcName == "__multc3" || funcName == "__mulxc3") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, newCall,
getIndex(orig, CacheType::Self));
}
}

eraseIfUnused(*orig);
if (gutils->isConstantInstruction(orig))
return;

Value *orig_op0 = call.getOperand(0);
Value *orig_op1 = call.getOperand(1);
Value *orig_op2 = call.getOperand(2);
Value *orig_op3 = call.getOperand(3);

bool constantval0 = gutils->isConstantValue(orig_op0);
bool constantval1 = gutils->isConstantValue(orig_op1);
bool constantval2 = gutils->isConstantValue(orig_op2);
bool constantval3 = gutils->isConstantValue(orig_op3);

Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
gutils->getNewFromOriginal(orig_op1),
gutils->getNewFromOriginal(orig_op2),
gutils->getNewFromOriginal(orig_op3)};

auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
funcName, called->getFunctionType(), called->getAttributes());

switch (Mode) {
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit: {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

Value *diff[4] = {
constantval0 ? Constant::getNullValue(orig_op0->getType())
: diffe(orig_op0, Builder2),
constantval1 ? Constant::getNullValue(orig_op1->getType())
: diffe(orig_op1, Builder2),
constantval2 ? Constant::getNullValue(orig_op2->getType())
: diffe(orig_op2, Builder2),
constantval3 ? Constant::getNullValue(orig_op3->getType())
: diffe(orig_op3, Builder2)};

auto cal1 =
Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
auto cal2 =
Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});

Value *resReal =
Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {0}),
Builder2.CreateExtractValue(cal2, {0}));
Value *resImag =
Builder2.CreateFAdd(Builder2.CreateExtractValue(cal1, {1}),
Builder2.CreateExtractValue(cal2, {1}));

Value *res = Builder2.CreateInsertValue(
UndefValue::get(call.getType()), resReal, {0});
res = Builder2.CreateInsertValue(res, resImag, {1});

setDiffe(&call, res, Builder2);
return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);

Value *idiff = diffe(&call, Builder2);
Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});

Value *diff0 = nullptr;
Value *diff1 = nullptr;

if (!constantval0 || !constantval1)
diff0 = Builder2.CreateCall(mul, {idiffReal, idiffImag,
lookup(prim[2], Builder2),
lookup(prim[3], Builder2)});

if (!constantval2 || !constantval3)
diff1 = Builder2.CreateCall(mul, {lookup(prim[0], Builder2),
lookup(prim[1], Builder2),
idiffReal, idiffImag});

if (diff0 || diff1)
setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);

if (diff0) {
addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
Builder2, orig_op0->getType());
addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
Builder2, orig_op1->getType());
}

if (diff1) {
addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
Builder2, orig_op2->getType());
addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
Builder2, orig_op3->getType());
}

return;
}
case DerivativeMode::ReverseModePrimal:
return;
}
}

if (funcName == "__divsc3" || funcName == "__divdc3" ||
funcName == "__divtc3" || funcName == "__divxc3") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, newCall,
getIndex(orig, CacheType::Self));
}
}

if (gutils->isConstantInstruction(orig))
return;

StringMap<StringRef> map = {
{"__divsc3", "__mulsc3"},
{"__divdc3", "__muldc3"},
{"__divtc3", "__multc3"},
{"__divxc3", "__mulxc3"},
};

auto mul = gutils->oldFunc->getParent()->getOrInsertFunction(
map[funcName], called->getFunctionType(), called->getAttributes());

auto div = gutils->oldFunc->getParent()->getOrInsertFunction(
funcName, called->getFunctionType(), called->getAttributes());

Value *orig_op0 = call.getOperand(0);
Value *orig_op1 = call.getOperand(1);
Value *orig_op2 = call.getOperand(2);
Value *orig_op3 = call.getOperand(3);

bool constantval0 = gutils->isConstantValue(orig_op0);
bool constantval1 = gutils->isConstantValue(orig_op1);
bool constantval2 = gutils->isConstantValue(orig_op2);
bool constantval3 = gutils->isConstantValue(orig_op3);

Value *prim[4] = {gutils->getNewFromOriginal(orig_op0),
gutils->getNewFromOriginal(orig_op1),
gutils->getNewFromOriginal(orig_op2),
gutils->getNewFromOriginal(orig_op3)};

switch (Mode) {
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit: {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

Value *diff[4] = {
constantval0 ? Constant::getNullValue(orig_op0->getType())
: diffe(orig_op0, Builder2),
constantval1 ? Constant::getNullValue(orig_op1->getType())
: diffe(orig_op1, Builder2),
constantval2 ? Constant::getNullValue(orig_op2->getType())
: diffe(orig_op2, Builder2),
constantval3 ? Constant::getNullValue(orig_op3->getType())
: diffe(orig_op3, Builder2)};

auto mul1 =
Builder2.CreateCall(mul, {diff[0], diff[1], prim[2], prim[3]});
auto mul2 =
Builder2.CreateCall(mul, {prim[0], prim[1], diff[2], diff[3]});
auto sq1 =
Builder2.CreateCall(mul, {prim[2], prim[3], prim[2], prim[3]});

Value *subReal =
Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {0}),
Builder2.CreateExtractValue(mul2, {0}));
Value *subImag =
Builder2.CreateFSub(Builder2.CreateExtractValue(mul1, {1}),
Builder2.CreateExtractValue(mul2, {1}));

auto div1 = Builder2.CreateCall(
div, {subReal, subImag, Builder2.CreateExtractValue(sq1, {0}),
Builder2.CreateExtractValue(sq1, {1})});

setDiffe(&call, div1, Builder2);

eraseIfUnused(*orig);

return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);

Value *idiff = diffe(&call, Builder2);
Value *idiffReal = Builder2.CreateExtractValue(idiff, {0});
Value *idiffImag = Builder2.CreateExtractValue(idiff, {1});

Value *diff0 = nullptr;
Value *diff1 = nullptr;

if (!constantval0 || !constantval1)
diff0 = Builder2.CreateCall(div, {idiffReal, idiffImag,
lookup(prim[2], Builder2),
lookup(prim[3], Builder2)});

if (!constantval2 || !constantval3) {
auto fdiv = Builder2.CreateCall(div, {idiffReal, idiffImag,
lookup(prim[1], Builder2),
lookup(prim[2], Builder2)});

Value *newcall = gutils->getNewFromOriginal(&call);

diff1 = Builder2.CreateCall(
mul,
{Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {0})),
Builder2.CreateFNeg(Builder2.CreateExtractValue(newcall, {1})),
Builder2.CreateExtractValue(fdiv, {0}),
Builder2.CreateExtractValue(fdiv, {1})});
}

if (diff0 || diff1)
setDiffe(&call, Constant::getNullValue(call.getType()), Builder2);

if (diff0) {
addToDiffe(orig_op0, Builder2.CreateExtractValue(diff0, {0}),
Builder2, orig_op0->getType());
addToDiffe(orig_op1, Builder2.CreateExtractValue(diff0, {1}),
Builder2, orig_op1->getType());
}

if (diff1) {
addToDiffe(orig_op2, Builder2.CreateExtractValue(diff1, {0}),
Builder2, orig_op2->getType());
addToDiffe(orig_op3, Builder2.CreateExtractValue(diff1, {1}),
Builder2, orig_op3->getType());
}

if (constantval2 && constantval3)
eraseIfUnused(*orig);

return;
}
case DerivativeMode::ReverseModePrimal:;
return;
}
}

if (funcName == "scalbn" || funcName == "scalbnf" ||
funcName == "scalbnl" || funcName == "scalbln" ||
funcName == "scalblnf" || funcName == "scalblnl") {
eraseIfUnused(*orig);

Value *orig_op0 = call.getOperand(0);
Value *orig_op1 = call.getOperand(1);

bool constantval0 = gutils->isConstantValue(orig_op0);

if (gutils->isConstantInstruction(orig) || constantval0)
return;

Value *op0 = gutils->getNewFromOriginal(orig_op0);
Value *op1 = gutils->getNewFromOriginal(orig_op1);

auto scal = gutils->oldFunc->getParent()->getOrInsertFunction(
funcName, called->getFunctionType(), called->getAttributes());

switch (Mode) {
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit: {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

Value *diff0 = diffe(orig_op0, Builder2);

auto cal1 = Builder2.CreateCall(scal, {op0, op1});
auto cal2 = Builder2.CreateCall(scal, {diff0, op1});

Value *diff = Builder2.CreateFMul(
cal1, ConstantFP::get(call.getType(), 0.3010299957));
diff = Builder2.CreateFAdd(diff, cal2);

setDiffe(&call, diff, Builder2);
return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);

Value *idiff = diffe(&call, Builder2);

if (idiff && !constantval0) {
op1 = lookup(op1, Builder2);

auto cal1 = Builder2.CreateCall(scal, {op0, op1});
auto cal2 = Builder2.CreateCall(scal, {idiff, op1});

Value *diff = Builder2.CreateFMul(
cal1, ConstantFP::get(call.getType(), 0.3010299957));
diff = Builder2.CreateFAdd(diff, cal2);

addToDiffe(orig_op0, diff, Builder2, call.getType());
}

return;
}
case DerivativeMode::ReverseModePrimal:;
return;
}
}

if (called) {
if (funcName == "erf" || funcName == "erfi" || funcName == "erfc" ||
funcName == "Faddeeva_erf" || funcName == "Faddeeva_erfi" ||
Expand Down
17 changes: 17 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,23 @@ const std::map<std::string, llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
{"log1p", Intrinsic::not_intrinsic},
{"log2", Intrinsic::log2},
{"logb", Intrinsic::not_intrinsic},
{"logbf", Intrinsic::not_intrinsic},
{"logbl", Intrinsic::not_intrinsic},
{"pow", Intrinsic::pow},
{"sqrt", Intrinsic::sqrt},
{"cbrt", Intrinsic::not_intrinsic},
{"hypot", Intrinsic::not_intrinsic},

{"__mulsc3", Intrinsic::not_intrinsic},
{"__muldc3", Intrinsic::not_intrinsic},
{"__multc3", Intrinsic::not_intrinsic},
{"__mulxc3", Intrinsic::not_intrinsic},

{"__divsc3", Intrinsic::not_intrinsic},
{"__divdc3", Intrinsic::not_intrinsic},
{"__divtc3", Intrinsic::not_intrinsic},
{"__divxc3", Intrinsic::not_intrinsic},

{"Faddeeva_erf", Intrinsic::not_intrinsic},
{"Faddeeva_erfc", Intrinsic::not_intrinsic},
{"Faddeeva_erfcx", Intrinsic::not_intrinsic},
Expand Down Expand Up @@ -139,6 +151,11 @@ const std::map<std::string, llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
{"fma", Intrinsic::fma},
{"ilogb", Intrinsic::not_intrinsic},
{"scalbn", Intrinsic::not_intrinsic},
{"scalbnf", Intrinsic::not_intrinsic},
{"scalbnl", Intrinsic::not_intrinsic},
{"scalbln", Intrinsic::not_intrinsic},
{"scalblnf", Intrinsic::not_intrinsic},
{"scalblnl", Intrinsic::not_intrinsic},
{"powi", Intrinsic::powi},
{"cabs", Intrinsic::not_intrinsic},
{"ldexp", Intrinsic::not_intrinsic},
Expand Down
Loading

0 comments on commit 034a18f

Please sign in to comment.