Skip to content

Commit c98cc98

Browse files
authored
ForwardMode: extractelement inst (rust-lang#353)
1 parent a2f7718 commit c98cc98

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,26 +1140,51 @@ class AdjointGenerator
11401140
if (Mode == DerivativeMode::ReverseModePrimal)
11411141
return;
11421142

1143-
IRBuilder<> Builder2(EEI.getParent());
1144-
getReverseBuilder(Builder2);
1143+
switch (Mode) {
1144+
case DerivativeMode::ForwardMode: {
1145+
IRBuilder<> Builder2(&EEI);
1146+
getForwardBuilder(Builder2);
11451147

1146-
Value *orig_vec = EEI.getVectorOperand();
1148+
Value *orig_vec = EEI.getVectorOperand();
11471149

1148-
if (!gutils->isConstantValue(orig_vec)) {
1149-
SmallVector<Value *, 4> sv;
1150-
sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand()));
1150+
auto vec_diffe = gutils->isConstantValue(orig_vec)
1151+
? ConstantVector::getNullValue(orig_vec->getType())
1152+
: diffe(orig_vec, Builder2);
1153+
auto diffe =
1154+
Builder2.CreateExtractElement(vec_diffe, EEI.getIndexOperand());
11511155

1152-
size_t size = 1;
1153-
if (EEI.getType()->isSized())
1154-
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1155-
EEI.getType()) +
1156-
7) /
1157-
8;
1158-
((DiffeGradientUtils *)gutils)
1159-
->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2,
1160-
TR.addingType(size, &EEI), sv);
1156+
setDiffe(&EEI, diffe, Builder2);
1157+
return;
1158+
}
1159+
case DerivativeMode::ReverseModeGradient:
1160+
case DerivativeMode::ReverseModeCombined: {
1161+
IRBuilder<> Builder2(EEI.getParent());
1162+
getReverseBuilder(Builder2);
1163+
1164+
Value *orig_vec = EEI.getVectorOperand();
1165+
1166+
if (!gutils->isConstantValue(orig_vec)) {
1167+
SmallVector<Value *, 4> sv;
1168+
sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand()));
1169+
1170+
size_t size = 1;
1171+
if (EEI.getType()->isSized())
1172+
size =
1173+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1174+
EEI.getType()) +
1175+
7) /
1176+
8;
1177+
((DiffeGradientUtils *)gutils)
1178+
->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2,
1179+
TR.addingType(size, &EEI), sv);
1180+
}
1181+
setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2);
1182+
return;
1183+
}
1184+
case DerivativeMode::ReverseModePrimal: {
1185+
return;
1186+
}
11611187
}
1162-
setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2);
11631188
}
11641189

11651190
void visitInsertElementInst(llvm::InsertElementInst &IEI) {

0 commit comments

Comments
 (0)