Skip to content

Commit

Permalink
Merge branch 'update/eigen-3.4' of github.com:stan-dev/math into upda…
Browse files Browse the repository at this point in the history
…te/eigen-3.4
  • Loading branch information
SteveBronder committed Dec 8, 2021
2 parents 1c90adf + d9e1b11 commit fe4b1dd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
22 changes: 13 additions & 9 deletions stan/math/fwd/fun/mdivide_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ mdivide_right(const EigMat1& b, const EigMat2& A) {
}
}
auto A_mult_inv_b = mdivide_right(val_b, val_A).eval();
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)> ret(A_mult_inv_b.rows(), A_mult_inv_b.cols());
ret.val() = A_mult_inv_b;
ret.d() = mdivide_right(deriv_b, val_A)
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)>
ret(A_mult_inv_b.rows(), A_mult_inv_b.cols()); ret.val() = A_mult_inv_b; ret.d()
= mdivide_right(deriv_b, val_A)
- multiply(A_mult_inv_b, mdivide_right(deriv_A, val_A));
return ret;
}
Expand All @@ -70,10 +70,12 @@ template <typename EigMat1, typename EigMat2,
check_square("mdivide_right", "A", A);
check_multiplicable("mdivide_right", "b", b, "A", A);
if (A.size() == 0) {
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
return ret_type{b.rows(), 0};
using ret_type = decltype(A.transpose().template
cast<T_return>().lu().solve(b.template
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
}
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
return A.transpose().template cast<T_return>().lu().solve(b.template
cast<T_return>().transpose()).transpose().eval();
}
template <typename EigMat1, typename EigMat2,
Expand All @@ -85,10 +87,12 @@ template <typename EigMat1, typename EigMat2,
check_square("mdivide_right", "A", A);
check_multiplicable("mdivide_right", "b", b, "A", A);
if (A.size() == 0) {
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
return ret_type{b.rows(), 0};
using ret_type = decltype(A.transpose().template
cast<T_return>().lu().solve(b.template
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
}
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
return A.transpose().template cast<T_return>().lu().solve(b.template
cast<T_return>().transpose()).transpose().eval();
}
*/
} // namespace math
Expand Down
18 changes: 14 additions & 4 deletions stan/math/prim/fun/mdivide_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,26 @@ namespace math {
*/
template <typename EigMat1, typename EigMat2,
require_all_eigen_t<EigMat1, EigMat2>* = nullptr>
inline auto
mdivide_right(const EigMat1& b, const EigMat2& A) {
inline auto mdivide_right(const EigMat1& b, const EigMat2& A) {
using T_return = return_type_t<EigMat1, EigMat2>;
check_square("mdivide_right", "A", A);
check_multiplicable("mdivide_right", "b", b, "A", A);
if (A.size() == 0) {
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
using ret_type
= decltype(A.transpose()
.template cast<T_return>()
.lu()
.solve(b.template cast<T_return>().transpose())
.transpose()
.eval());
return ret_type{b.rows(), 0};
}
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
return A.transpose()
.template cast<T_return>()
.lu()
.solve(b.template cast<T_return>().transpose())
.transpose()
.eval();
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion test/unit/math/mix/fun/mdivide_right_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TEST(MathMixMatFun, mdivideRight_rowvector_matrix1) {
Eigen::RowVectorXd g(2);
g << 1, 1;

stan::test::expect_ad(f, g, b);
stan::test::expect_ad(f, g, b);
// vector, matrix
/*
for (const auto& m : std::vector<Eigen::MatrixXd>{b}) {
Expand Down

0 comments on commit fe4b1dd

Please sign in to comment.