Skip to content

Commit

Permalink
ForwardMode: extractelement inst (rust-lang#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Oct 14, 2021
1 parent a2f7718 commit c98cc98
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1140,26 +1140,51 @@ class AdjointGenerator
if (Mode == DerivativeMode::ReverseModePrimal)
return;

IRBuilder<> Builder2(EEI.getParent());
getReverseBuilder(Builder2);
switch (Mode) {
case DerivativeMode::ForwardMode: {
IRBuilder<> Builder2(&EEI);
getForwardBuilder(Builder2);

Value *orig_vec = EEI.getVectorOperand();
Value *orig_vec = EEI.getVectorOperand();

if (!gutils->isConstantValue(orig_vec)) {
SmallVector<Value *, 4> sv;
sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand()));
auto vec_diffe = gutils->isConstantValue(orig_vec)
? ConstantVector::getNullValue(orig_vec->getType())
: diffe(orig_vec, Builder2);
auto diffe =
Builder2.CreateExtractElement(vec_diffe, EEI.getIndexOperand());

size_t size = 1;
if (EEI.getType()->isSized())
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
EEI.getType()) +
7) /
8;
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2,
TR.addingType(size, &EEI), sv);
setDiffe(&EEI, diffe, Builder2);
return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(EEI.getParent());
getReverseBuilder(Builder2);

Value *orig_vec = EEI.getVectorOperand();

if (!gutils->isConstantValue(orig_vec)) {
SmallVector<Value *, 4> sv;
sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand()));

size_t size = 1;
if (EEI.getType()->isSized())
size =
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
EEI.getType()) +
7) /
8;
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2,
TR.addingType(size, &EEI), sv);
}
setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2);
return;
}
case DerivativeMode::ReverseModePrimal: {
return;
}
}
setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2);
}

void visitInsertElementInst(llvm::InsertElementInst &IEI) {
Expand Down

0 comments on commit c98cc98

Please sign in to comment.