Skip to content

Commit

Permalink
Use clad::zero_vector instead of 0 in the vectorized forward mode to …
Browse files Browse the repository at this point in the history
…ensure type safety.

This is necessary to ensure type safety. ``0`` represents a zero vector but is only properly treated as a zero vector when it gets added/subtracted with a ``clad::array``. Initializing a ``clad::array`` with ``0`` will create not a zero vector but an empty array (with zero elements). To fix this, we needed exceptions for all declarations. We would need to make new exceptions when passing parameters to pushforwards.
  • Loading branch information
PetroZarytskyi committed Nov 17, 2024
1 parent c8e867c commit 906f05a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 77 deletions.
4 changes: 2 additions & 2 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ class BaseForwardModeVisitor
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
virtual StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
StmtDiff VisitIfStmt(const clang::IfStmt* If);
StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE);
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
virtual StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
/// For example: for size = 4, the returned expression is: {0, 0, 0, 0}
clang::Expr* getZeroInitListExpr(size_t size, clang::QualType type);

StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL) override;
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL) override;
StmtDiff
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
Expand Down
48 changes: 19 additions & 29 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,10 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
Expr* derivedRetValE = retValDiff.getExpr_dx();
// If we are in vector mode, we need to wrap the return value in a
// vector.
SourceLocation loc{m_DiffReq->getLocation()};
llvm::SmallVector<Expr*, 2> args = {m_IndVarCountExpr, derivedRetValE};
QualType cladArrayType = GetCladArrayOfType(utils::GetValueType(retType));
TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(cladArrayType, loc);
Expr* constructorCallExpr =
m_Sema.BuildCXXTypeConstructExpr(TSI, loc, args, loc, false).get();
auto dVectorParamDecl =
BuildVarDecl(cladArrayType, "_d_vector_return", constructorCallExpr,
false, nullptr, VarDecl::InitializationStyle::CallInit);
BuildVarDecl(cladArrayType, "_d_vector_return", derivedRetValE, false,
nullptr, VarDecl::InitializationStyle::CallInit);
// Create an array of statements to hold the return statement and the
// assignments to the derivatives of the parameters.
Stmts returnStmts;
Expand Down Expand Up @@ -591,34 +586,29 @@ VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
VarDecl* VDClone =
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// Create an expression to initialize the derivative vector of the
// size of the number of parameters to be differentiated and initialize
// the derivative vector to the derivative expression.
//
// For example:
// clad::array<double> _d_vector_y(2, ...);
//
// This will also help in initializing the derivative vector in the
// case when initExprDx is not an array.
// So, for example, if we have:
// clad::array<double> _d_vector_y(2, 1);
// this means that we have to initialize the derivative vector of
// size 2 with all elements equal to 1.
SourceLocation loc{m_DiffReq->getLocation()};
llvm::SmallVector<Expr*, 2> args = {m_IndVarCountExpr, initDiff.getExpr_dx()};
QualType cladArrayType =
GetCladArrayOfType(utils::GetValueType(VD->getType()));
TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(cladArrayType, loc);
Expr* constructorCallExpr =
m_Sema.BuildCXXTypeConstructExpr(TSI, loc, args, loc, false).get();

VarDecl* VDDerived =
BuildVarDecl(GetCladArrayOfType(utils::GetValueType(VD->getType())),
"_d_vector_" + VD->getNameAsString(), constructorCallExpr,
"_d_vector_" + VD->getNameAsString(), initDiff.getExpr_dx(),
false, nullptr, VarDecl::InitializationStyle::CallInit);

m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

StmtDiff VectorForwardModeVisitor::VisitFloatingLiteral(
const clang::FloatingLiteral* FL) {
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
auto* zero_vec = BuildCallExprToCladFunction(
"zero_vector", {m_IndVarCountExpr}, {FL->getType()}, fakeLoc);
return StmtDiff(Clone(FL), zero_vec);
}

StmtDiff
VectorForwardModeVisitor::VisitIntegerLiteral(const clang::IntegerLiteral* IL) {
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
auto* zero_vec = BuildCallExprToCladFunction(
"zero_vector", {m_IndVarCountExpr}, {IL->getType()}, fakeLoc);
return StmtDiff(Clone(IL), zero_vec);
}

} // namespace clad
16 changes: 8 additions & 8 deletions test/Arrays/ArrayInputsVectorForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ double multiply(const double *arr) {
// CHECK-NEXT: unsigned {{int|long|long long}} indepVarCount = _d_arr.size();
// CHECK-NEXT: clad::matrix<double> _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, {{0U|0UL|0ULL}});
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _d_vector_return(clad::array<double>(indepVarCount, (_d_vector_arr[0]) * arr[1] + arr[0] * (_d_vector_arr[1])));
// CHECK-NEXT: clad::array<double> _d_vector_return((_d_vector_arr[0]) * arr[1] + arr[0] * (_d_vector_arr[1]));
// CHECK-NEXT: _d_arr = _d_vector_return.slice({{0U|0UL|0ULL}}, _d_arr.size());
// CHECK-NEXT: return;
// CHECK-NEXT: }
Expand All @@ -25,7 +25,7 @@ double divide(double *arr) {
// CHECK-NEXT: unsigned {{int|long|long long}} indepVarCount = _d_arr.size();
// CHECK-NEXT: clad::matrix<double> _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, {{0U|0UL|0ULL}});
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _d_vector_return(clad::array<double>(indepVarCount, ((_d_vector_arr[0]) * arr[1] - arr[0] * (_d_vector_arr[1])) / (arr[1] * arr[1])));
// CHECK-NEXT: _d_vector_return(((_d_vector_arr[0]) * arr[1] - arr[0] * (_d_vector_arr[1])) / (arr[1] * arr[1]));
// CHECK-NEXT: _d_arr = _d_vector_return.slice({{0U|0UL|0ULL}}, _d_arr.size());
// CHECK-NEXT: return;
// CHECK-NEXT: }
Expand All @@ -43,17 +43,17 @@ double addArr(const double *arr, int n) {
// CHECK-NEXT: unsigned {{int|long|long long}} indepVarCount = _d_arr.size();
// CHECK-NEXT: clad::matrix<double> _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, {{0U|0UL|0ULL}});
// CHECK-NEXT: clad::array<int> _d_vector_n = clad::zero_vector(indepVarCount);
// CHECK-NEXT: clad::array<double> _d_vector_ret(clad::array<double>(indepVarCount, 0));
// CHECK-NEXT: clad::array<double> _d_vector_ret(clad::zero_vector(indepVarCount));
// CHECK-NEXT: double ret = 0;
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<int> _d_vector_i(clad::array<int>(indepVarCount, 0));
// CHECK-NEXT: clad::array<int> _d_vector_i(clad::zero_vector(indepVarCount));
// CHECK-NEXT: for (int i = 0; i < n; i++) {
// CHECK-NEXT: _d_vector_ret += _d_vector_arr[i];
// CHECK-NEXT: ret += arr[i];
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _d_vector_return(clad::array<double>(indepVarCount, _d_vector_ret));
// CHECK-NEXT: clad::array<double> _d_vector_return(_d_vector_ret);
// CHECK-NEXT: _d_arr = _d_vector_return.slice({{0U|0UL|0ULL}}, _d_arr.size());
// CHECK-NEXT: return;
// CHECK-NEXT: }
Expand All @@ -77,10 +77,10 @@ double maskedSum(const double *arr, int n, int *signedMask, double alpha, double
// CHECK-NEXT: clad::array<int> _d_vector_n = clad::zero_vector(indepVarCount);
// CHECK-NEXT: clad::array<double> _d_vector_alpha = clad::one_hot_vector(indepVarCount, _d_arr.size());
// CHECK-NEXT: clad::array<double> _d_vector_beta = clad::one_hot_vector(indepVarCount, _d_arr.size() + {{1U|1UL|1ULL}});
// CHECK-NEXT: clad::array<double> _d_vector_ret(clad::array<double>(indepVarCount, 0));
// CHECK-NEXT: clad::array<double> _d_vector_ret(clad::zero_vector(indepVarCount));
// CHECK-NEXT: double ret = 0;
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<int> _d_vector_i(clad::array<int>(indepVarCount, 0));
// CHECK-NEXT: clad::array<int> _d_vector_i(clad::zero_vector(indepVarCount));
// CHECK-NEXT: for (int i = 0; i < n; i++) {
// CHECK-NEXT: if (signedMask[i] > 0) {
// CHECK-NEXT: _d_vector_ret += _d_vector_alpha * arr[i] + alpha * (_d_vector_arr[i]);
Expand All @@ -92,7 +92,7 @@ double maskedSum(const double *arr, int n, int *signedMask, double alpha, double
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _d_vector_return(clad::array<double>(indepVarCount, _d_vector_ret));
// CHECK-NEXT: clad::array<double> _d_vector_return(_d_vector_ret);
// CHECK-NEXT: _d_arr = _d_vector_return.slice({{0U|0UL|0ULL}}, _d_arr.size());
// CHECK-NEXT: *_d_alpha = _d_vector_return[_d_arr.size()];
// CHECK-NEXT: *_d_beta = _d_vector_return[_d_arr.size() + {{1U|1UL|1ULL}}];
Expand Down
Loading

0 comments on commit 906f05a

Please sign in to comment.