Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify error estimation by removing _EERepl_ and _delta_ #773

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 19 additions & 105 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
// multiple header files.
// `Stmts` is originally defined in `VisitorBase`.
using Stmts = llvm::SmallVector<clang::Stmt*, 16>;
/// Keeps a track of the delta error expression we shouldn't emit.
bool m_DoNotEmitDelta;
/// Reference to the final error parameter in the augumented target
/// function.
clang::Expr* m_FinalError;
Expand All @@ -42,23 +40,20 @@ class ErrorEstimationHandler : public ExternalRMVSource {
Stmts m_ReverseErrorStmts;
/// The index expression for emitting final errors for input param errors.
clang::Expr* m_IdxExpr;
/// A set of declRefExprs for parameter value replacements.
std::unordered_map<const clang::VarDecl*, clang::Expr*> m_ParamRepls;
/// An expression to match nested function call errors with their
/// assignee (if any exists).
clang::Expr* m_NestedFuncError = nullptr;

std::stack<bool> m_ShouldEmit;
ReverseModeVisitor* m_RMV;
clang::Expr* m_DeltaVar = nullptr;
llvm::SmallVectorImpl<clang::QualType>* m_ParamTypes = nullptr;
llvm::SmallVectorImpl<clang::ParmVarDecl*>* m_Params = nullptr;

public:
using direction = rmv::direction;
ErrorEstimationHandler()
: m_DoNotEmitDelta(false), m_FinalError(nullptr), m_RetErrorExpr(nullptr),
m_EstModel(nullptr), m_IdxExpr(nullptr) {}
: m_FinalError(nullptr), m_RetErrorExpr(nullptr), m_EstModel(nullptr),
m_IdxExpr(nullptr) {}
~ErrorEstimationHandler() override = default;

/// Function to set the error estimation model currently in use.
Expand All @@ -70,33 +65,16 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] finErrExpr The final error expression.
void SetFinalErrorExpr(clang::Expr* finErrExpr) { m_FinalError = finErrExpr; }

/// Shorthand to get array subscript expressions.
///
/// \param[in] arrBase The base expression of the array.
/// \param[in] idx The index expression.
/// \param[in] isCladSpType Keeps track of if we have to build a clad
/// special type (i.e. clad::Array or clad::ArrayRef).
///
/// \returns An expression of the kind arrBase[idx].
clang::Expr* getArraySubscriptExpr(clang::Expr* arrBase, clang::Expr* idx,
bool isCladSpType = true);

/// \returns The final error expression so far.
clang::Expr* GetFinalErrorExpr() { return m_FinalError; }

/// Function to build the final error statemnt of the function. This is the
/// last statement of any target function in error estimation and
/// aggregates the error in all the registered variables.
void BuildFinalErrorStmt();
/// Function to build the error statement corresponding
/// to the function's return statement.
void BuildReturnErrorStmt();

/// Function to emit error statements into the derivative body.
///
/// \param[in] var The variable whose error statement we want to emit.
/// \param[in] deltaVar The "_delta_" expression of the variable 'var'.
/// \param[in] errorExpr The error expression (LHS) of the variable 'var'.
/// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop.
void AddErrorStmtToBlock(clang::Expr* var, clang::Expr* deltaVar,
clang::Expr* errorExpr, bool isInsideLoop = false);
/// \param[in] errorExpr The error expression (LHS) of the variable.
/// \param[in] addToTheFront A flag to decide whether the error stmts
/// should be added to the beginning of the block or the current position.
void AddErrorStmtToBlock(clang::Expr* errorExpr, bool addToTheFront = true);

/// Emit the error estimation related statements that were saved to be
/// emitted at later points into specific blocks.
Expand Down Expand Up @@ -124,44 +102,12 @@ class ErrorEstimationHandler : public ExternalRMVSource {
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, size_t numArgs);

/// Save values of registered variables so that they can be replaced
/// properly in case of re-assignments.
///
/// \param[in] val The value to save.
/// \param[in] isInsideLoop A flag to indicate if 'val' is inside a loop.
///
/// \returns The saved variable and its derivative.
StmtDiff SaveValue(clang::Expr* val, bool isInLoop = false);

/// Save the orignal values of the input parameters in case of
/// re-assignments.
///
/// \param[in] paramRef The DeclRefExpr of the input parameter.
void SaveParamValue(clang::DeclRefExpr* paramRef);

/// Register variables to be used while accumulating error.
/// Register variable declarations so that they may be used while
/// calculating the final error estimates. Any unregistered variables will
/// not be considered for the final estimation.
///
/// \param[in] VD The variable declaration to be registered.
/// \param[in] toCurrentScope Add the created "_delta_" variable declaration
/// to the current scope instead of the global scope.
///
/// \returns The Variable declaration of the '_delta_' prefixed variable.
clang::Expr* RegisterVariable(clang::VarDecl* VD,
bool toCurrentScope = false);

/// Checks if a variable can be registered for error estimation.
/// Checks if a variable should be considered in error estimation.
///
/// \param[in] VD The variable declaration to be registered.
/// \param[in] VD The variable declaration.
///
/// \returns True if the variable can be registered, false otherwise.
bool CanRegisterVariable(clang::VarDecl* VD);

/// Calculate aggregate error from m_EstimateVar.
/// Builds the final error estimation statement.
clang::Stmt* CalculateAggregateError();
/// \returns true if the variable should be considered, false otherwise.
bool ShouldEstimateErrorFor(clang::VarDecl* VD);

/// Get the underlying DeclRefExpr type it it exists.
///
Expand All @@ -170,14 +116,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \returns The DeclRefExpr of input or null.
clang::DeclRefExpr* GetUnderlyingDeclRefOrNull(clang::Expr* expr);

/// Get the parameter replacement (if any).
///
/// \param[in] VD The parameter variable declaration to get replacement
/// for.
///
/// \returns The underlying replaced Expr.
clang::Expr* GetParamReplacement(const clang::ParmVarDecl* VD);

/// An abstraction of the error estimation model's AssignError.
///
/// \param[in] val The variable to get the error for.
Expand All @@ -190,16 +128,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
return m_EstModel->AssignError({var, varDiff}, varName);
}

/// An abstraction of the error estimation model's IsVariableRegistered.
///
/// \param[in] VD The variable declaration to check the status of.
///
/// \returns the reference to the respective '_delta_' expression if the
/// variable is registered, null otherwise.
clang::Expr* IsRegistered(clang::VarDecl* VD) {
return m_EstModel->IsVariableRegistered(VD);
}

/// This function adds the final error and the other parameter errors to the
/// forward block.
///
Expand All @@ -215,17 +143,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// loop.
void EmitUnaryOpErrorStmts(StmtDiff var, bool isInsideLoop);

/// This function registers all LHS declRefExpr in binary operations.
///
/// \param[in] LExpr The LHS of the operation.
/// \param[in] RExpr The RHS of the operation.
/// \param[in] isAssign A flag to know if the current operation is a simple
/// assignment.
///
/// \returns The delta value of the input 'var'.
clang::Expr* RegisterBinaryOpLHS(clang::Expr* LExpr, clang::Expr* RExpr,
bool isAssign);

/// This function emits the error in a binary operation.
///
/// \param[in] LExpr The LHS of the operation.
Expand All @@ -234,8 +151,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] deltaVar The delta value of the LHS.
/// \param[in] isInsideLoop A flag to keep track of if we are inside a
/// loop.
void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue,
clang::Expr* deltaVar, bool isInsideLoop);
void EmitBinaryOpErrorStmts(clang::Expr* LExpr, clang::Expr* oldValue);

/// This function emits the error in declaration statements.
///
Expand All @@ -256,21 +172,19 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override;
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
void ActBeforeDifferentiatingSingleStmtLoopBody() override;
void ActAfterProcessingSingleStmtBodyInVisitForLoop() override;
void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) override;
void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override;
void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) override;
void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override;
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
void
ActAfterCloningLHSOfAssignOp(clang::Expr*& LCloned, clang::Expr*& R,
clang::BinaryOperator::Opcode& opCode) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override;
void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
Expand Down
39 changes: 0 additions & 39 deletions include/clad/Differentiator/EstimationModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,6 @@ namespace clad {
/// Clear the variable estimate map so that we can start afresh.
void clearEstimationVariables() { m_EstimateVar.clear(); }

/// Check if a variable is registered for estimation.
///
/// \param[in] VD The variable to check.
///
/// \returns The delta expression of the variable if it is registered,
/// nullptr otherwise.
clang::Expr* IsVariableRegistered(const clang::VarDecl* VD);

/// Track the variable declaration and utilize it in error
/// estimation.
///
/// \param[in] VD The declaration to track.
void AddVarToEstimate(clang::VarDecl* VD, clang::Expr* VDRef);

/// Helper to build a function call expression.
///
/// \param[in] funcName The name of the function to build the expression
Expand Down Expand Up @@ -86,31 +72,6 @@ namespace clad {
virtual clang::Expr* AssignError(StmtDiff refExpr,
const std::string& name) = 0;

/// Initializes errors for '_delta_' statements.
/// This function returns the initial error assignment. Similar to
/// AssignError, however, this function is only called during declaration of
/// variables. This function is separate from AssignError to keep
/// implementation of different estimation models more flexible.
///
/// The default definition is as follows:
/// \n \code
/// clang::Expr* SetError(clang::VarDecl* declStmt) {
/// return nullptr;
/// }
/// \endcode
/// The above will return a 0 expression to be assigned to the '_delta_'
/// declaration of input decl.
///
/// \param[in] decl The declaration to which the error has to be assigned.
///
/// \returns The error expression for declaration statements.
virtual clang::Expr* SetError(clang::VarDecl* decl);

/// Calculate aggregate error from m_EstimateVar.
///
/// \returns the final error estimation statement.
clang::Expr* CalculateAggregateError();

friend class ErrorEstimationHandler;
};

Expand Down
12 changes: 7 additions & 5 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

/// This is called just before finalising processing of Single statement
/// branch in `VisitBranch` lambda in
virtual void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() {}
virtual void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() {}

Check warning on line 107 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L107

Added line #L107 was not covered by tests

/// This is called just before differentiating init statement of loops.
virtual void ActBeforeDifferentiatingLoopInitStmt() {}
Expand All @@ -117,7 +117,7 @@
virtual void ActAfterProcessingSingleStmtBodyInVisitForLoop() {}

/// This is called just before finalising `VisitReturnStmt`.
virtual void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) {}
virtual void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) {}

Check warning on line 120 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L120

Added line #L120 was not covered by tests

/// This ic called just before finalising `VisitCallExpr`.
///
Expand All @@ -131,15 +131,17 @@

/// This is called just before finalising processing of post and pre
/// increment and decrement operations.
virtual void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff){};
virtual void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff){};

Check warning on line 134 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L134

Added line #L134 was not covered by tests

/// This is called just after cloning of LHS assignment operation.
virtual void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) {
}

/// This is called just after finaising processing of assignment operator.
virtual void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&){};
/// This is called just after finalising processing of assignment operator.
virtual void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&,

Check warning on line 142 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L142

Added line #L142 was not covered by tests
clang::Expr*&,
clang::BinaryOperator::Opcode&){};

Check warning on line 144 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L144

Added line #L144 was not covered by tests

/// This is called at that beginning of
/// `ReverseModeVisitor::DifferentiateSingleStmt`.
Expand Down
9 changes: 5 additions & 4 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override;
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
void ActBeforeDifferentiatingSingleStmtLoopBody() override;
void ActAfterProcessingSingleStmtBodyInVisitForLoop() override;
void ActBeforeFinalisingVisitReturnStmt(StmtDiff& retExprDiff) override;
void ActBeforeFinalizingVisitReturnStmt(StmtDiff& retExprDiff) override;
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
void ActBeforeFinalisingPostIncDecOp(StmtDiff& diff) override;
void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override;
void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) override;
void ActBeforeFinalisingAssignOp(clang::Expr*&, clang::Expr*&) override;
void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActOnStartOfDifferentiateSingleStmt() override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
Expand Down
7 changes: 7 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ namespace clad {
clang::Expr*
BuildArraySubscript(clang::Expr* Base,
const llvm::SmallVectorImpl<clang::Expr*>& IS);

/// Build an array subscript expression with a given base expression and
/// one index.
clang::Expr* BuildArraySubscript(clang::Expr* Base, clang::Expr*& Idx) {
llvm::SmallVector<clang::Expr*, 1> IS = {Idx};
return BuildArraySubscript(Base, IS);
}
/// Find namespace clad declaration.
clang::NamespaceDecl* GetCladNamespace();
/// Find declaration of clad::class templated type
Expand Down
Loading
Loading