From 44daaad410cab06d8783a77baee276f876d06f8e Mon Sep 17 00:00:00 2001 From: parth <partharora99160808@gmail.com> Date: Tue, 13 Feb 2024 01:52:30 +0530 Subject: [PATCH] Add support for differentiating switch stmt in reverse mode AD This commit adds support for differentiating switch statements in the reverse mode AD. The basic idea used to differentiate switch statement is that in the forward pass, processing of the statements of the switch statement body always starts from a case/default label and ends at a break statement or at the end of the switch body. Similarly, in the reverse pass, processing of the differentiated statements of the switch statement body will start from the statement just above the break statement that was hit or from the last differentiated statement in the case when no break statement was hit. Thus, we can keep track of which break statement was hit in the forward pass or if no break statement was hit at all in a variable. This information is further used by an auxiliary switch statement in the reverse pass to jump the execution to the correct point (that is, differentiated statement of the statement just before the break statement that was hit in the forward pass). In this strategy, each switch case statement of the original function gets transformed to an if condition in the reverse pass. The if condition decides at runtime whether the processing of the differentiated statements of the switch statement body should stop or continue. This is again based on the fact that the processing of statements of the switch statement body always starts at a case statement. For an example, consider this code snippet: switch (count) { case 0: a += i; break; case 2: a += 4 * i; break; default: a += 10 * i; } case 0 of this code snippet gets transformed to the following in the differentiated function: forward pass: { case 0: a += i; } { clad::push(_t0, 1UL); // this is used to keep track if this break was hit; 1UL is used to represent the case number break; } reverse pass: case 1UL:; // this case is selected if the corresponding break was hit in the forward pass { { double _r_d0 = _d_a; _d_a += _r_d0; *_d_i += _r_d0; _d_a -= _r_d0; } if (0 == _cond0) // If case 0: was selected in the forward pass, we should break out of processing differentiated switch stmt body here. break; } --- lib/Differentiator/CladUtils.cpp | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 14 +++++++------- test/Gradient/Switch.C | 12 ++++++++---- test/Gradient/SwitchInit.C | 3 ++- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 459a07854..6be4b8349 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -634,7 +634,7 @@ namespace clad { } void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) { - if (auto *caseStmt = dyn_cast<CaseStmt>(SC)) + if (auto* caseStmt = dyn_cast<CaseStmt>(SC)) caseStmt->setSubStmt(subStmt); else cast<DefaultStmt>(SC)->setSubStmt(subStmt); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d732e1dac..bddb9de11 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3317,10 +3317,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr(); } - auto *activeBreakContHandler = PushBreakContStmtHandler( + auto* activeBreakContHandler = PushBreakContStmtHandler( /*forSwitchStmt=*/true); activeBreakContHandler->BeginCFSwitchStmtScope(); - auto *SSData = PushSwitchStmtInfo(); + auto* SSData = PushSwitchStmtInfo(); if (isInsideLoop) SSData->switchStmtCond = condTape->Last(); @@ -3371,8 +3371,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // ``` if (SSData->defaultIfBreakExpr) { Expr* breakCond = nullptr; - for (auto *SC : SSData->cases) { - if (auto *CS = dyn_cast<CaseStmt>(SC)) { + for (auto* SC : SSData->cases) { + if (auto* CS = dyn_cast<CaseStmt>(SC)) { if (breakCond) { breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond, BuildOp(BinaryOperatorKind::BO_NE, @@ -3423,13 +3423,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) { beginBlock(direction::forward); beginBlock(direction::reverse); - auto SSData = GetActiveSwitchStmtInfo(); + SwitchStmtInfo* SSData = GetActiveSwitchStmtInfo(); Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr); Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr); - auto *newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone, - rhsClone, noLoc, noLoc, noLoc); + auto* newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone, + rhsClone, noLoc, noLoc, noLoc); Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(), SSData->switchStmtCond); diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 04254e357..686a3f1b1 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -134,6 +134,7 @@ double fn2(double i, double j) { // CHECK: void fn2_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) { // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; // CHECK-NEXT: int _cond0; // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; @@ -144,7 +145,7 @@ double fn2(double i, double j) { // CHECK-NEXT: double _t6; // CHECK-NEXT: double res = 0; // CHECK-NEXT: { -// CHECK-NEXT: int count = 2; +// CHECK-NEXT: count = 2; // CHECK-NEXT: _cond0 = count; // CHECK-NEXT: switch (_cond0) { // CHECK-NEXT: _t0 = res; @@ -395,6 +396,7 @@ double fn4(double i, double j) { // CHECK-NEXT: double _t0; // CHECK-NEXT: clad::tape<unsigned long> _t1 = {}; // CHECK-NEXT: int _d_counter = 0; +// CHECK-NEXT: int counter = 0; // CHECK-NEXT: unsigned long _t2; // CHECK-NEXT: clad::tape<double> _t3 = {}; // CHECK-NEXT: double res = 0; @@ -411,7 +413,7 @@ double fn4(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: case 1: -// CHECK-NEXT: int counter = 2; +// CHECK-NEXT: counter = 2; // CHECK-NEXT: } // CHECK-NEXT: _t2 = 0; // CHECK-NEXT: while (counter--) @@ -481,12 +483,13 @@ double fn5(double i, double j) { // CHECK: void fn5_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) { // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; // CHECK-NEXT: int _cond0; // CHECK-NEXT: double _t0; // CHECK-NEXT: clad::tape<unsigned long> _t1 = {}; // CHECK-NEXT: double res = 0; // CHECK-NEXT: { -// CHECK-NEXT: int count = 1; +// CHECK-NEXT: count = 1; // CHECK-NEXT: _cond0 = count; // CHECK-NEXT: switch (_cond0) { // CHECK-NEXT: case 1: @@ -595,13 +598,14 @@ double fn7(double u, double v) { // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: unsigned long _t0; // CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: int i = 0; // CHECK-NEXT: clad::tape<int> _cond0 = {}; // CHECK-NEXT: clad::tape<double> _t1 = {}; // CHECK-NEXT: clad::tape<unsigned long> _t2 = {}; // CHECK-NEXT: clad::tape<double> _t3 = {}; // CHECK-NEXT: double res = 0; // CHECK-NEXT: _t0 = 0; -// CHECK-NEXT: for (int i = 0; i < 5; ++i) { +// CHECK-NEXT: for (i = 0; i < 5; ++i) { // CHECK-NEXT: _t0++; // CHECK-NEXT: { // CHECK-NEXT: switch (clad::push(_cond0, i)) { diff --git a/test/Gradient/SwitchInit.C b/test/Gradient/SwitchInit.C index 355b91855..f2112c8f4 100644 --- a/test/Gradient/SwitchInit.C +++ b/test/Gradient/SwitchInit.C @@ -19,6 +19,7 @@ double fn1(double i, double j) { // CHECK: void fn1_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) { // CHECK-NEXT: double _d_res = 0; // CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; // CHECK-NEXT: int _cond0; // CHECK-NEXT: double _t0; // CHECK-NEXT: clad::tape<unsigned long> _t1 = {}; @@ -27,7 +28,7 @@ double fn1(double i, double j) { // CHECK-NEXT: double _t4; // CHECK-NEXT: double res = 0; // CHECK-NEXT: { -// CHECK-NEXT: int count = 1; +// CHECK-NEXT: count = 1; // CHECK-NEXT: _cond0 = count; // CHECK-NEXT: switch (_cond0) { // CHECK-NEXT: {