diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 07096ce7d..ab5651a37 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1462,10 +1462,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return foundExpr; } // Not found, construct the adjoint and register it. - VarDecl* VDDiff = - BuildVarDecl(VD->getType(), CreateUniqueIdentifier(nameDiff_str), - m_DerivativeFnScope->getParent()->getParent(), DC, - getZeroInit(VD->getType())); + QualType TyDiff = getNonConstType(VD->getType(), m_Context, m_Sema); + VarDecl* VDDiff = BuildVarDecl( + TyDiff, CreateUniqueIdentifier(nameDiff_str), + m_DerivativeFnScope->getParent()->getParent(), DC, getZeroInit(TyDiff)); DC->addDecl(VDDiff); DC->makeDeclVisibleInContext(VDDiff);