Skip to content

Commit

Permalink
Do not wrap pullback args in array_ref
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Feb 20, 2024
1 parent d023a99 commit 20e2b32
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 25 deletions.
18 changes: 0 additions & 18 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1689,24 +1689,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
passByRef = false;
}
if (passByRef) {
// If derivative type is constant array type instead of
// `clad::array_ref` or `clad::array` type, then create an
// `clad::array_ref` variable that references this constant array. It
// is required because the pullback function expects `clad::array_ref`
// type for representing array derivatives. Currently, only constant
// array data members have derivatives of constant array types.
if ((argDerivative != nullptr) &&
isa<ConstantArrayType>(argDerivative->getType())) {
Expr* init =
utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative);
auto* derivativeArrayRefVD = BuildVarDecl(
GetCladArrayRefOfType(argDerivative->getType()
->getPointeeOrArrayElementType()
->getCanonicalTypeInternal()),
"_t", init);
ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD));
argDerivative = BuildDeclRef(derivativeArrayRefVD);
}
if (argDerivative) {
if (utils::isArrayOrPointerType(argDerivative->getType()) ||
isCladArrayType(argDerivative->getType()) ||
Expand Down
9 changes: 4 additions & 5 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,20 @@ double fn8(double x, double y) {
// CHECK: void fn8_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: _t3 = check_and_return(x, 'a', "aa");
// CHECK-NEXT: double _t2;
// CHECK-NEXT: _t2 = check_and_return(x, 'a', "aa");
// CHECK-NEXT: _t1 = std::tanh(1.);
// CHECK-NEXT: _t0 = std::max(1., 2.);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: char _grad1 = 0i8;
// CHECK-NEXT: clad::array_ref<char> _t2 = {"", 3UL};
// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_grad0, &_grad1, _t2);
// CHECK-NEXT: check_and_return_pullback(x, 'a', "aa", 1 * _t0 * _t1 * y, &_grad0, &_grad1, "");
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: char _r1 = _grad1;
// CHECK-NEXT: *_d_y += _t3 * 1 * _t0 * _t1;
// CHECK-NEXT: *_d_y += _t2 * 1 * _t0 * _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down
3 changes: 1 addition & 2 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ double fn2(Tangent t, double i) {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t1;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: clad::array_ref<double> _t2 = {(*_d_t).data, 5UL};
// CHECK-NEXT: sum_pullback(t.data, _r_d0, _t2);
// CHECK-NEXT: sum_pullback(t.data, _r_d0, (*_d_t).data);
// CHECK-NEXT: *_d_i += _r_d0;
// CHECK-NEXT: (*_d_t).data[0] += 2 * _r_d0;
// CHECK-NEXT: }
Expand Down

0 comments on commit 20e2b32

Please sign in to comment.