diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1e3c92384022a..e1100836bac5c 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3542,7 +3542,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction( std::pair>(&a, {})); DIFFE_TYPE typ; if (a.getType()->isFPOrFPVectorTy()) { - typ = DIFFE_TYPE::OUT_DIFF; + typ = mode == DerivativeMode::ForwardMode ? DIFFE_TYPE::DUP_ARG + : DIFFE_TYPE::OUT_DIFF; } else if (a.getType()->isIntegerTy() && cast(a.getType())->getBitWidth() < 16) { typ = DIFFE_TYPE::CONSTANT; @@ -3554,7 +3555,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction( types.push_back(typ); } - DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() + DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() && + mode != DerivativeMode::ForwardMode ? DIFFE_TYPE::OUT_DIFF : DIFFE_TYPE::DUP_ARG; if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||