Skip to content

Commit ef3a0ac

Browse files
authored
ForwardMode: extractvalue inst (rust-lang#354)
1 parent ffe46d5 commit ef3a0ac

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

enzyme/Enzyme/AdjointGenerator.h

+44-21
Original file line numberDiff line numberDiff line change
@@ -1260,33 +1260,56 @@ class AdjointGenerator
12601260
if (EVI.getType()->isPointerTy())
12611261
return;
12621262

1263-
if (Mode == DerivativeMode::ReverseModePrimal)
1263+
switch (Mode) {
1264+
case DerivativeMode::ForwardMode: {
1265+
IRBuilder<> Builder2(&EVI);
1266+
getForwardBuilder(Builder2);
1267+
1268+
Value *orig_aggregate = EVI.getAggregateOperand();
1269+
1270+
Value *diffe_aggregate =
1271+
gutils->isConstantValue(orig_aggregate)
1272+
? ConstantAggregate::getNullValue(orig_aggregate->getType())
1273+
: diffe(orig_aggregate, Builder2);
1274+
Value *diffe =
1275+
Builder2.CreateExtractValue(diffe_aggregate, EVI.getIndices());
1276+
1277+
setDiffe(&EVI, diffe, Builder2);
12641278
return;
1279+
}
1280+
case DerivativeMode::ReverseModeGradient:
1281+
case DerivativeMode::ReverseModeCombined: {
1282+
IRBuilder<> Builder2(EVI.getParent());
1283+
getReverseBuilder(Builder2);
12651284

1266-
Value *orig_op0 = EVI.getOperand(0);
1285+
Value *orig_op0 = EVI.getOperand(0);
12671286

1268-
IRBuilder<> Builder2(EVI.getParent());
1269-
getReverseBuilder(Builder2);
1287+
auto prediff = diffe(&EVI, Builder2);
12701288

1271-
auto prediff = diffe(&EVI, Builder2);
1289+
// todo const
1290+
if (!gutils->isConstantValue(orig_op0)) {
1291+
SmallVector<Value *, 4> sv;
1292+
for (auto i : EVI.getIndices())
1293+
sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
1294+
size_t size = 1;
1295+
if (EVI.getType()->isSized())
1296+
size =
1297+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1298+
EVI.getType()) +
1299+
7) /
1300+
8;
1301+
((DiffeGradientUtils *)gutils)
1302+
->addToDiffe(orig_op0, prediff, Builder2, TR.addingType(size, &EVI),
1303+
sv);
1304+
}
12721305

1273-
// todo const
1274-
if (!gutils->isConstantValue(orig_op0)) {
1275-
SmallVector<Value *, 4> sv;
1276-
for (auto i : EVI.getIndices())
1277-
sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
1278-
size_t size = 1;
1279-
if (EVI.getType()->isSized())
1280-
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1281-
EVI.getType()) +
1282-
7) /
1283-
8;
1284-
((DiffeGradientUtils *)gutils)
1285-
->addToDiffe(orig_op0, prediff, Builder2, TR.addingType(size, &EVI),
1286-
sv);
1306+
setDiffe(&EVI, Constant::getNullValue(EVI.getType()), Builder2);
1307+
return;
1308+
}
1309+
case DerivativeMode::ReverseModePrimal: {
1310+
return;
1311+
}
12871312
}
1288-
1289-
setDiffe(&EVI, Constant::getNullValue(EVI.getType()), Builder2);
12901313
}
12911314

12921315
void visitInsertValueInst(llvm::InsertValueInst &IVI) {

0 commit comments

Comments
 (0)