@@ -1140,26 +1140,51 @@ class AdjointGenerator
1140
1140
if (Mode == DerivativeMode::ReverseModePrimal)
1141
1141
return ;
1142
1142
1143
- IRBuilder<> Builder2 (EEI.getParent ());
1144
- getReverseBuilder (Builder2);
1143
+ switch (Mode) {
1144
+ case DerivativeMode::ForwardMode: {
1145
+ IRBuilder<> Builder2 (&EEI);
1146
+ getForwardBuilder (Builder2);
1145
1147
1146
- Value *orig_vec = EEI.getVectorOperand ();
1148
+ Value *orig_vec = EEI.getVectorOperand ();
1147
1149
1148
- if (!gutils->isConstantValue (orig_vec)) {
1149
- SmallVector<Value *, 4 > sv;
1150
- sv.push_back (gutils->getNewFromOriginal (EEI.getIndexOperand ()));
1150
+ auto vec_diffe = gutils->isConstantValue (orig_vec)
1151
+ ? ConstantVector::getNullValue (orig_vec->getType ())
1152
+ : diffe (orig_vec, Builder2);
1153
+ auto diffe =
1154
+ Builder2.CreateExtractElement (vec_diffe, EEI.getIndexOperand ());
1151
1155
1152
- size_t size = 1 ;
1153
- if (EEI.getType ()->isSized ())
1154
- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1155
- EEI.getType ()) +
1156
- 7 ) /
1157
- 8 ;
1158
- ((DiffeGradientUtils *)gutils)
1159
- ->addToDiffe (orig_vec, diffe (&EEI, Builder2), Builder2,
1160
- TR.addingType (size, &EEI), sv);
1156
+ setDiffe (&EEI, diffe, Builder2);
1157
+ return ;
1158
+ }
1159
+ case DerivativeMode::ReverseModeGradient:
1160
+ case DerivativeMode::ReverseModeCombined: {
1161
+ IRBuilder<> Builder2 (EEI.getParent ());
1162
+ getReverseBuilder (Builder2);
1163
+
1164
+ Value *orig_vec = EEI.getVectorOperand ();
1165
+
1166
+ if (!gutils->isConstantValue (orig_vec)) {
1167
+ SmallVector<Value *, 4 > sv;
1168
+ sv.push_back (gutils->getNewFromOriginal (EEI.getIndexOperand ()));
1169
+
1170
+ size_t size = 1 ;
1171
+ if (EEI.getType ()->isSized ())
1172
+ size =
1173
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1174
+ EEI.getType ()) +
1175
+ 7 ) /
1176
+ 8 ;
1177
+ ((DiffeGradientUtils *)gutils)
1178
+ ->addToDiffe (orig_vec, diffe (&EEI, Builder2), Builder2,
1179
+ TR.addingType (size, &EEI), sv);
1180
+ }
1181
+ setDiffe (&EEI, Constant::getNullValue (EEI.getType ()), Builder2);
1182
+ return ;
1183
+ }
1184
+ case DerivativeMode::ReverseModePrimal: {
1185
+ return ;
1186
+ }
1161
1187
}
1162
- setDiffe (&EEI, Constant::getNullValue (EEI.getType ()), Builder2);
1163
1188
}
1164
1189
1165
1190
void visitInsertElementInst (llvm::InsertElementInst &IEI) {
0 commit comments