Skip to content

Commit

Permalink
Don't mark nonvaried constant params
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Andriychuk committed Nov 21, 2024
1 parent 7e0901f commit 63d431f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 33 deletions.
7 changes: 7 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ struct DiffRequest {

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;

void setToBeRecorded(std::set<const clang::VarDecl*> init) {
this->m_ActivityRunInfo.ToBeRecorded = init;
}
std::set<const clang::VarDecl*> getToBeRecorded() const {
return this->m_ActivityRunInfo.ToBeRecorded;
}
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
42 changes: 27 additions & 15 deletions lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,37 +126,54 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
clang::Expr* par = CE->getArg(i);

QualType parType = FDparam[i]->getType();
while (parType->isPointerType())
parType = parType->getPointeeType();
if((parType->isReferenceType() || utils::isArrayOrPointerType(parType)) && !parType.isConstQualified()){
QualType innermostType = parType;
while (innermostType->isPointerType())
innermostType = innermostType->getPointeeType();

if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = true;
m_Varied = true;
}

TraverseStmt(par);
if ((parType->isReferenceType() ||
utils::isArrayOrPointerType(parType)) &&
!innermostType.isConstQualified()) {
m_Marking = false; //?
m_Varied = false;
}

m_Marking = false;
m_Varied = false;

if(!parType.isConstQualified())
if ((m_Varied || !innermostType.isConstQualified() ||
!parType->isPointerType())) {
m_VariedDecls.insert(FDparam[i]);
}
}
}
return true;
}

bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
for (Decl* D : DS->decls()) {

QualType VDTy = cast<VarDecl>(D)->getType();
if(utils::isArrayOrPointerType(VDTy)){
QualType innermost = VDTy;
while (innermost->isPointerType())
innermost = innermost->getPointeeType();
if (VDTy->isPointerType() && !innermost.isConstQualified()) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
} else if (VDTy->isArrayType()) {
copyVarToCurBlock(cast<VarDecl>(D));
continue;
}

if (Expr* init = cast<VarDecl>(D)->getInit()) {
m_Varied = false;
TraverseStmt(init);
m_Marking = true;
if (m_Varied )
if (m_Varied)
copyVarToCurBlock(cast<VarDecl>(D));
m_Marking = false;
}
Expand All @@ -165,12 +182,7 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) {
}

bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) {
const auto opCode = UnOp->getOpcode();
Expr* E = UnOp->getSubExpr();
if (opCode == UO_AddrOf || opCode == UO_Deref) {
m_Varied = true;
m_Marking = true;
}
TraverseStmt(E);
m_Marking = false;
return true;
Expand All @@ -181,7 +193,7 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) {
if (!VD)
return true;

if (isVaried(VD))
if (isVaried(VD) || VD->getType()->isArrayType())
m_Varied = true;

if (m_Varied && m_Marking)
Expand Down
17 changes: 3 additions & 14 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,20 +637,9 @@ namespace clad {
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
ArrayRef<ParmVarDecl*> FDparam = Function->parameters();
std::vector<ParmVarDecl*> derivedParam;

for (auto* parameter : FDparam) {
QualType parType = parameter->getType();
while (parType->isPointerType())
parType = parType->getPointeeType();
if (!parType.isConstQualified())
derivedParam.push_back(parameter);
}

std::copy(derivedParam.begin(), derivedParam.end(),
std::inserter(m_ActivityRunInfo.ToBeRecorded,
m_ActivityRunInfo.ToBeRecorded.end()));
if (Args)
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));

VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.ToBeRecorded);
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2201,6 +2201,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
pullbackRequest.setToBeRecorded(m_DiffReq.getToBeRecorded());
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
Expand Down
28 changes: 24 additions & 4 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@

#include "clad/Differentiator/Differentiator.h"

inline double interpolate1d(double low, double high, double val, unsigned int numBins, double const* vals)
{
double binWidth = (high - low) / numBins;
int idx = val >= high ? numBins - 1 : std::abs((val - low) / binWidth);

// interpolation
double central = low + (idx + 0.5) * binWidth;
if (val > low + 0.5 * binWidth && val < high - 0.5 * binWidth) {
double slope;
if (val < central) {
slope = vals[idx] - vals[idx - 1];
} else {
slope = vals[idx + 1] - vals[idx];
}
return vals[idx] + slope * (val - central) / binWidth;
}

return vals[idx];
}

double f1(double x){
double a = x*x;
double b = 1;
Expand Down Expand Up @@ -273,7 +293,7 @@ double f9(double x, double const *obs)
return res;
}

// CHECK: void f9_grad(double x, const double *obs, double *_d_x, double *_d_obs) {
// CHECK: void f9_grad_0(double x, const double *obs, double *_d_x) {
// CHECK-NEXT: int loopIdx0 = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: double _d_res = 0.;
Expand Down Expand Up @@ -364,7 +384,7 @@ int main(){
double arr[] = {1,2,3,4,5};
double darr[] = {0,0,0,0,0};
double result[3] = {};
double dx;
double dx = 0;
TEST(f1, 3);// CHECK-EXEC: {6.00}
TEST(f2, 3);// CHECK-EXEC: {6.00}
TEST(f3, 3);// CHECK-EXEC: {0.00}
Expand All @@ -373,8 +393,8 @@ int main(){
TEST(f6, 3);// CHECK-EXEC: {0.00}
TEST(f7, 3);// CHECK-EXEC: {1.00}
TEST(f8, 3);// CHECK-EXEC: {1.00}
auto grad = clad::gradient<clad::opts::enable_va>(f9);
grad.execute(3, arr, &dx, darr);
auto grad9 = clad::gradient<clad::opts::enable_va>(f9, "x");
grad9.execute(3, arr, &dx, darr);
printf("%.2f\n", dx);// CHECK-EXEC: 2.00
TEST(f10, 3);// CHECK-EXEC: {1.00}
TEST(f11, 3);// CHECK-EXEC: {1.00}
Expand Down

0 comments on commit 63d431f

Please sign in to comment.