-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/0123 complex spec #1720
Changes from all commits
994d363
2d93950
c679e61
5c13696
879f3ac
ed61d81
fe1bb95
622d294
1b18e5f
db7728c
f1c33dd
f5e0107
4bcad0c
6da512c
3b66c6f
a01d73a
9d040b5
273533b
95a84fd
92d3ffa
3944dc6
b3c1ef7
2296e3f
590ea12
af5ffbe
39e305a
78ca46f
3e89f98
1f18017
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#ifndef STAN_MATH_FWD_CORE_STD_COMPLEX_HPP | ||
#define STAN_MATH_FWD_CORE_STD_COMPLEX_HPP | ||
|
||
#include <stan/math/prim/core/complex_base.hpp> | ||
#include <stan/math/fwd/core/fvar.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <complex> | ||
|
||
namespace std { | ||
|
||
/** | ||
* Specialization of the standard libary complex number type for | ||
* reverse-mode autodiff type `stan::math::fvar<T>`. | ||
* | ||
* @tparam T forward-mode autodiff value type | ||
*/ | ||
template <typename T> | ||
class complex<stan::math::fvar<T>> | ||
: public stan::math::complex_base<stan::math::fvar<T>> { | ||
public: | ||
using base_t = stan::math::complex_base<stan::math::fvar<T>>; | ||
|
||
/** | ||
* Construct a complex number with zero real and imaginary parts. | ||
*/ | ||
complex() = default; | ||
|
||
/** | ||
* Construct a complex number with the specified real part and a zero | ||
* imaginary part. | ||
* | ||
* @tparam Scalar real type (must be assignable to `value_type`) | ||
* @param[in] re real part | ||
*/ | ||
template <typename U, typename = stan::require_stan_scalar_t<U>> | ||
complex(U&& re) : base_t(re) {} // NOLINT(runtime/explicit) | ||
|
||
/** | ||
* Construct a complex number from the specified real and imaginary | ||
* parts. | ||
* | ||
* @tparam U type of real part | ||
* @tparam V type of imaginary part | ||
* @param[in] re real part | ||
* @param[in] im imaginary part | ||
*/ | ||
template <typename U, typename V> | ||
complex(const U& re, const V& im) : base_t(re, im) {} | ||
|
||
/** | ||
* Set the real and imaginary parts to those of the specified | ||
* complex number. | ||
* | ||
* @tparam U value type of argument | ||
* @param[in] x complex number to set | ||
* @return this | ||
*/ | ||
template <typename U, typename = stan::require_arithmetic_t<U>> | ||
auto& operator=(const std::complex<U>& x) { | ||
this->re_ = x.real(); | ||
this->im_ = x.imag(); | ||
return *this; | ||
} | ||
}; | ||
|
||
} // namespace std | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
#ifndef STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP | ||
#define STAN_MATH_PRIM_CORE_COMPLEX_BASE_HPP | ||
|
||
#include <stan/math/prim/fun/square.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <complex> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Base class for complex numbers. Extending classes must be of | ||
* of the form `complex<ValueType>`. | ||
* | ||
* @tparam ValueType type of real and imaginary parts | ||
*/ | ||
template <typename ValueType> | ||
class complex_base { | ||
public: | ||
/** | ||
* Type of real and imaginary parts | ||
*/ | ||
using value_type = ValueType; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is our standard to do underscore_naming for usings and CamelCase for template parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's consistent with how I understand the code but it's weird to see if directly in a situation like this :/. |
||
|
||
/** | ||
* Derived complex type used for function return types | ||
*/ | ||
using complex_type = std::complex<value_type>; | ||
|
||
/** | ||
* Construct a complex base with zero real and imaginary parts. | ||
*/ | ||
complex_base() = default; | ||
|
||
/** | ||
* Construct a complex base with the specified real part and a zero | ||
* imaginary part. | ||
* | ||
* @tparam U real type (assignable to `value_type`) | ||
* @param[in] re real part | ||
*/ | ||
template <typename U, typename = require_stan_scalar_t<U>> | ||
complex_base(const U& re) : re_(re) {} // NOLINT(runtime/explicit) | ||
|
||
/** | ||
* Construct a complex base with the specified real and imaginary | ||
* parts. | ||
* | ||
* @tparam U real type (assignable to `value_type`) | ||
* @tparam V imaginary type (assignable to `value_type`) | ||
* @param[in] re real part | ||
* @param[in] im imaginary part | ||
*/ | ||
template <typename U, typename V> | ||
complex_base(const U& re, const V& im) : re_(re), im_(im) {} | ||
|
||
/** | ||
* Return the real part. | ||
* | ||
* @return real part | ||
*/ | ||
value_type real() const { return re_; } | ||
|
||
/** | ||
* Set the real part to the specified value. | ||
* | ||
* @param[in] re real part | ||
*/ | ||
void real(const value_type& re) { re_ = re; } | ||
|
||
/** | ||
* Return the imaginary part. | ||
* | ||
* @return imaginary part | ||
*/ | ||
value_type imag() const { return im_; } | ||
|
||
/** | ||
* Set the imaginary part to the specified value. | ||
* | ||
* @param[in] im imaginary part | ||
*/ | ||
void imag(const value_type& im) { im_ = im; } | ||
|
||
/** | ||
* Assign the specified value to the real part of this complex number | ||
* and set imaginary part to zero. | ||
* | ||
* @tparam U argument type (assignable to `value_type`) | ||
* @param[in] re real part | ||
* @return this | ||
*/ | ||
template <typename U, typename = require_stan_scalar_t<U>> | ||
complex_type& operator=(U&& re) { | ||
re_ = re; | ||
im_ = 0; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Add specified real value to real part. | ||
* | ||
* @tparam U argument type (assignable to `value_type`) | ||
* @param[in] x real number to add | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator+=(const U& x) { | ||
re_ += x; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Adds specified complex number to this. | ||
* | ||
* @tparam U value type of argument (assignable to `value_type`) | ||
* @param[in] other complex number to add | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator+=(const std::complex<U>& other) { | ||
re_ += other.real(); | ||
im_ += other.imag(); | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Subtracts specified real number from real part. | ||
* | ||
* @tparam U argument type (assignable to `value_type`) | ||
* @param[in] x real number to subtract | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator-=(const U& x) { | ||
re_ -= x; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Subtracts specified complex number from this. | ||
* | ||
* @tparam U value type of argument (assignable to `value_type`) | ||
* @param[in] other complex number to subtract | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator-=(const std::complex<U>& other) { | ||
re_ -= other.real(); | ||
im_ -= other.imag(); | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Multiplies this by the specified real number. | ||
* | ||
* @tparam U type of argument (assignable to `value_type`) | ||
* @param[in] x real number to multiply | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator*=(const U& x) { | ||
re_ *= x; | ||
im_ *= x; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Multiplies this by specified complex number. | ||
* | ||
* @tparam U value type of argument (assignable to `value_type`) | ||
* @param[in] other complex number to multiply | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator*=(const std::complex<U>& other) { | ||
value_type re_temp = re_ * other.real() - im_ * other.imag(); | ||
im_ = re_ * other.imag() + other.real() * im_; | ||
re_ = re_temp; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Divides this by the specified real number. | ||
* | ||
* @tparam U type of argument (assignable to `value_type`) | ||
* @param[in] x real number to divide by | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator/=(const U& x) { | ||
re_ /= x; | ||
im_ /= x; | ||
return derived(); | ||
} | ||
|
||
/** | ||
* Divides this by the specified complex number. | ||
* | ||
* @tparam U value type of argument (assignable to `value_type`) | ||
* @param[in] other number to divide by | ||
* @return this | ||
*/ | ||
template <typename U> | ||
complex_type& operator/=(const std::complex<U>& other) { | ||
using stan::math::square; | ||
value_type sum_sq_im = square(other.real()) + square(other.imag()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can/should these be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the feeling about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally like |
||
value_type re_temp = (re_ * other.real() + im_ * other.imag()) / sum_sq_im; | ||
im_ = (im_ * other.real() - re_ * other.imag()) / sum_sq_im; | ||
re_ = re_temp; | ||
return derived(); | ||
} | ||
|
||
protected: | ||
/** | ||
* Real part | ||
*/ | ||
value_type re_{0}; | ||
|
||
/** | ||
* Imaginary part | ||
*/ | ||
value_type im_{0}; | ||
|
||
/** | ||
* Return this complex base cast to the complex type. | ||
* | ||
* @return this complex base cast to the complex type | ||
*/ | ||
complex_type& derived() { return static_cast<complex_type&>(*this); } | ||
}; | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#ifndef STAN_MATH_REV_CORE_STD_COMPLEX_HPP | ||
#define STAN_MATH_REV_CORE_STD_COMPLEX_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/core/complex_base.hpp> | ||
#include <stan/math/rev/core/var.hpp> | ||
#include <cmath> | ||
#include <complex> | ||
|
||
namespace std { | ||
|
||
/** | ||
* Specialization of the standard libary complex number type for | ||
* reverse-mode autodiff type `stan::math::var`. | ||
*/ | ||
template <> | ||
class complex<stan::math::var> | ||
: public stan::math::complex_base<stan::math::var> { | ||
public: | ||
using base_t = stan::math::complex_base<stan::math::var>; | ||
|
||
/** | ||
* Construct a complex number with zero real and imaginary parts. | ||
*/ | ||
complex() = default; | ||
|
||
/** | ||
* Construct a complex number from real and imaginary parts. | ||
* | ||
* @tparam U type of real part (assignable to `value_type`) | ||
* @tparam V type of imaginary part (assignable to `value_type`) | ||
* @param[in] re real part | ||
* @param[in] im imaginary part | ||
*/ | ||
template <typename U, typename V> | ||
complex(const U& re, const V& im) : base_t(re, im) {} | ||
|
||
/** | ||
* Construct a complex number with specified real part and zero | ||
* imaginary part. | ||
* | ||
* @tparam U type of real part (assignable to `value_type`) | ||
* @param[in] re real part | ||
*/ | ||
template <typename U, typename = stan::require_stan_scalar_t<U>> | ||
complex(U&& re) : base_t(re) {} // NOLINT(runtime/explicit) | ||
|
||
/** | ||
* Set the real and imaginary components of this complex number to | ||
* those of the specified complex number. | ||
* | ||
* @tparam U value type of argument (must be an arithmetic type) | ||
* @param[in] x complex argument | ||
* @return this | ||
*/ | ||
template <typename U, typename = stan::require_arithmetic_t<U>> | ||
auto& operator=(const std::complex<U>& x) { | ||
re_ = x.real(); | ||
im_ = x.imag(); | ||
return *this; | ||
} | ||
}; | ||
|
||
} // namespace std | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a crtp should this be
So the type is recurrent in the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. The
complex
(which would have to bestd::complex
) is redundant. I only take the value typeT
as the template parameter, but the "derived" class isstd::complex<T>
. If I suppoied all ofcomplex<fvar<T>>
, then I'd just have to go fishing theT
out with traits again.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is definitely redundant, though if I use a pattern I like keeping it standard instead of 'ish. But whether a strict CRTP pattern here would be better is totally subjective.
Would argue that type traits change here would be fishing in a barrel