Skip to content

Commit

Permalink
Merge pull request #1797 from bstatcomp/cl_kernel_generator_broadcast
Browse files Browse the repository at this point in the history
Add broadcasting to kernel generator
  • Loading branch information
t4c1 authored Apr 2, 2020
2 parents ad4600f + 4154ff4 commit b225daf
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 14 deletions.
1 change: 1 addition & 0 deletions stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/colwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/transpose.hpp>
#include <stan/math/opencl/kernel_generator/broadcast.hpp>

#include <stan/math/opencl/kernel_generator/multi_result_kernel.hpp>
#include <stan/math/opencl/kernel_generator/get_kernel_source_for_evaluating_into.hpp>
Expand Down
9 changes: 3 additions & 6 deletions stan/math/opencl/kernel_generator/block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,16 @@ class block_
* @return number of columns
*/
inline int bottom_diagonal() const {
return std::max(
this->template get_arg<0>().bottom_diagonal() - start_col_ + start_row_,
1 - rows_);
return this->template get_arg<0>().bottom_diagonal() - start_col_
+ start_row_;
}

/**
* Determine index of top diagonal written.
* @return number of columns
*/
inline int top_diagonal() const {
return std::min(
this->template get_arg<0>().top_diagonal() - start_col_ + start_row_,
cols_ - 1);
return this->template get_arg<0>().top_diagonal() - start_col_ + start_row_;
}

/**
Expand Down
208 changes: 208 additions & 0 deletions stan/math/opencl/kernel_generator/broadcast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BROADCAST_HPP
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BROADCAST_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
#include <limits>
#include <string>
#include <type_traits>
#include <set>
#include <utility>

namespace stan {
namespace math {

/**
* Represents a broadcasting operation in kernel generator expressions.
* @tparam T type of arguments
* @tparam Colwise whether to broadcast colwise
* @tparam Rowwise whether to broadcast rowwise
*/
template <typename T, bool Colwise, bool Rowwise>
class broadcast_
: public operation_cl<broadcast_<T, Colwise, Rowwise>,
typename std::remove_reference_t<T>::Scalar, T> {
public:
using Scalar = typename std::remove_reference_t<T>::Scalar;
using base = operation_cl<broadcast_<T, Colwise, Rowwise>, Scalar, T>;
using base::var_name;

/**
* Constructor
* @param a expression
*/
explicit broadcast_(T&& a) : base(std::forward<T>(a)) {
const char* function = "broadcast";
if (Colwise) {
check_size_match(function, "Rows of ", "a", a.rows(), "", "", 1);
}
if (Rowwise) {
check_size_match(function, "Columns of ", "a", a.cols(), "", "", 1);
}
}

/**
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return broadcast_<std::remove_reference_t<decltype(arg_copy)>, Colwise,
Rowwise>{std::move(arg_copy)};
}

/**
* Generates kernel code for this and nested expressions.
* @param[in,out] generated set of already generated operations
* @param ng name generator for this kernel
* @param i row index variable name
* @param j column index variable name
* @return part of kernel with code for this and nested expressions
*/
inline kernel_parts generate(const std::string& i, const std::string& j,
const std::string& var_name_arg) const {
var_name = this->template get_arg<0>().var_name;
return {};
}

/**
* Sets index/indices along broadcasted dimmension(s) to 0.
* @param[in, out] i row index
* @param[in, out] j column index
*/
inline void modify_argument_indices(std::string& i, std::string& j) const {
if (Colwise) {
i = "0";
}
if (Rowwise) {
j = "0";
}
}

/**
* Number of rows of a matrix that would be the result of evaluating this
* expression.
* @return number of rows
*/
inline int rows() const {
return Colwise ? base::dynamic : this->template get_arg<0>().rows();
}

/**
* Number of columns of a matrix that would be the result of evaluating this
* expression.
* @return number of columns
*/
inline int cols() const {
return Rowwise ? base::dynamic : this->template get_arg<0>().cols();
}

/**
* View of a matrix that would be the result of evaluating this expression.
* @return view
*/
inline matrix_cl_view view() const {
matrix_cl_view view = this->template get_arg<0>().view();
if (Colwise) {
view = either(view, matrix_cl_view::Lower);
}
if (Rowwise) {
view = either(view, matrix_cl_view::Upper);
}
return view;
}

/**
* Determine index of bottom diagonal written.
* @return index of bottom diagonal
*/
inline int bottom_diagonal() const {
if (Colwise) {
return std::numeric_limits<int>::min();
} else {
return this->template get_arg<0>().bottom_diagonal();
}
}

/**
* Determine index of top diagonal written.
* @return index of top diagonal
*/
inline int top_diagonal() const {
if (Rowwise) {
return std::numeric_limits<int>::max();
} else {
return this->template get_arg<0>().top_diagonal();
}
}
};

/**
* Broadcast an expression in specified dimension(s). If broadcasting rowwise,
* the argument must have a single column. If broadcasting colwise, the argument
* must have a single row. Further expressions can use this expression as if it
* had any size in broadcasted dimension, repeating the values.
*
* Broadcasting evaluates the argument expression multiple times. For
* performance reasons don't broadcast slow operations. Instead evaluate them in
* a separate kernel.
* @tparam Colwise whether to broadcast Colwise
* @tparam Rowwise whether to broadcast Rowwise
* @tparam T type of input expression
* @param a input expression
* @return broadcast expression
*/
template <bool Colwise, bool Rowwise, typename T,
typename = require_all_valid_expressions_and_none_scalar_t<T>>
inline broadcast_<as_operation_cl_t<T>, Colwise, Rowwise> broadcast(T&& a) {
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
return broadcast_<as_operation_cl_t<T>, Colwise, Rowwise>(
std::move(a_operation));
}

/**
* Broadcast an expression in rowwise dimmension. The argument must have single
* column. Further expressions can use this expression as if it had any number
* of columns, repeating the values.
*
* Broadcasting evaluates argument expression multiple times. For performance
* reasons don't broadcast slow operations. Instead evaluate them in a separate
* kernel.
* @tparam T type of input expression
* @param a input expression
* @return broadcast expression
*/
template <typename T,
typename = require_all_valid_expressions_and_none_scalar_t<T>>
inline auto rowwise_broadcast(T&& a) {
return broadcast<false, true>(std::forward<T>(a));
}

/**
* Broadcast an expression in colwise dimmension. The argument must have single
* row. Further expressions can use this expression as if it had any number of
* rows, repeating the values.
*
* Broadcasting evaluates argument expression multiple times. For performance
* reasons don't broadcast slow operations. Instead evaluate them in a separate
* kernel.
* @tparam T type of input expression
* @param a input expression
* @return broadcast expression
*/
template <typename T,
typename = require_all_valid_expressions_and_none_scalar_t<T>>
inline auto colwise_broadcast(T&& a) {
return broadcast<true, false>(std::forward<T>(a));
}

} // namespace math
} // namespace stan
#endif
#endif
6 changes: 3 additions & 3 deletions stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() const & { return load_<T&>(a_); }
inline load_<T&> deep_copy() const& { return load_<T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }

/**
Expand Down Expand Up @@ -181,12 +181,12 @@ class load_
*/
inline void set_view(int bottom_diagonal, int top_diagonal,
int bottom_zero_diagonal, int top_zero_diagonal) const {
if (bottom_diagonal < 0) {
if (bottom_zero_diagonal <= top_diagonal && bottom_diagonal < 0) {
a_.view(either(a_.view(), matrix_cl_view::Lower));
} else if (bottom_zero_diagonal <= 1 - a_.rows()) {
a_.view(both(a_.view(), matrix_cl_view::Upper));
}
if (top_diagonal > 0) {
if (top_zero_diagonal >= bottom_diagonal && top_diagonal > 0) {
a_.view(either(a_.view(), matrix_cl_view::Upper));
} else if (top_zero_diagonal >= a_.cols() - 1) {
a_.view(both(a_.view(), matrix_cl_view::Lower));
Expand Down
13 changes: 9 additions & 4 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP
#ifdef STAN_OPENCL

#include <stan/math/prim/err.hpp>
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/opencl_context.hpp>
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -75,8 +77,11 @@ struct multi_result_kernel_internal {
"first expression", n_cols);
if (!is_without_output<T_current_expression>::value) {
result.check_assign_dimensions(expression.rows(), expression.cols());
result.set_view(expression.bottom_diagonal(), expression.top_diagonal(),
1 - expression.rows(), expression.cols() - 1);
int bottom_written = 1 - expression.rows();
int top_written = expression.cols() - 1;
result.set_view(std::max(expression.bottom_diagonal(), bottom_written),
std::min(expression.top_diagonal(), top_written),
bottom_written, top_written);
}
}

Expand Down Expand Up @@ -410,8 +415,8 @@ class results_cl {
if (n_rows * n_cols == 0) {
return;
}
check_positive(function, "number of rows", n_rows);
check_positive(function, "number of columns", n_cols);
check_nonnegative(function, "expr.rows()", n_rows);
check_nonnegative(function, "expr.cols()", n_cols);

try {
if (impl::kernel_() == NULL) {
Expand Down
7 changes: 6 additions & 1 deletion stan/math/opencl/kernel_generator/operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef STAN_OPENCL

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
Expand Down Expand Up @@ -117,7 +118,11 @@ class operation_cl : public operation_cl_base {
* @return Result of the expression.
*/
matrix_cl<Scalar> eval() const {
matrix_cl<Scalar> res(derived().rows(), derived().cols(), derived().view());
int rows = derived().rows();
int cols = derived().cols();
check_nonnegative("operation_cl.eval", "this->rows()", rows);
check_nonnegative("operation_cl.eval", "this->cols()", cols);
matrix_cl<Scalar> res(rows, cols, derived().view());
if (res.size() > 0) {
this->evaluate_into(res);
}
Expand Down
15 changes: 15 additions & 0 deletions test/unit/math/opencl/kernel_generator/block_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,19 @@ TEST(KernelGenerator, two_blocks_of_same_expression) {
EXPECT_MATRIX_NEAR(res, correct, 1e-9);
}

TEST(MathMatrixCL, block_view_test) {
using stan::math::block;
matrix_cl<double> m(4, 4, stan::math::matrix_cl_view::Diagonal);
matrix_cl<double> res = block(m, 0, 0, 2, 2);
EXPECT_EQ(res.view(), stan::math::matrix_cl_view::Diagonal);
res = block(m, 1, 0, 2, 2);
EXPECT_EQ(res.view(), stan::math::matrix_cl_view::Upper);
res = block(m, 0, 1, 2, 2);
EXPECT_EQ(res.view(), stan::math::matrix_cl_view::Lower);
res = block(m, 0, 2, 2, 2);
EXPECT_EQ(res.view(), stan::math::matrix_cl_view::Diagonal);
res = block(m, 2, 0, 2, 2);
EXPECT_EQ(res.view(), stan::math::matrix_cl_view::Diagonal);
}

#endif
Loading

0 comments on commit b225daf

Please sign in to comment.