Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds auxilary functions needed for reduce_sum #1800

Merged
merged 13 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions stan/math/rev/core.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#ifndef STAN_MATH_REV_CORE_HPP
#define STAN_MATH_REV_CORE_HPP

#include <stan/math/rev/core/accumulate_adjoints.hpp>
#include <stan/math/rev/core/autodiffstackstorage.hpp>
#include <stan/math/rev/core/build_vari_array.hpp>
#include <stan/math/rev/core/chainable_alloc.hpp>
#include <stan/math/rev/core/chainablestack.hpp>
#include <stan/math/rev/core/count_vars.hpp>
#include <stan/math/rev/core/init_chainablestack.hpp>
#include <stan/math/rev/core/std_iterator_traits.hpp>
#include <stan/math/rev/core/ddv_vari.hpp>
#include <stan/math/rev/core/deep_copy_vars.hpp>
#include <stan/math/rev/core/dv_vari.hpp>
#include <stan/math/rev/core/dvd_vari.hpp>
#include <stan/math/rev/core/dvv_vari.hpp>
Expand Down Expand Up @@ -63,5 +66,6 @@
#include <stan/math/rev/core/vv_vari.hpp>
#include <stan/math/rev/core/vvd_vari.hpp>
#include <stan/math/rev/core/vvv_vari.hpp>
#include <stan/math/rev/core/save_varis.hpp>

#endif
154 changes: 154 additions & 0 deletions stan/math/rev/core/accumulate_adjoints.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#ifndef STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP
#define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/var.hpp>

#include <utility>
#include <vector>

namespace stan {
namespace math {

template <typename... Pargs>
inline double* accumulate_adjoints(double* dest, const var& x, Pargs&&... args);

template <typename VarVec, require_std_vector_vt<is_var, VarVec>* = nullptr,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args);

template <typename VecContainer,
require_std_vector_st<is_var, VecContainer>* = nullptr,
require_std_vector_vt<is_container, VecContainer>* = nullptr,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, VecContainer&& x,
Pargs&&... args);

template <typename EigT, require_eigen_vt<is_var, EigT>* = nullptr,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args);

template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>* = nullptr,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args);

inline double* accumulate_adjoints(double* dest);

/**
* Accumulate adjoints from x into storage pointed to by dest,
* increment the adjoint storage pointer,
* recursively accumulate the adjoints of the rest of the arguments,
* and return final position of storage pointer.
*
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x A var
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename... Pargs>
inline double* accumulate_adjoints(double* dest, const var& x,
Pargs&&... args) {
*dest += x.adj();
return accumulate_adjoints(dest + 1, std::forward<Pargs>(args)...);
}

/**
* Accumulate adjoints from std::vector x into storage pointed to by dest,
* increment the adjoint storage pointer,
* recursively accumulate the adjoints of the rest of the arguments,
* and return final position of storage pointer.
*
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x A std::vector of vars
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename VarVec, require_std_vector_vt<is_var, VarVec>*,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args) {
for (auto&& x_iter : x) {
*dest += x_iter.adj();
++dest;
}
return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
}

/**
* Accumulate adjoints from x (a std::vector of containers containing vars)
* into storage pointed to by dest,
* increment the adjoint storage pointer,
* recursively accumulate the adjoints of the rest of the arguments,
* and return final position of storage pointer.
*
* @tparam VecContainer the type of a standard container holding var
* containers.
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x A std::vector of containers holding vars
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename VecContainer, require_std_vector_st<is_var, VecContainer>*,
require_std_vector_vt<is_container, VecContainer>*, typename... Pargs>
inline double* accumulate_adjoints(double* dest, VecContainer&& x,
Pargs&&... args) {
for (auto&& x_iter : x) {
dest = accumulate_adjoints(dest, x_iter);
}
return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
}

/**
* Accumulate adjoints from x (an Eigen type containing vars)
* into storage pointed to by dest,
* increment the adjoint storage pointer,
* recursively accumulate the adjoints of the rest of the arguments,
* and return final position of storage pointer.
*
* @tparam EigT Type derived from `EigenBase` containing vars.
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x An eigen type holding vars to accumulate over
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename EigT, require_eigen_vt<is_var, EigT>*, typename... Pargs>
inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) {
Eigen::Map<Eigen::MatrixXd>(dest, x.rows(), x.cols()) += x.adj();
return accumulate_adjoints(dest + x.size(), std::forward<Pargs>(args)...);
}

/**
* Ignore arithmetic types.
*
* Recursively accumulate the adjoints of the rest of the arguments
* and return final position of adjoint storage pointer.
*
* @tparam Arith A type satisfying `std::is_arithmetic`.
* @tparam Pargs Types of remaining arguments
* @param dest Pointer to where adjoints are to be accumulated
* @param x An object that is either arithmetic or a container of Arithmetic
* types
* @param args Further args to accumulate over
* @return Final position of adjoint storage pointer
*/
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>*,
typename... Pargs>
inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) {
return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
}

/**
* End accumulate_adjoints recursion and return pointer
*
* @param dest Pointer
*/
inline double* accumulate_adjoints(double* dest) { return dest; }

} // namespace math
} // namespace stan

#endif
151 changes: 151 additions & 0 deletions stan/math/rev/core/count_vars.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#ifndef STAN_MATH_REV_CORE_COUNT_VARS_HPP
#define STAN_MATH_REV_CORE_COUNT_VARS_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/var.hpp>

#include <utility>
#include <vector>

namespace stan {
namespace math {

template <typename VecVar, require_std_vector_vt<is_var, VecVar>* = nullptr,
typename... Pargs>
inline size_t count_vars_impl(size_t count, VecVar&& x, Pargs&&... args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for bringing this up somewhat late.... but shouldn't we place the "*_impl" stuff into an "internal" namespace? This applies to the other tools as well which we are writing with an impl naming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... count_vars is the only one with a "impl" thing - for good reasons as here you have to initialise to zero. So, I think moving the "*_impl" functions into an internal namespace, makes sense to me. Right?


template <typename VecContainer,
require_std_vector_st<is_var, VecContainer>* = nullptr,
require_std_vector_vt<is_container, VecContainer>* = nullptr,
typename... Pargs>
inline size_t count_vars_impl(size_t count, VecContainer&& x, Pargs&&... args);

template <typename EigT, require_eigen_vt<is_var, EigT>* = nullptr,
typename... Pargs>
inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args);

template <typename... Pargs>
inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args);

template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>* = nullptr,
typename... Pargs>
inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args);

inline size_t count_vars_impl(size_t count);

/**
* Count the number of vars in x (a std::vector of vars),
* add it to the running total,
* count the number of vars in the remaining arguments
* and return the result.
*
* @tparam VecVar type of standard container holding vars
* @tparam Pargs Types of remaining arguments
* @param[in] count The current count of the number of vars
* @param[in] x A std::vector holding vars.
* @param[in] args objects to be forwarded to recursive call of
* `count_vars_impl`
*/
template <typename VecVar, require_std_vector_vt<is_var, VecVar>*,
typename... Pargs>
inline size_t count_vars_impl(size_t count, VecVar&& x, Pargs&&... args) {
return count_vars_impl(count + x.size(), std::forward<Pargs>(args)...);
}

/**
* Count the number of vars in x (a std::vector holding other containers),
* add it to the running total,
* count the number of vars in the remaining arguments
* and return the result.
*
* @tparam VecContainer std::vector holding arguments which contain Vars
* @tparam Pargs Types of remaining arguments
* @param[in] count The current count of the number of vars
* @param[in] x A vector holding containers of vars
* @param[in] args objects to be forwarded to recursive call of
* `count_vars_impl`
*/
template <typename VecContainer, require_std_vector_st<is_var, VecContainer>*,
require_std_vector_vt<is_container, VecContainer>*, typename... Pargs>
inline size_t count_vars_impl(size_t count, VecContainer&& x, Pargs&&... args) {
for (auto&& x_iter : x) {
count = count_vars_impl(count, x_iter);
}
return count_vars_impl(count, std::forward<Pargs>(args)...);
}

/**
* Count the number of vars in x (an eigen container),
* add it to the running total,
* count the number of vars in the remaining arguments
* and return the result.
*
* @tparam EigT A type derived from `EigenBase`
* @tparam Pargs Types of remaining arguments
* @param[in] count The current count of the number of vars
* @param[in] x An Eigen container holding vars
* @param[in] args objects to be forwarded to recursive call of
* `count_vars_impl`
*/
template <typename EigT, require_eigen_vt<is_var, EigT>*, typename... Pargs>
inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args) {
return count_vars_impl(count + x.size(), std::forward<Pargs>(args)...);
}

/**
* Add one to the running total number of vars,
* count the number of vars in the remaining arguments
* and return the result.
*
* @tparam Pargs Types of remaining arguments
* @param[in] count The current count of the number of vars
* @param[in] x A var
* @param[in] args objects to be forwarded to recursive call of
* `count_vars_impl`
*/
template <typename... Pargs>
inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) {
return count_vars_impl(count + 1, std::forward<Pargs>(args)...);
}

/**
* Arguments without vars contribute zero to the total number of vars.
*
* Return the running total number of vars plus the number of
* vars in the remaining aruments.
*
* @tparam Arith An object that is either arithmetic or holds arithmetic
* types
* @tparam Pargs Types of remaining arguments
* @param[in] count The current count of the number of vars
* @param[in] x An arithmetic value or container
* @param[in] args objects to be forwarded to recursive call of
* `count_vars_impl`
*/
template <typename Arith, require_arithmetic_t<scalar_type_t<Arith>>*,
typename... Pargs>
inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args) {
return count_vars_impl(count, std::forward<Pargs>(args)...);
}

/**
* End count_vars_impl recursion and return total number of counted vars
*/
inline size_t count_vars_impl(size_t count) { return count; }

/**
* Count the number of vars in the input argument list
*
* @tparam Pargs Types of input arguments
* @return Number of vars in input
*/
template <typename... Pargs>
inline size_t count_vars(Pargs&&... args) {
return count_vars_impl(0, std::forward<Pargs>(args)...);
}

} // namespace math
} // namespace stan

#endif
Loading