diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 146ddff2eef53..0504ff430248c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1140,26 +1140,51 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModePrimal) return; - IRBuilder<> Builder2(EEI.getParent()); - getReverseBuilder(Builder2); + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&EEI); + getForwardBuilder(Builder2); - Value *orig_vec = EEI.getVectorOperand(); + Value *orig_vec = EEI.getVectorOperand(); - if (!gutils->isConstantValue(orig_vec)) { - SmallVector sv; - sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand())); + auto vec_diffe = gutils->isConstantValue(orig_vec) + ? ConstantVector::getNullValue(orig_vec->getType()) + : diffe(orig_vec, Builder2); + auto diffe = + Builder2.CreateExtractElement(vec_diffe, EEI.getIndexOperand()); - size_t size = 1; - if (EEI.getType()->isSized()) - size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( - EEI.getType()) + - 7) / - 8; - ((DiffeGradientUtils *)gutils) - ->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2, - TR.addingType(size, &EEI), sv); + setDiffe(&EEI, diffe, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(EEI.getParent()); + getReverseBuilder(Builder2); + + Value *orig_vec = EEI.getVectorOperand(); + + if (!gutils->isConstantValue(orig_vec)) { + SmallVector sv; + sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand())); + + size_t size = 1; + if (EEI.getType()->isSized()) + size = + (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits( + EEI.getType()) + + 7) / + 8; + ((DiffeGradientUtils *)gutils) + ->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2, + TR.addingType(size, &EEI), sv); + } + setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } } - setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2); } void visitInsertElementInst(llvm::InsertElementInst &IEI) {