Skip to content

Commit

Permalink
Add support for differentiating switch stmt in reverse mode AD
Browse files Browse the repository at this point in the history
This commit adds support for differentiating switch statements in
the reverse mode AD.

The basic idea used to differentiate switch statement is that in the
forward pass, processing of the statements of the switch statement body
always starts from a case/default label and ends at a break statement or
at the end of the switch body.

Similarly, in the reverse pass, processing of the differentiated statements
of the switch statement body will start from the statement just above the
break statement that was hit or from the last differentiated statement in
the case when no break statement was hit.

Thus, we can keep track of which break statement was hit in the forward pass
or if no break statement was hit at all in a variable. This information is
further used by an auxiliary switch statement in the reverse pass to jump the
execution to the correct point (that is, differentiated statement of the
statement just before the break statement that was hit in the forward pass).

In this strategy, each switch case statement of the original function gets
transformed to an if condition in the reverse pass. The if condition decides
at runtime whether the processing of the differentiated statements of the switch
statement body should stop or continue. This is again based on the fact
that the processing of statements of the switch statement body always starts
at a case statement.

For an example, consider this code snippet:

switch (count) {
    case 0: a += i; break;
    case 2: a += 4 * i; break;
    default: a += 10 * i;
}

case 0 of this code snippet gets transformed to the following in the
differentiated function:

forward pass:

{
  case 0: a += i;
}
{
  clad::push(_t0, 1UL); // this is used to keep track if this break was hit; 1UL is used to represent the case number
  break;
}

reverse pass:

case 1UL:;  // this case is selected if the corresponding break was hit in the forward pass
  {
    {
      double _r_d0 = _d_a;
      _d_a += _r_d0;
      *_d_i += _r_d0;
      _d_a -= _r_d0;
    }
    if (0 == _cond0)  // If case 0: was selected in the forward pass, we should break out of processing differentiated switch stmt body here.
      break;
  }
  • Loading branch information
parth-07 committed Feb 12, 2024
1 parent 5acae9f commit 44daaad
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ namespace clad {
}

void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) {
if (auto *caseStmt = dyn_cast<CaseStmt>(SC))
if (auto* caseStmt = dyn_cast<CaseStmt>(SC))
caseStmt->setSubStmt(subStmt);
else
cast<DefaultStmt>(SC)->setSubStmt(subStmt);
Expand Down
14 changes: 7 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3317,10 +3317,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr();
}

auto *activeBreakContHandler = PushBreakContStmtHandler(
auto* activeBreakContHandler = PushBreakContStmtHandler(
/*forSwitchStmt=*/true);
activeBreakContHandler->BeginCFSwitchStmtScope();
auto *SSData = PushSwitchStmtInfo();
auto* SSData = PushSwitchStmtInfo();

if (isInsideLoop)
SSData->switchStmtCond = condTape->Last();
Expand Down Expand Up @@ -3371,8 +3371,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ```
if (SSData->defaultIfBreakExpr) {
Expr* breakCond = nullptr;
for (auto *SC : SSData->cases) {
if (auto *CS = dyn_cast<CaseStmt>(SC)) {
for (auto* SC : SSData->cases) {
if (auto* CS = dyn_cast<CaseStmt>(SC)) {
if (breakCond) {
breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond,
BuildOp(BinaryOperatorKind::BO_NE,
Expand Down Expand Up @@ -3423,13 +3423,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) {
beginBlock(direction::forward);
beginBlock(direction::reverse);
auto SSData = GetActiveSwitchStmtInfo();
SwitchStmtInfo* SSData = GetActiveSwitchStmtInfo();

Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr);
Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr);

auto *newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone,
rhsClone, noLoc, noLoc, noLoc);
auto* newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone,
rhsClone, noLoc, noLoc, noLoc);

Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(),
SSData->switchStmtCond);
Expand Down
12 changes: 8 additions & 4 deletions test/Gradient/Switch.C
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ double fn2(double i, double j) {
// CHECK: void fn2_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: int _d_count = 0;
// CHECK-NEXT: int count = 0;
// CHECK-NEXT: int _cond0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
Expand All @@ -144,7 +145,7 @@ double fn2(double i, double j) {
// CHECK-NEXT: double _t6;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int count = 2;
// CHECK-NEXT: count = 2;
// CHECK-NEXT: _cond0 = count;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: _t0 = res;
Expand Down Expand Up @@ -395,6 +396,7 @@ double fn4(double i, double j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: clad::tape<unsigned long> _t1 = {};
// CHECK-NEXT: int _d_counter = 0;
// CHECK-NEXT: int counter = 0;
// CHECK-NEXT: unsigned long _t2;
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: double res = 0;
Expand All @@ -411,7 +413,7 @@ double fn4(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case 1:
// CHECK-NEXT: int counter = 2;
// CHECK-NEXT: counter = 2;
// CHECK-NEXT: }
// CHECK-NEXT: _t2 = 0;
// CHECK-NEXT: while (counter--)
Expand Down Expand Up @@ -481,12 +483,13 @@ double fn5(double i, double j) {
// CHECK: void fn5_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: int _d_count = 0;
// CHECK-NEXT: int count = 0;
// CHECK-NEXT: int _cond0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: clad::tape<unsigned long> _t1 = {};
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int count = 1;
// CHECK-NEXT: count = 1;
// CHECK-NEXT: _cond0 = count;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: case 1:
Expand Down Expand Up @@ -595,13 +598,14 @@ double fn7(double u, double v) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: unsigned long _t0;
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: int i = 0;
// CHECK-NEXT: clad::tape<int> _cond0 = {};
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: clad::tape<unsigned long> _t2 = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: for (int i = 0; i < 5; ++i) {
// CHECK-NEXT: for (i = 0; i < 5; ++i) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: {
// CHECK-NEXT: switch (clad::push(_cond0, i)) {
Expand Down
3 changes: 2 additions & 1 deletion test/Gradient/SwitchInit.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ double fn1(double i, double j) {
// CHECK: void fn1_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: int _d_count = 0;
// CHECK-NEXT: int count = 0;
// CHECK-NEXT: int _cond0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: clad::tape<unsigned long> _t1 = {};
Expand All @@ -27,7 +28,7 @@ double fn1(double i, double j) {
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int count = 1;
// CHECK-NEXT: count = 1;
// CHECK-NEXT: _cond0 = count;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: {
Expand Down

0 comments on commit 44daaad

Please sign in to comment.