@@ -1260,33 +1260,56 @@ class AdjointGenerator
1260
1260
if (EVI.getType ()->isPointerTy ())
1261
1261
return ;
1262
1262
1263
- if (Mode == DerivativeMode::ReverseModePrimal)
1263
+ switch (Mode) {
1264
+ case DerivativeMode::ForwardMode: {
1265
+ IRBuilder<> Builder2 (&EVI);
1266
+ getForwardBuilder (Builder2);
1267
+
1268
+ Value *orig_aggregate = EVI.getAggregateOperand ();
1269
+
1270
+ Value *diffe_aggregate =
1271
+ gutils->isConstantValue (orig_aggregate)
1272
+ ? ConstantAggregate::getNullValue (orig_aggregate->getType ())
1273
+ : diffe (orig_aggregate, Builder2);
1274
+ Value *diffe =
1275
+ Builder2.CreateExtractValue (diffe_aggregate, EVI.getIndices ());
1276
+
1277
+ setDiffe (&EVI, diffe, Builder2);
1264
1278
return ;
1279
+ }
1280
+ case DerivativeMode::ReverseModeGradient:
1281
+ case DerivativeMode::ReverseModeCombined: {
1282
+ IRBuilder<> Builder2 (EVI.getParent ());
1283
+ getReverseBuilder (Builder2);
1265
1284
1266
- Value *orig_op0 = EVI.getOperand (0 );
1285
+ Value *orig_op0 = EVI.getOperand (0 );
1267
1286
1268
- IRBuilder<> Builder2 (EVI.getParent ());
1269
- getReverseBuilder (Builder2);
1287
+ auto prediff = diffe (&EVI, Builder2);
1270
1288
1271
- auto prediff = diffe (&EVI, Builder2);
1289
+ // todo const
1290
+ if (!gutils->isConstantValue (orig_op0)) {
1291
+ SmallVector<Value *, 4 > sv;
1292
+ for (auto i : EVI.getIndices ())
1293
+ sv.push_back (ConstantInt::get (Type::getInt32Ty (EVI.getContext ()), i));
1294
+ size_t size = 1 ;
1295
+ if (EVI.getType ()->isSized ())
1296
+ size =
1297
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1298
+ EVI.getType ()) +
1299
+ 7 ) /
1300
+ 8 ;
1301
+ ((DiffeGradientUtils *)gutils)
1302
+ ->addToDiffe (orig_op0, prediff, Builder2, TR.addingType (size, &EVI),
1303
+ sv);
1304
+ }
1272
1305
1273
- // todo const
1274
- if (!gutils->isConstantValue (orig_op0)) {
1275
- SmallVector<Value *, 4 > sv;
1276
- for (auto i : EVI.getIndices ())
1277
- sv.push_back (ConstantInt::get (Type::getInt32Ty (EVI.getContext ()), i));
1278
- size_t size = 1 ;
1279
- if (EVI.getType ()->isSized ())
1280
- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1281
- EVI.getType ()) +
1282
- 7 ) /
1283
- 8 ;
1284
- ((DiffeGradientUtils *)gutils)
1285
- ->addToDiffe (orig_op0, prediff, Builder2, TR.addingType (size, &EVI),
1286
- sv);
1306
+ setDiffe (&EVI, Constant::getNullValue (EVI.getType ()), Builder2);
1307
+ return ;
1308
+ }
1309
+ case DerivativeMode::ReverseModePrimal: {
1310
+ return ;
1311
+ }
1287
1312
}
1288
-
1289
- setDiffe (&EVI, Constant::getNullValue (EVI.getType ()), Builder2);
1290
1313
}
1291
1314
1292
1315
void visitInsertValueInst (llvm::InsertValueInst &IVI) {
0 commit comments