Skip to content

Commit

Permalink
Implement visitor to check varied expression
Browse files Browse the repository at this point in the history
Max Andriychuk committed Nov 7, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 90d17a0 commit b41e0eb
Showing 4 changed files with 71 additions and 15 deletions.
7 changes: 2 additions & 5 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
@@ -122,15 +122,11 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
m_Varied = true;
m_Marking = true;
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
m_Varied = false;
m_Marking = false;
}
return true;
}
@@ -141,7 +137,8 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied)
QualType VDTy = cast<VarDecl>(D)->getType();
if (m_Varied || utils::isArrayOrPointerType(VDTy))
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
14 changes: 5 additions & 9 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -620,20 +620,16 @@ namespace clad {
if (!EnableVariedAnalysis)
return true;

if (VD->getType()->isPointerType() || isa<ArrayType>(VD->getType()))
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
ArrayRef<ParmVarDecl*> FDparam = Function->parameters();
std::vector<ParmVarDecl*> derivedParam;

for(auto* parameter: FDparam){
for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
if(parType->isPointerType()){
if(!parType->getPointeeType().isConstQualified())
derivedParam.push_back(parameter);
}else if(!parType.isConstQualified())
derivedParam.push_back(parameter);
while (parType->isPointerType())
parType = parType->getPointeeType();
if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

std::copy(derivedParam.begin(), derivedParam.end(),
19 changes: 18 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -1800,7 +1800,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts();
if (!arg->isEvaluatable(m_Context)) {
class VariedChecker : public RecursiveASTVisitor<VariedChecker> {
const DiffRequest& Request;

public:
VariedChecker(const DiffRequest& DR) : Request(DR) {}
bool isVariedE(const clang::Expr* E) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return !TraverseStmt(const_cast<clang::Expr*>(E));
}
bool VisitDeclRefExpr(const clang::DeclRefExpr* DRE) {
if (!isa<VarDecl>(DRE->getDecl()))
return true;
if (Request.shouldHaveAdjoint(cast<VarDecl>(DRE->getDecl())))
return false;
return true;
}
} analyzer(m_DiffReq);
if (analyzer.isVariedE(arg)) {
allArgsAreConstantLiterals = false;
break;
}
46 changes: 46 additions & 0 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
@@ -264,6 +264,46 @@ double f8(double x){
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn9(double x, double const *obs)
{
double res = 0.0;
for (int loopIdx0 = 0; loopIdx0 < 2; loopIdx0++) {
res += std::lgamma(obs[2 + loopIdx0] + 1) + x;
}
return res;
}

// CHECK: void fn9_grad(double x, const double *obs, double *_d_x, double *_d_obs) {
// CHECK-NEXT: int loopIdx0 = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0.;
// CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
// CHECK-NEXT: for (loopIdx0 = 0; ; loopIdx0++) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!(loopIdx0 < 2))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, res);
// CHECK-NEXT: res += std::lgamma(obs[2 + loopIdx0] + 1) + x;
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: loopIdx0--;
// CHECK-NEXT: {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_x += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }


#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_va>(F);\
@@ -272,7 +312,10 @@ double f8(double x){
}

int main(){
double arr[] = {1,2,3,4,5};
double darr[] = {0,0,0,0,0};
double result[3] = {};
double dx;
TEST(f1, 3);// CHECK-EXEC: {6.00}
TEST(f2, 3);// CHECK-EXEC: {6.00}
TEST(f3, 3);// CHECK-EXEC: {0.00}
@@ -281,6 +324,9 @@ int main(){
TEST(f6, 3);// CHECK-EXEC: {0.00}
TEST(f7, 3);// CHECK-EXEC: {1.00}
TEST(f8, 3);// CHECK-EXEC: {1.00}
auto grad = clad::gradient<clad::opts::enable_va>(fn9);
grad.execute(3, arr, &dx, darr);
printf("%.2f\n", dx);// CHECK-EXEC: 2.00
}

// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) {

0 comments on commit b41e0eb

Please sign in to comment.