Skip to content

Commit

Permalink
Support differentiation of switch condition
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 committed Feb 2, 2024
1 parent ad01388 commit c6f3adf
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
27 changes: 15 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ConstantFolder.h"

#include "TBRAnalyzer.h"
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ExternalRMVSource.h"
Expand All @@ -17,7 +18,9 @@

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand Down Expand Up @@ -3252,7 +3255,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
// Scope and blocks for the compound statement that encloses the switch
// statement in both the forward and the reverse pass. Block is required
// handling condition variable and switch-init statement.
// for handling condition variable and switch-init statement.
beginScope(Scope::DeclScope);
beginBlock(direction::forward);
beginBlock(direction::reverse);
Expand All @@ -3271,14 +3274,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(condVarDiff.getStmt(), direction::forward);
addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse);
}
// Condition is only cloned, and not differentiated.
// Its because conditions generally contain non-differentiable constructs,
// but this behaviour will lead to incorrect results if the condition
// expression modifies any variable.
Expr* condClone = (SS->getCond() ? Clone(SS->getCond()) : nullptr);

StmtDiff condDiff = DifferentiateSingleStmt(SS->getCond());
addToCurrentBlock(condDiff.getStmt(), direction::forward);
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
Expr* condExpr = nullptr;
llvm::Optional<CladTapeResult> condTape;
clad_compat::llvm_Optional<CladTapeResult> condTape;

if (isInsideLoop) {
// If we are inside a loop, condition will be stored and used as follows:
Expand All @@ -3289,10 +3290,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// reverse block:
// switch (...) { ... }
// clad::pop(...);
condTape.emplace(MakeCladTapeFor(condClone, "_cond"));
condTape.emplace(MakeCladTapeFor(condDiff.getExpr(), "_cond"));
condExpr = condTape->Push;
} else {
condExpr = GlobalStoreAndRef(condClone, "_cond").getExpr();
condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr();
}

auto activeBreakContHandler = PushBreakContStmtHandler(
Expand Down Expand Up @@ -3361,6 +3362,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
}
if (!breakCond)
breakCond = m_Sema.ActOnCXXBoolLiteral(noLoc, tok::kw_true).get();
SSData->defaultIfBreakExpr->setCond(breakCond);
}

Expand All @@ -3378,7 +3381,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);

// Registers all the cases to the switch statement.
for (auto SC : SSData->cases)
for (auto *SC : SSData->cases)
forwardSS->addSwitchCase(SC);

forwardSS =
Expand Down Expand Up @@ -3425,8 +3428,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) {
beginBlock(direction::reverse);
beginBlock(direction::forward);
auto SSData = GetActiveSwitchStmtInfo();
auto newDefaultStmt =
auto *SSData = GetActiveSwitchStmtInfo();
DefaultStmt *newDefaultStmt =
new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr);
Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();
Stmt* ifBreakExpr = clad_compat::IfStmt_Create(
Expand Down
63 changes: 63 additions & 0 deletions test/Gradient/Switch.C
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
#include "../TestUtils.h"

double fn1(double i, double j) {
double res = 0;
Expand Down Expand Up @@ -513,6 +514,65 @@ double fn5(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn6(double u, double v) {
int res = 0;
double temp = 0;
switch(res = u * v) {
default:
temp = 1;
}
return res;
}

// CHECK: void fn6_grad(double u, double v, clad::array_ref<double> _d_u, clad::array_ref<double> _d_v) {
// CHECK-NEXT: int _d_res = 0;
// CHECK-NEXT: double _d_temp = 0;
// CHECK-NEXT: int _t0;
// CHECK-NEXT: int _cond0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: clad::tape<unsigned long> _t2 = {};
// CHECK-NEXT: int res = 0;
// CHECK-NEXT: double temp = 0;
// CHECK-NEXT: {
// CHECK-NEXT: _t0 = res;
// CHECK-NEXT: res = u * v;
// CHECK-NEXT: _cond0 = res = u * v;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: {
// CHECK-NEXT: default:
// CHECK-NEXT: temp = 1;
// CHECK-NEXT: _t1 = temp;
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t2, 1UL);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case 1UL:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: temp = _t1;
// CHECK-NEXT: double _r_d1 = _d_temp;
// CHECK-NEXT: _d_temp -= _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: if (true)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: res = _t0;
// CHECK-NEXT: int _r_d0 = _d_res;
// CHECK-NEXT: _d_res -= _r_d0;
// CHECK-NEXT: * _d_u += _r_d0 * v;
// CHECK-NEXT: * _d_v += u * _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST_2(F, x, y) \
{ \
result[0] = result[1] = 0; \
Expand All @@ -530,4 +590,7 @@ int main() {
TEST_2(fn3, 3, 5); // CHECK-EXEC: {162.00, 90.00}
TEST_2(fn4, 3, 5); // CHECK-EXEC: {10.00, 6.00}
TEST_2(fn5, 3, 5); // CHECK-EXEC: {5.00, 3.00}

INIT_GRADIENT(fn6);
TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00}
}

0 comments on commit c6f3adf

Please sign in to comment.