diff --git a/stan/math/fwd/core.hpp b/stan/math/fwd/core.hpp index 35219018993..543a8206bf4 100644 --- a/stan/math/fwd/core.hpp +++ b/stan/math/fwd/core.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include diff --git a/stan/math/fwd/core/std_complex.hpp b/stan/math/fwd/core/std_complex.hpp new file mode 100644 index 00000000000..b7b7263df88 --- /dev/null +++ b/stan/math/fwd/core/std_complex.hpp @@ -0,0 +1,68 @@ +#ifndef STAN_MATH_FWD_CORE_STD_COMPLEX_HPP +#define STAN_MATH_FWD_CORE_STD_COMPLEX_HPP + +#include +#include +#include +#include + +namespace std { + +/** + * Specialization of the standard libary complex number type for + * reverse-mode autodiff type `stan::math::fvar`. + * + * @tparam T forward-mode autodiff value type + */ +template +class complex> + : public stan::math::complex_base> { + public: + using base_t = stan::math::complex_base>; + + /** + * Construct a complex number with zero real and imaginary parts. + */ + complex() = default; + + /** + * Construct a complex number with the specified real part and a zero + * imaginary part. + * + * @tparam Scalar real type (must be assignable to `value_type`) + * @param[in] re real part + */ + template > + complex(U&& re) : base_t(re) {} // NOLINT(runtime/explicit) + + /** + * Construct a complex number from the specified real and imaginary + * parts. + * + * @tparam U type of real part + * @tparam V type of imaginary part + * @param[in] re real part + * @param[in] im imaginary part + */ + template + complex(const U& re, const V& im) : base_t(re, im) {} + + /** + * Set the real and imaginary parts to those of the specified + * complex number. + * + * @tparam U value type of argument + * @param[in] x complex number to set + * @return this + */ + template > + auto& operator=(const std::complex& x) { + this->re_ = x.real(); + this->im_ = x.imag(); + return *this; + } +}; + +} // namespace std + +#endif diff --git a/stan/math/prim/core/complex_base.hpp b/stan/math/prim/core/complex_base.hpp new file mode 100644 index 00000000000..4d21c88b585 --- /dev/null +++ b/stan/math/prim/core/complex_base.hpp @@ -0,0 +1,236 @@ +#ifndef STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP +#define STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP + +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Base class for complex numbers. Extending classes must be of + * of the form `complex`. + * + * @tparam ValueType type of real and imaginary parts + */ +template +class complex_base { + public: + /** + * Type of real and imaginary parts + */ + using value_type = ValueType; + + /** + * Derived complex type used for function return types + */ + using complex_type = std::complex; + + /** + * Construct a complex base with zero real and imaginary parts. + */ + complex_base() = default; + + /** + * Construct a complex base with the specified real part and a zero + * imaginary part. + * + * @tparam U real type (assignable to `value_type`) + * @param[in] re real part + */ + template > + complex_base(const U& re) : re_(re) {} // NOLINT(runtime/explicit) + + /** + * Construct a complex base with the specified real and imaginary + * parts. + * + * @tparam U real type (assignable to `value_type`) + * @tparam V imaginary type (assignable to `value_type`) + * @param[in] re real part + * @param[in] im imaginary part + */ + template + complex_base(const U& re, const V& im) : re_(re), im_(im) {} + + /** + * Return the real part. + * + * @return real part + */ + value_type real() const { return re_; } + + /** + * Set the real part to the specified value. + * + * @param[in] re real part + */ + void real(const value_type& re) { re_ = re; } + + /** + * Return the imaginary part. + * + * @return imaginary part + */ + value_type imag() const { return im_; } + + /** + * Set the imaginary part to the specified value. + * + * @param[in] im imaginary part + */ + void imag(const value_type& im) { im_ = im; } + + /** + * Assign the specified value to the real part of this complex number + * and set imaginary part to zero. + * + * @tparam U argument type (assignable to `value_type`) + * @param[in] re real part + * @return this + */ + template > + complex_type& operator=(U&& re) { + re_ = re; + im_ = 0; + return derived(); + } + + /** + * Add specified real value to real part. + * + * @tparam U argument type (assignable to `value_type`) + * @param[in] x real number to add + * @return this + */ + template + complex_type& operator+=(const U& x) { + re_ += x; + return derived(); + } + + /** + * Adds specified complex number to this. + * + * @tparam U value type of argument (assignable to `value_type`) + * @param[in] other complex number to add + * @return this + */ + template + complex_type& operator+=(const std::complex& other) { + re_ += other.real(); + im_ += other.imag(); + return derived(); + } + + /** + * Subtracts specified real number from real part. + * + * @tparam U argument type (assignable to `value_type`) + * @param[in] x real number to subtract + * @return this + */ + template + complex_type& operator-=(const U& x) { + re_ -= x; + return derived(); + } + + /** + * Subtracts specified complex number from this. + * + * @tparam U value type of argument (assignable to `value_type`) + * @param[in] other complex number to subtract + * @return this + */ + template + complex_type& operator-=(const std::complex& other) { + re_ -= other.real(); + im_ -= other.imag(); + return derived(); + } + + /** + * Multiplies this by the specified real number. + * + * @tparam U type of argument (assignable to `value_type`) + * @param[in] x real number to multiply + * @return this + */ + template + complex_type& operator*=(const U& x) { + re_ *= x; + im_ *= x; + return derived(); + } + + /** + * Multiplies this by specified complex number. + * + * @tparam U value type of argument (assignable to `value_type`) + * @param[in] other complex number to multiply + * @return this + */ + template + complex_type& operator*=(const std::complex& other) { + value_type re_temp = re_ * other.real() - im_ * other.imag(); + im_ = re_ * other.imag() + other.real() * im_; + re_ = re_temp; + return derived(); + } + + /** + * Divides this by the specified real number. + * + * @tparam U type of argument (assignable to `value_type`) + * @param[in] x real number to divide by + * @return this + */ + template + complex_type& operator/=(const U& x) { + re_ /= x; + im_ /= x; + return derived(); + } + + /** + * Divides this by the specified complex number. + * + * @tparam U value type of argument (assignable to `value_type`) + * @param[in] other number to divide by + * @return this + */ + template + complex_type& operator/=(const std::complex& other) { + using stan::math::square; + value_type sum_sq_im = square(other.real()) + square(other.imag()); + value_type re_temp = (re_ * other.real() + im_ * other.imag()) / sum_sq_im; + im_ = (im_ * other.real() - re_ * other.imag()) / sum_sq_im; + re_ = re_temp; + return derived(); + } + + protected: + /** + * Real part + */ + value_type re_{0}; + + /** + * Imaginary part + */ + value_type im_{0}; + + /** + * Return this complex base cast to the complex type. + * + * @return this complex base cast to the complex type + */ + complex_type& derived() { return static_cast(*this); } +}; + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/meta/scalar_type.hpp b/stan/math/prim/meta/scalar_type.hpp index bac5a4e4c08..5d373575377 100644 --- a/stan/math/prim/meta/scalar_type.hpp +++ b/stan/math/prim/meta/scalar_type.hpp @@ -50,6 +50,7 @@ struct scalar_type::value>> { }; /** \ingroup type_trait + * * Template metaprogram defining the scalar type for values * stored in a complex number. * diff --git a/stan/math/rev/core.hpp b/stan/math/rev/core.hpp index bfa25a9145c..5d3c91f423b 100644 --- a/stan/math/rev/core.hpp +++ b/stan/math/rev/core.hpp @@ -48,6 +48,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/core/std_complex.hpp b/stan/math/rev/core/std_complex.hpp new file mode 100644 index 00000000000..d2f76450c5f --- /dev/null +++ b/stan/math/rev/core/std_complex.hpp @@ -0,0 +1,66 @@ +#ifndef STAN_MATH_REV_CORE_STD_COMPLEX_HPP +#define STAN_MATH_REV_CORE_STD_COMPLEX_HPP + +#include +#include +#include +#include +#include + +namespace std { + +/** + * Specialization of the standard libary complex number type for + * reverse-mode autodiff type `stan::math::var`. + */ +template <> +class complex + : public stan::math::complex_base { + public: + using base_t = stan::math::complex_base; + + /** + * Construct a complex number with zero real and imaginary parts. + */ + complex() = default; + + /** + * Construct a complex number from real and imaginary parts. + * + * @tparam U type of real part (assignable to `value_type`) + * @tparam V type of imaginary part (assignable to `value_type`) + * @param[in] re real part + * @param[in] im imaginary part + */ + template + complex(const U& re, const V& im) : base_t(re, im) {} + + /** + * Construct a complex number with specified real part and zero + * imaginary part. + * + * @tparam U type of real part (assignable to `value_type`) + * @param[in] re real part + */ + template > + complex(U&& re) : base_t(re) {} // NOLINT(runtime/explicit) + + /** + * Set the real and imaginary components of this complex number to + * those of the specified complex number. + * + * @tparam U value type of argument (must be an arithmetic type) + * @param[in] x complex argument + * @return this + */ + template > + auto& operator=(const std::complex& x) { + re_ = x.real(); + im_ = x.imag(); + return *this; + } +}; + +} // namespace std + +#endif diff --git a/test/unit/math/mix/core/std_complex_test.cpp b/test/unit/math/mix/core/std_complex_test.cpp new file mode 100644 index 00000000000..23f253b348d --- /dev/null +++ b/test/unit/math/mix/core/std_complex_test.cpp @@ -0,0 +1,217 @@ +#include +#include +#include + +template +void test_constructor_init_type() { + S a = 2; + std::complex z(a); + EXPECT_EQ(a, z.real()); + EXPECT_EQ(0, z.imag()); +} + +template +void test_binary_constructor(const T1& x, const T2& y) { + using stan::math::value_of_rec; + std::complex z(x, y); + EXPECT_EQ(1.1, value_of_rec(z.real())); + EXPECT_EQ(2.3, value_of_rec(z.imag())); +} + +template +void test_set_real_imag() { + std::complex z; + z.real(3.2); + EXPECT_TRUE(z.real() == 3.2); + z.imag(-1.9); + EXPECT_TRUE(z.imag() == -1.9); +} + +template +void test_std_complex_constructor() { + using stan::math::value_of_rec; + using c_t = std::complex; + + // set real and imaginary parts + test_set_real_imag(); + + // binary constructor + test_binary_constructor(1.1, 2.3); + test_binary_constructor(1.1, 2.3); + test_binary_constructor(1.1, 2.3); + test_binary_constructor(1.1, 2.3); + + // copy constructor + c_t c(4.9, -15.8); + c_t d(c); + EXPECT_EQ(4.9, value_of_rec(d.real())); + EXPECT_EQ(-15.8, value_of_rec(d.imag())); + + // default constructor + c_t e; + EXPECT_EQ(0, value_of_rec(e.real())); + EXPECT_EQ(0, value_of_rec(e.imag())); + + // unary constructors from constants + test_constructor_init_type(); + test_constructor_init_type(); + test_constructor_init_type(); + test_constructor_init_type(); // NOLINT(runtime/int) + test_constructor_init_type(); + test_constructor_init_type(); // NOLINT(runtime/int) +} + +TEST(mathMixCore, stdComplexConstructor) { + using stan::math::fvar; + using stan::math::var; + using stan::test::expect_ad; + + // test constructors + test_std_complex_constructor(); + test_std_complex_constructor(); + test_std_complex_constructor>(); + test_std_complex_constructor>>(); + test_std_complex_constructor>(); + test_std_complex_constructor>>(); +} + +// convenience for type inference of T +template +std::complex to_std_complex(const T& x) { + return {x}; +} + +template +void expect_common_complex(const F& f) { + // cover all quadrants and projections + for (double re : std::vector{-1.4, -1e-3, 0, 2e-3, 2.3}) { + for (double im : std::vector{-0.5, -3e-3, 0, 4e-3, 1.5}) { + stan::test::expect_ad(f, std::complex(re, im)); + } + } +} + +template +void expect_common_for_complex(const F& f) { + for (double re : std::vector{-3.9, -1e-3, 0, 2e-3, 4.1}) { + stan::test::expect_ad(f, re); + } +} + +// remaining tests are for operators; after here, each operator +// is tested for complex and real assignability, including for +// std::complex, double, and int; in each test an auto +// variable is set up to get a matching type to the input, then +// it is modified using std::complex or double, then +// with the argument. + +TEST(mathMixCore, stdComplexOperatorEqual) { + // operator=(std::complex) + auto f = [](const auto& a) { + auto b = a; // for auto type + b = std::complex(-1.1, 5.5); + EXPECT_TRUE(b.real() == -1.1); + EXPECT_TRUE(b.imag() == 5.5); + + b = a; + return b; + }; + expect_common_complex(f); + + // operator=(Arith) + auto g = [](const auto& a) { + auto b = to_std_complex(a); // for auto type + b = 3.1; + EXPECT_TRUE(b.real() == 3.1); + b = 3; + EXPECT_TRUE(b.real() == 3.0); + + b = a; + return b; + }; + expect_common_for_complex(g); +} + +TEST(mathMixCore, stdComplexOperatorPlusEqual) { + // operator+=(std::complex) + auto f = [](const auto& a) { + auto b = a; + b += std::complex(-3.9, 1.8); + b += a; + return b; + }; + expect_common_complex(f); + + // operator+=(Arith) + auto g = [](const auto& a) { + auto b = to_std_complex(a); + b += 2.0; + b += 2; + b += a; + return b; + }; + expect_common_for_complex(g); +} + +TEST(mathMixCore, stdComplexOperatorMinusEqual) { + // operator-=(std::complex) + auto f = [](const auto& a) { + auto b = a; + b -= std::complex(18.3, -21.2); + b -= a; + return b; + }; + expect_common_complex(f); + + // operator-=(Arith) + auto g = [](const auto& a) { + auto b = to_std_complex(a); + b -= 5.8; + b -= -1; + b -= a; + return b; + }; + expect_common_for_complex(g); +} + +TEST(mathMixCore, stdComplexOperatorTimesEqual) { + // operator-=(std::complex) + auto f = [](const auto& a) { + auto b = a; + b *= std::complex(-1.2, -6.3); + b *= a; + return b; + }; + expect_common_complex(f); + + // operator-=(Arith) + auto g = [](const auto& a) { + auto b = to_std_complex(a); + b *= 3.0; + b *= -2; + b *= a; + return b; + }; + expect_common_for_complex(g); +} + +TEST(mathMixCore, stdComplexOperatorDivideEqual) { + // operator-=(std::complex) + auto f = [](const auto& a) { + auto b = a; + b /= std::complex(1.2, -5.5); + b /= a; + return b; + }; + expect_common_complex(f); + + // operator-=(Arith) + auto g = [](const auto& a) { + auto b = to_std_complex(a); + b /= 5.5; + b /= -2; + b /= a; + return b; + }; + expect_common_for_complex(g); +} diff --git a/test/unit/math/serializer.hpp b/test/unit/math/serializer.hpp index 9aee03d1319..29b5ab3d17e 100644 --- a/test/unit/math/serializer.hpp +++ b/test/unit/math/serializer.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -64,6 +65,23 @@ struct deserializer { return vals_[position_++]; } + /** + * Read a complex number comforming to the shape of the specified + * argument. The specified argument is only used for its + * shape---there is no relationship between the value type of the + * argument and the type of the result. + * + * @tparam U type of pattern value type + * @param x pattern argument to determine result shape + * @return deserialized value with shape and size matching argument + */ + template + std::complex read(const std::complex& x) { + T re = read(x.real()); + T im = read(x.imag()); + return {re, im}; + } + /** * Read a standard vector conforming to the shape of the specified * argument, here a standard vector. The specified argument is only @@ -139,6 +157,18 @@ struct serializer { vals_.push_back(x); } + /** + * Serialize the specified complex number. + * + * @tparam U value type of complex number; must be assignable to T + * @param x complex number to serialize + */ + template + void write(const std::complex& x) { + write(x.real()); + write(x.imag()); + } + /** * Serialize the specified standard vector. * @@ -206,6 +236,11 @@ deserializer to_deserializer(const Eigen::Matrix& vals) { return deserializer(vals); } +template +deserializer to_deserializer(const std::complex& vals) { + return to_deserializer(std::vector{vals.real(), vals.imag()}); +} + template void serialize_helper(serializer& s) {} @@ -239,8 +274,8 @@ std::vector serialize(const Ts... xs) { * @return serialized argument */ template -std::vector::type> serialize_return(const T& x) { - return serialize::type>(x); +std::vector> serialize_return(const T& x) { + return serialize>(x); } /** diff --git a/test/unit/math/serializer_test.cpp b/test/unit/math/serializer_test.cpp index 2bed5bacfd2..a4dceaa043e 100644 --- a/test/unit/math/serializer_test.cpp +++ b/test/unit/math/serializer_test.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include TEST(testUnitMathSerializer, serializer_deserializer) { @@ -8,6 +9,9 @@ TEST(testUnitMathSerializer, serializer_deserializer) { s.write(3.2); s.write(-1); + + s.write(std::complex(-1.5, 15.75)); + s.write(std::vector{10, 20, 30}); Eigen::VectorXd a(2); @@ -22,8 +26,8 @@ TEST(testUnitMathSerializer, serializer_deserializer) { c << 1, 2, 3, 4, 5, 6; // << is row major; index order is col major s.write(c); - std::vector expected{3.2, -1, 10, 20, 30, -10, -20, 101, - 102, 103, 1, 3, 5, 2, 4, 6}; + std::vector expected{3.2, -1, -1.5, 15.75, 10, 20, 30, -10, -20, + 101, 102, 103, 1, 3, 5, 2, 4, 6}; for (size_t i = 0; i < expected.size(); ++i) EXPECT_EQ(expected[i], s.vals_[i]); @@ -31,6 +35,9 @@ TEST(testUnitMathSerializer, serializer_deserializer) { EXPECT_EQ(3.2, d.read(0.0)); EXPECT_EQ(-1, d.read(0.0)); + std::complex w = d.read(std::complex(1, 1)); + EXPECT_EQ(-1.5, w.real()); + EXPECT_EQ(15.75, w.imag()); std::vector x = d.read(std::vector{0, 0, 0}); EXPECT_EQ(10, x[0]); EXPECT_EQ(20, x[1]); @@ -56,12 +63,13 @@ TEST(testUnitMathSerializer, serialize) { EXPECT_EQ(0, xs.size()); double a = 2; - std::vector b{3, 4, 5}; - Eigen::MatrixXd c(2, 3); - c << -1, -2, -3, -4, -5, -6; - std::vector ys = stan::test::serialize(a, b, c); + std::complex b(1, 2); + std::vector c{3, 4, 5}; + Eigen::MatrixXd d(2, 3); + d << -1, -2, -3, -4, -5, -6; + std::vector ys = stan::test::serialize(a, b, c, d); - std::vector expected{2, 3, 4, 5, -1, -4, -2, -5, -3, -6}; + std::vector expected{2, 1, 2, 3, 4, 5, -1, -4, -2, -5, -3, -6}; for (size_t i = 0; i < expected.size(); ++i) EXPECT_EQ(expected[i], ys[i]); } diff --git a/test/unit/math/test_ad.hpp b/test/unit/math/test_ad.hpp index 6acb5bce857..ed1f1b15485 100644 --- a/test/unit/math/test_ad.hpp +++ b/test/unit/math/test_ad.hpp @@ -16,39 +16,57 @@ namespace test { namespace internal { /** - * Evaluates expression. A no-op for scalars. - * @tparam T nested type of fvar - * @param x value + * Evaluates nested matrix template expressions, which is a no-op for + * arithmetic arguments. + * + * @tparam T arithmetic type + * @param[in] x value * @return value */ -template -auto eval(const stan::math::fvar& x) { +template ::value>> +auto eval(T x) { return x; } /** - * Evaluates expression. A no-op for scalars. - * @param x value + * Evaluates nested matrix template expressions, which is a no-op for + * complex arguments. + * + * @tparam T complex value type + * @param x[in] value * @return value */ -auto eval(const stan::math::var& x) { return x; } +template +auto eval(const std::complex& x) { + return x; +} /** - * Evaluates expression. A no-op for scalars. - * @param x value + * Evaluates all nested matrix expression templates, which is a no-op for + * reverse-mode autodiff variables. + * + * @param[in] x value * @return value */ -auto eval(double x) { return x; } +auto eval(const stan::math::var& x) { return x; } /** - * Evaluates expression. A no-op for scalars. - * @param x value + * Evaluates all matrix expression templates, which is a no-op for + * forward-mode autodiff variables. + * + * @tparam T value type of fvar + * @param[in] x value * @return value */ -auto eval(int x) { return x; } +template +auto eval(const stan::math::fvar& x) { + return x; +} /** - * Evaluates expression. + * Evaluates all nested matrix expression templates, which evaluates + * the specified derived matrix. + * * @tparam Derived derived type of the expression * @param x expression * @return evaluated expression @@ -57,11 +75,13 @@ template auto eval(const Eigen::EigenBase& x) { return x.derived().eval(); } + /** - * Evaluates expressions in a \c std::vector. - * @tparam T type of \c std::vector elements - * @param x a \c std::vector of expressions - * @return a \cstd::vector of evaluated expressions + * Evaluates all nested matrix expression templates elementwise. + * + * @tparam T type of elements + * @param[in] x vector of expressions + * @return vector of evaluated expressions */ template auto eval(const std::vector& x) {