From c69676c0aae9fb28e1d8263687671e7aaffea06d Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Sun, 13 Feb 2022 16:12:08 +0530 Subject: [PATCH] refactor(expression): add function cast_tensor_expression for casting This function casts any `tensor_expression` to its child class, and it also handles recursive casting to get the real expression that is stored inside the layers of `tensor_expression`. --- .../boost/numeric/ublas/tensor/expression.hpp | 68 ++++++++++--------- .../ublas/tensor/expression_evaluation.hpp | 50 +++++++++----- 2 files changed, 69 insertions(+), 49 deletions(-) diff --git a/include/boost/numeric/ublas/tensor/expression.hpp b/include/boost/numeric/ublas/tensor/expression.hpp index d33c385a3..5c20f3123 100644 --- a/include/boost/numeric/ublas/tensor/expression.hpp +++ b/include/boost/numeric/ublas/tensor/expression.hpp @@ -43,6 +43,32 @@ static constexpr bool does_exp_need_cast_v = does_exp_need_cast< std::decay_t template struct does_exp_need_cast< tensor_expression > : std::true_type{}; +/** + * @brief It is a safer way of casting `tensor_expression` because it handles + * recursive expressions. Otherwise, in most of the cases, we try to access + * `operator()`, which requires a parameter argument, that is not supported + * by the `tensor_expression` class and might give an error if it is not casted + * properly. + * + * @tparam T type of the tensor + * @tparam E type of the child stored inside tensor_expression + * @param e tensor_expression that needs to be casted + * @return child of tensor_expression that is not tensor_expression + */ +template +constexpr auto const& cast_tensor_exression(tensor_expression const& e) noexcept{ + auto const& res = e(); + if constexpr(does_exp_need_cast_v) + return cast_tensor_exression(res); + else + return res; +} + + +/// @brief Any expression other than `tensor_expression` +template +constexpr auto const& cast_tensor_exression(E const& e) noexcept{ return e; } + template constexpr auto is_tensor_expression_impl(tensor_expression const*) -> std::true_type; @@ -137,33 +163,15 @@ struct binary_tensor_expression binary_tensor_expression(const binary_tensor_expression& l) = delete; binary_tensor_expression& operator=(binary_tensor_expression const& l) noexcept = delete; + constexpr auto const& left_expr() const noexcept{ return cast_tensor_exression(el); } + constexpr auto const& right_expr() const noexcept{ return cast_tensor_exression(er); } [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (does_exp_need_cast_v && does_exp_need_cast_v) - { - return op(el()(i), er()(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (does_exp_need_cast_v && !does_exp_need_cast_v) - { - return op(el()(i), er(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const - requires (!does_exp_need_cast_v && does_exp_need_cast_v) - { - return op(el(i), er()(i)); - } - - [[nodiscard]] inline - constexpr decltype(auto) operator()(size_type i) const { - return op(el(i), er(i)); + constexpr decltype(auto) operator()(size_type i) const { + return op(left_expr()(i), right_expr()(i)); } +private: expression_type_left el; expression_type_right er; binary_operation op; @@ -211,19 +219,15 @@ struct unary_tensor_expression constexpr unary_tensor_expression() = delete; unary_tensor_expression(unary_tensor_expression const& l) = delete; unary_tensor_expression& operator=(unary_tensor_expression const& l) noexcept = delete; - - [[nodiscard]] inline constexpr - decltype(auto) operator()(size_type i) const - requires does_exp_need_cast_v - { - return op(e()(i)); - } + + constexpr auto const& expr() const noexcept{ return cast_tensor_exression(e); } [[nodiscard]] inline constexpr - decltype(auto) operator()(size_type i) const { - return op(e(i)); + decltype(auto) operator()(size_type i) const { + return op(expr()(i)); } +private: expression_type e; unary_operation op; }; diff --git a/include/boost/numeric/ublas/tensor/expression_evaluation.hpp b/include/boost/numeric/ublas/tensor/expression_evaluation.hpp index 37e9f1e48..b18203ce2 100644 --- a/include/boost/numeric/ublas/tensor/expression_evaluation.hpp +++ b/include/boost/numeric/ublas/tensor/expression_evaluation.hpp @@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression const& exp static_assert(has_tensor_types_v>, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); + auto const& lexpr = expr.left_expr(); + auto const& rexpr = expr.right_expr(); + if constexpr ( same_exp ) - return expr.el.extents(); + return lexpr.extents(); else if constexpr ( same_exp ) - return expr.er.extents(); + return rexpr.extents(); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.el); + return retrieve_extents(lexpr); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.er); + return retrieve_extents(rexpr); } #ifdef _MSC_VER @@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression const& expr) static_assert(has_tensor_types_v>, "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors."); + + auto const& uexpr = expr.expr(); if constexpr ( same_exp ) - return expr.e.extents(); + return uexpr.extents(); else if constexpr ( has_tensor_types_v ) - return retrieve_extents(expr.e); + return retrieve_extents(uexpr); } } // namespace boost::numeric::ublas::detail @@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression const& exp using ::operator==; using ::operator!=; + auto const& lexpr = expr.left_expr(); + auto const& rexpr = expr.right_expr(); + if constexpr ( same_exp ) - if(e != expr.el.extents()) + if(e != lexpr.extents()) return false; if constexpr ( same_exp ) - if(e != expr.er.extents()) + if(e != rexpr.extents()) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.el, e)) + if(!all_extents_equal(lexpr, e)) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.er, e)) + if(!all_extents_equal(rexpr, e)) return false; return true; @@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression const& expr, ex using ::operator==; + auto const& uexpr = expr.expr(); + if constexpr ( same_exp ) - if(e != expr.e.extents()) + if(e != uexpr.extents()) return false; if constexpr ( has_tensor_types_v ) - if(!all_extents_equal(expr.e, e)) + if(!all_extents_equal(uexpr, e)) return false; return true; @@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression if(!all_extents_equal(expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes."); -#pragma omp parallel for + auto const& rhs = cast_tensor_exression(expr); + + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) - lhs(i) = expr()(i); + lhs(i) = rhs(i); } /** @brief Evaluates expression for a tensor_core @@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression if(!all_extents_equal( expr, lhs.extents() )) throw std::runtime_error("Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes."); + auto const& rhs = cast_tensor_exression(expr); + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) - fn(lhs(i), expr()(i)); + fn(lhs(i), rhs(i)); } @@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression template inline void eval(tensor_type& lhs, unary_fn const& fn) { -#pragma omp parallel for + #pragma omp parallel for for(auto i = 0u; i < lhs.size(); ++i) fn(lhs(i)); }