-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
more accurate sqrt function #129
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -281,40 +281,87 @@ impl<T: Float> Complex<T> { | |||||||||
/// | ||||||||||
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`. | ||||||||||
#[inline] | ||||||||||
pub fn sqrt(self) -> Self { | ||||||||||
if self.im.is_zero() { | ||||||||||
if self.re.is_sign_positive() { | ||||||||||
// simple positive real √r, and copy `im` for its sign | ||||||||||
Self::new(self.re.sqrt(), self.im) | ||||||||||
pub fn sqrt(mut self) -> Self { | ||||||||||
// complex sqrt algorithm based on the algorithm from | ||||||||||
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks | ||||||||||
// to increase accuracy. Compared to a naive implementationt that | ||||||||||
// reuses the complex exp/ln implementations this algorithm has better | ||||||||||
// accuarcy since both (real) sqrt and (real) hypot are garunteed to | ||||||||||
// round perfectly. It's also faster since this implementation requires | ||||||||||
// less transcendental functions and those it does use (sqrt/hypto) are | ||||||||||
// faster comparted to exp/sin/cos. | ||||||||||
// | ||||||||||
// The musl libc implementation was referenced while implementing the | ||||||||||
// algorithm here: | ||||||||||
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c | ||||||||||
|
||||||||||
// TODO: rounding for very tiny subnormal numbers isn't perfect yet so | ||||||||||
// the assert shown fails in the very worst case this leads to about | ||||||||||
// 10% accuracy loss (see example below). As the magnitude increase the | ||||||||||
// error quickly drops to basically zero. | ||||||||||
// | ||||||||||
// glibc handles that (but other implementations like musl and numpy do | ||||||||||
// not) by upscaling very small values. That upscaling (and particularly | ||||||||||
// it's reversal) are weird and hard to understand (and rely on mantissa | ||||||||||
// bit size which we can't get out of the trait). In general the glibc | ||||||||||
// implementation is ever so subtley different and I wouldn't want to | ||||||||||
// introduce bugs by trying to adapt the underflow handling. | ||||||||||
// | ||||||||||
// assert_eq!( | ||||||||||
// Complex64::new(5.212e-324, 5.212e-324).sqrt(), | ||||||||||
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162) | ||||||||||
// ); | ||||||||||
|
||||||||||
// specical cases for correct nan/inf handling | ||||||||||
// see https://en.cppreference.com/w/c/numeric/complex/csqrt | ||||||||||
|
||||||||||
if self.re.is_zero() && self.im.is_zero() { | ||||||||||
// 0 +/- 0 i | ||||||||||
return Self::new(T::zero(), self.im); | ||||||||||
} | ||||||||||
if self.im.is_infinite() { | ||||||||||
// inf +/- inf i | ||||||||||
return Self::new(T::infinity(), self.im); | ||||||||||
} | ||||||||||
if self.re.is_nan() { | ||||||||||
// nan + nan i | ||||||||||
return Self::new(self.re, T::nan()); | ||||||||||
} | ||||||||||
if self.re.is_infinite() { | ||||||||||
// √(inf +/- NaN i) = inf +/- NaN i | ||||||||||
// √(inf +/- x i) = inf +/- 0 i | ||||||||||
// √(-inf +/- NaN i) = NaN +/- inf i | ||||||||||
// √(-inf +/- x i) = 0 +/- inf i | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a variable to make this clearer:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that is indeed more readable, I also added a comments. good point |
||||||||||
// if im is inf (or nan) this is nan, otherwise it's zero | ||||||||||
#[allow(clippy::eq_op)] | ||||||||||
let zero_or_nan = self.im - self.im; | ||||||||||
if self.re.is_sign_negative() { | ||||||||||
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im)); | ||||||||||
} else { | ||||||||||
// √(r e^(iπ)) = √r e^(iπ/2) = i√r | ||||||||||
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r | ||||||||||
let re = T::zero(); | ||||||||||
let im = (-self.re).sqrt(); | ||||||||||
if self.im.is_sign_positive() { | ||||||||||
Self::new(re, im) | ||||||||||
} else { | ||||||||||
Self::new(re, -im) | ||||||||||
} | ||||||||||
} | ||||||||||
} else if self.re.is_zero() { | ||||||||||
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2) | ||||||||||
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2) | ||||||||||
let one = T::one(); | ||||||||||
let two = one + one; | ||||||||||
let x = (self.im.abs() / two).sqrt(); | ||||||||||
if self.im.is_sign_positive() { | ||||||||||
Self::new(x, x) | ||||||||||
} else { | ||||||||||
Self::new(x, -x) | ||||||||||
return Self::new(self.re, zero_or_nan.copysign(self.im)); | ||||||||||
} | ||||||||||
} | ||||||||||
let two = T::one() + T::one(); | ||||||||||
let four = two + two; | ||||||||||
let overflow = T::max_value() / (T::one() + T::sqrt(two)); | ||||||||||
let max_magnitude = self.re.abs().max(self.im.abs()); | ||||||||||
let scale = max_magnitude >= overflow; | ||||||||||
if scale { | ||||||||||
self = self / four; | ||||||||||
} | ||||||||||
if self.re.is_sign_negative() { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also use a citation and link in a comment for the algorithm you mentioned. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a citation to the algorithm and the musl libc implementation as well as provide some additional background in a comement |
||||||||||
let tmp = ((-self.re + self.norm()) / two).sqrt(); | ||||||||||
self.re = self.im.abs() / (two * tmp); | ||||||||||
self.im = tmp.copysign(self.im); | ||||||||||
} else { | ||||||||||
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) | ||||||||||
let one = T::one(); | ||||||||||
let two = one + one; | ||||||||||
let (r, theta) = self.to_polar(); | ||||||||||
Self::from_polar(r.sqrt(), theta / two) | ||||||||||
self.re = ((self.re + self.norm()) / two).sqrt(); | ||||||||||
self.im = self.im / (two * self.re); | ||||||||||
} | ||||||||||
if scale { | ||||||||||
self = self * two; | ||||||||||
} | ||||||||||
self | ||||||||||
} | ||||||||||
|
||||||||||
/// Computes the principal value of the cube root of `self`. | ||||||||||
|
@@ -2065,6 +2112,50 @@ pub(crate) mod test { | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
#[test] | ||||||||||
fn test_sqrt_nan() { | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::INFINITY, f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::NAN), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NAN, f64::INFINITY).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::NEG_INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(-0.0, 0.0).sqrt(), | ||||||||||
Complex64::new(0.0, 0.0), | ||||||||||
)); | ||||||||||
for x in (-100..100).map(f64::from) { | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(x, f64::INFINITY).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NAN, x).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::NAN), | ||||||||||
)); | ||||||||||
// √(inf + x i) = inf + 0 i | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::INFINITY, x).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, 0.0.copysign(x)), | ||||||||||
)); | ||||||||||
// √(-inf + x i) = 0 + inf i | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, x).sqrt(), | ||||||||||
Complex64::new(0.0, f64::INFINITY.copysign(x)), | ||||||||||
)); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
#[test] | ||||||||||
fn test_cbrt() { | ||||||||||
assert!(close(_0_0i.cbrt(), _0_0i)); | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a source for all these special cases? e.g.
https://en.cppreference.com/w/c/numeric/complex/csqrt
(and make sure all those are covered)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more test to
test_nan
to make sure all of these are covered by theses and added a comment