-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds auxilary functions needed for reduce_sum #1800
Merged
Merged
Changes from 9 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f49dd02
Adds auxilary functions needed for reduce_sum
SteveBronder 6dedcf9
change test names for save_vari
SteveBronder fc6701f
fixup accumulate_adjoint and save_vari tests
SteveBronder 568c71c
Merge commit 'b6134fbf1a75d9bfa4716bafc8ced948b794f4b3' into HEAD
yashikno 062d981
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot c8343c3
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot 1c4c84f
Test deep_copy_vars to make sure adjoints aren't being propagated fro…
bbbales2 083de0a
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 32ef30f
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot e28b18a
Added zero argument test for save_varis
bbbales2 b679bb4
Merge commit '4649606f4cfa0e99a6a0673f3816e16b96b0709a' into HEAD
yashikno 8803029
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot 58a7626
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
#ifndef STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP | ||
#define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core/var.hpp> | ||
|
||
#include <utility> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, const var& x, Pargs&&... args); | ||
|
||
template <typename VarVec, require_std_vector_vt<is_var, VarVec>* = nullptr, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args); | ||
|
||
template <typename VecContainer, | ||
require_std_vector_st<is_var, VecContainer>* = nullptr, | ||
require_std_vector_vt<is_container, VecContainer>* = nullptr, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, VecContainer&& x, | ||
Pargs&&... args); | ||
|
||
template <typename EigT, require_eigen_vt<is_var, EigT>* = nullptr, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); | ||
|
||
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>* = nullptr, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); | ||
|
||
inline double* accumulate_adjoints(double* dest); | ||
|
||
/** | ||
* Accumulate adjoints from x into storage pointed to by dest, | ||
* increment the adjoint storage pointer, | ||
* recursively accumulate the adjoints of the rest of the arguments, | ||
* and return final position of storage pointer. | ||
* | ||
* @tparam Pargs Types of remaining arguments | ||
* @param dest Pointer to where adjoints are to be accumulated | ||
* @param x A var | ||
* @param args Further args to accumulate over | ||
* @return Final position of adjoint storage pointer | ||
*/ | ||
template <typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, const var& x, | ||
Pargs&&... args) { | ||
*dest += x.adj(); | ||
return accumulate_adjoints(dest + 1, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Accumulate adjoints from std::vector x into storage pointed to by dest, | ||
* increment the adjoint storage pointer, | ||
* recursively accumulate the adjoints of the rest of the arguments, | ||
* and return final position of storage pointer. | ||
* | ||
* @tparam Pargs Types of remaining arguments | ||
* @param dest Pointer to where adjoints are to be accumulated | ||
* @param x A std::vector of vars | ||
* @param args Further args to accumulate over | ||
* @return Final position of adjoint storage pointer | ||
*/ | ||
template <typename VarVec, require_std_vector_vt<is_var, VarVec>*, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args) { | ||
for (auto&& x_iter : x) { | ||
*dest += x_iter.adj(); | ||
++dest; | ||
} | ||
return accumulate_adjoints(dest, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Accumulate adjoints from x (a std::vector of containers containing vars) | ||
* into storage pointed to by dest, | ||
* increment the adjoint storage pointer, | ||
* recursively accumulate the adjoints of the rest of the arguments, | ||
* and return final position of storage pointer. | ||
* | ||
* @tparam VecContainer the type of a standard container holding var | ||
* containers. | ||
* @tparam Pargs Types of remaining arguments | ||
* @param dest Pointer to where adjoints are to be accumulated | ||
* @param x A std::vector of containers holding vars | ||
* @param args Further args to accumulate over | ||
* @return Final position of adjoint storage pointer | ||
*/ | ||
template <typename VecContainer, require_std_vector_st<is_var, VecContainer>*, | ||
require_std_vector_vt<is_container, VecContainer>*, typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, VecContainer&& x, | ||
Pargs&&... args) { | ||
for (auto&& x_iter : x) { | ||
dest = accumulate_adjoints(dest, x_iter); | ||
} | ||
return accumulate_adjoints(dest, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Accumulate adjoints from x (an Eigen type containing vars) | ||
* into storage pointed to by dest, | ||
* increment the adjoint storage pointer, | ||
* recursively accumulate the adjoints of the rest of the arguments, | ||
* and return final position of storage pointer. | ||
* | ||
* @tparam EigT Type derived from `EigenBase` containing vars. | ||
* @tparam Pargs Types of remaining arguments | ||
* @param dest Pointer to where adjoints are to be accumulated | ||
* @param x An eigen type holding vars to accumulate over | ||
* @param args Further args to accumulate over | ||
* @return Final position of adjoint storage pointer | ||
*/ | ||
template <typename EigT, require_eigen_vt<is_var, EigT>*, typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { | ||
Eigen::Map<Eigen::MatrixXd>(dest, x.rows(), x.cols()) += x.adj(); | ||
return accumulate_adjoints(dest + x.size(), std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Ignore arithmetic types. | ||
* | ||
* Recursively accumulate the adjoints of the rest of the arguments | ||
* and return final position of adjoint storage pointer. | ||
* | ||
* @tparam Arith A type satisfying `std::is_arithmetic`. | ||
* @tparam Pargs Types of remaining arguments | ||
* @param dest Pointer to where adjoints are to be accumulated | ||
* @param x An object that is either arithmetic or a container of Arithmetic | ||
* types | ||
* @param args Further args to accumulate over | ||
* @return Final position of adjoint storage pointer | ||
*/ | ||
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>*, | ||
typename... Pargs> | ||
inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) { | ||
return accumulate_adjoints(dest, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* End accumulate_adjoints recursion and return pointer | ||
* | ||
* @param dest Pointer | ||
*/ | ||
inline double* accumulate_adjoints(double* dest) { return dest; } | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
#ifndef STAN_MATH_REV_CORE_COUNT_VARS_HPP | ||
#define STAN_MATH_REV_CORE_COUNT_VARS_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core/var.hpp> | ||
|
||
#include <utility> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename VecVar, require_std_vector_vt<is_var, VecVar>* = nullptr, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, VecVar&& x, Pargs&&... args); | ||
|
||
template <typename VecContainer, | ||
require_std_vector_st<is_var, VecContainer>* = nullptr, | ||
require_std_vector_vt<is_container, VecContainer>* = nullptr, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, VecContainer&& x, Pargs&&... args); | ||
|
||
template <typename EigT, require_eigen_vt<is_var, EigT>* = nullptr, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args); | ||
|
||
template <typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); | ||
|
||
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>* = nullptr, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args); | ||
|
||
inline size_t count_vars_impl(size_t count); | ||
|
||
/** | ||
* Count the number of vars in x (a std::vector of vars), | ||
* add it to the running total, | ||
* count the number of vars in the remaining arguments | ||
* and return the result. | ||
* | ||
* @tparam VecVar type of standard container holding vars | ||
* @tparam Pargs Types of remaining arguments | ||
* @param[in] count The current count of the number of vars | ||
* @param[in] x A std::vector holding vars. | ||
* @param[in] args objects to be forwarded to recursive call of | ||
* `count_vars_impl` | ||
*/ | ||
template <typename VecVar, require_std_vector_vt<is_var, VecVar>*, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, VecVar&& x, Pargs&&... args) { | ||
return count_vars_impl(count + x.size(), std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Count the number of vars in x (a std::vector holding other containers), | ||
* add it to the running total, | ||
* count the number of vars in the remaining arguments | ||
* and return the result. | ||
* | ||
* @tparam VecContainer std::vector holding arguments which contain Vars | ||
* @tparam Pargs Types of remaining arguments | ||
* @param[in] count The current count of the number of vars | ||
* @param[in] x A vector holding containers of vars | ||
* @param[in] args objects to be forwarded to recursive call of | ||
* `count_vars_impl` | ||
*/ | ||
template <typename VecContainer, require_std_vector_st<is_var, VecContainer>*, | ||
require_std_vector_vt<is_container, VecContainer>*, typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, VecContainer&& x, Pargs&&... args) { | ||
for (auto&& x_iter : x) { | ||
count = count_vars_impl(count, x_iter); | ||
} | ||
return count_vars_impl(count, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Count the number of vars in x (an eigen container), | ||
* add it to the running total, | ||
* count the number of vars in the remaining arguments | ||
* and return the result. | ||
* | ||
* @tparam EigT A type derived from `EigenBase` | ||
* @tparam Pargs Types of remaining arguments | ||
* @param[in] count The current count of the number of vars | ||
* @param[in] x An Eigen container holding vars | ||
* @param[in] args objects to be forwarded to recursive call of | ||
* `count_vars_impl` | ||
*/ | ||
template <typename EigT, require_eigen_vt<is_var, EigT>*, typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args) { | ||
return count_vars_impl(count + x.size(), std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Add one to the running total number of vars, | ||
* count the number of vars in the remaining arguments | ||
* and return the result. | ||
* | ||
* @tparam Pargs Types of remaining arguments | ||
* @param[in] count The current count of the number of vars | ||
* @param[in] x A var | ||
* @param[in] args objects to be forwarded to recursive call of | ||
* `count_vars_impl` | ||
*/ | ||
template <typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { | ||
return count_vars_impl(count + 1, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* Arguments without vars contribute zero to the total number of vars. | ||
* | ||
* Return the running total number of vars plus the number of | ||
* vars in the remaining aruments. | ||
* | ||
* @tparam Arith An object that is either arithmetic or holds arithmetic | ||
* types | ||
* @tparam Pargs Types of remaining arguments | ||
* @param[in] count The current count of the number of vars | ||
* @param[in] x An arithmetic value or container | ||
* @param[in] args objects to be forwarded to recursive call of | ||
* `count_vars_impl` | ||
*/ | ||
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>*, | ||
typename... Pargs> | ||
inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args) { | ||
return count_vars_impl(count, std::forward<Pargs>(args)...); | ||
} | ||
|
||
/** | ||
* End count_vars_impl recursion and return total number of counted vars | ||
*/ | ||
inline size_t count_vars_impl(size_t count) { return count; } | ||
|
||
/** | ||
* Count the number of vars in the input argument list | ||
* | ||
* @tparam Pargs Types of input arguments | ||
* @return Number of vars in input | ||
*/ | ||
template <typename... Pargs> | ||
inline size_t count_vars(Pargs&&... args) { | ||
return count_vars_impl(0, std::forward<Pargs>(args)...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for bringing this up somewhat late.... but shouldn't we place the "*_impl" stuff into an "internal" namespace? This applies to the other tools as well which we are writing with an impl naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh... count_vars is the only one with a "impl" thing - for good reasons as here you have to initialise to zero. So, I think moving the "*_impl" functions into an internal namespace, makes sense to me. Right?