-
Notifications
You must be signed in to change notification settings - Fork 899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quadratic deal with a=0 and speed up #1103
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
using System; | ||
using Complex = System.Numerics.Complex; | ||
using NUnit.Framework; | ||
using System.Diagnostics; | ||
|
||
namespace MathNet.Numerics.Tests.RootFindingTests | ||
{ | ||
|
@@ -40,7 +41,7 @@ internal class FindRootsTest | |
public void MultipleRoots() | ||
{ | ||
// Roots at -2, 2 | ||
Func<double, double> f1 = x => x*x - 4; | ||
Func<double, double> f1 = x => x * x - 4; | ||
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, 0, 5, 1e-14))); | ||
Assert.AreEqual(-2, FindRoots.OfFunction(f1, -5, -1, 1e-14)); | ||
Assert.AreEqual(2, FindRoots.OfFunction(f1, 1, 4, 1e-14)); | ||
|
@@ -49,7 +50,7 @@ public void MultipleRoots() | |
Assert.AreEqual(2, FindRoots.OfFunction(x => -f1(x), 1, 4, 1e-14)); | ||
|
||
// Roots at 3, 4 | ||
Func<double, double> f2 = x => (x - 3)*(x - 4); | ||
Func<double, double> f2 = x => (x - 3) * (x - 4); | ||
Assert.AreEqual(0, f2(FindRoots.OfFunction(f2, 3.5, 5, 1e-14)), 1e-14); | ||
Assert.AreEqual(3, FindRoots.OfFunction(f2, -5, 3.5, 1e-14)); | ||
Assert.AreEqual(4, FindRoots.OfFunction(f2, 3.2, 5, 1e-14)); | ||
|
@@ -60,23 +61,23 @@ public void MultipleRoots() | |
[Test] | ||
public void LocalMinima() | ||
{ | ||
Func<double, double> f1 = x => x*x*x - 2*x + 2; | ||
Func<double, double> f1 = x => x * x * x - 2 * x + 2; | ||
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, -5, 5, 1e-14)), 1e-14); | ||
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, -2, 4, 1e-14)), 1e-14); | ||
} | ||
|
||
[Test] | ||
public void NoRoot() | ||
{ | ||
Func<double, double> f1 = x => x*x + 4; | ||
Func<double, double> f1 = x => x * x + 4; | ||
Assert.That(() => FindRoots.OfFunction(f1, -5, 5, 1e-14), Throws.TypeOf<NonConvergenceException>()); | ||
} | ||
|
||
[Test] | ||
public void Oneeq3() | ||
{ | ||
// Test case from http://www.polymath-software.com/library/nle/Oneeq3.htm | ||
Func<double, double> f1 = T => Math.Exp(21000/T)/(T*T) - 1.11e11; | ||
Func<double, double> f1 = T => Math.Exp(21000 / T) / (T * T) - 1.11e11; | ||
double x = FindRoots.OfFunction(f1, 550, 560, 1e-12); | ||
Assert.AreEqual(551.773822885233, x, 1e-5); | ||
Assert.AreEqual(0, f1(x), 1e-2); | ||
|
@@ -92,7 +93,7 @@ public void Oneeq5() | |
const double BETA = 2.298E-3; | ||
const double GAMMA = 0.283E-6; | ||
const double DH = -57798.0; | ||
return DH + TR*(ALPHA + TR*(BETA/2 + TR*GAMMA/3)) - 298.0*(ALPHA + 298.0*(BETA/2 + 298.0*GAMMA/3)); | ||
return DH + TR * (ALPHA + TR * (BETA / 2 + TR * GAMMA / 3)) - 298.0 * (ALPHA + 298.0 * (BETA / 2 + 298.0 * GAMMA / 3)); | ||
}; | ||
|
||
double x = FindRoots.OfFunction(f1, 3000, 5000, 1e-10); | ||
|
@@ -114,11 +115,11 @@ public void Oneeq6a() | |
const double A = 0.01855; | ||
const double B = -0.01587; | ||
const double P = 100.0; | ||
const double Beta = R*T*B0 - A0 - R*C/(T*T); | ||
const double Gama = -R*T*B0*B + A0*A - R*C*B0/(T*T); | ||
const double Delta = R*B0*B*C/(T*T); | ||
const double Beta = R * T * B0 - A0 - R * C / (T * T); | ||
const double Gama = -R * T * B0 * B + A0 * A - R * C * B0 / (T * T); | ||
const double Delta = R * B0 * B * C / (T * T); | ||
|
||
return R*T/V + Beta/(V*V) + Gama/(V*V*V) + Delta/(V*V*V*V) - P; | ||
return R * T / V + Beta / (V * V) + Gama / (V * V * V) + Delta / (V * V * V * V) - P; | ||
}; | ||
|
||
double x = FindRoots.OfFunction(f1, 0.1, 1); | ||
|
@@ -130,7 +131,7 @@ public void Oneeq6a() | |
public void Oneeq7() | ||
{ | ||
// Test case from http://www.polymath-software.com/library/nle/Oneeq7.htm | ||
Func<double, double> f1 = x => x/(1 - x) - 5*Math.Log(0.4*(1 - x)/(0.4 - 0.5*x)) + 4.45977; | ||
Func<double, double> f1 = x => x / (1 - x) - 5 * Math.Log(0.4 * (1 - x) / (0.4 - 0.5 * x)) + 4.45977; | ||
double r = FindRoots.OfFunction(f1, 0, 0.79, 1e-10); | ||
Assert.AreEqual(0.757396293891, r, 1e-6); | ||
Assert.AreEqual(0, f1(r), 1e-6); | ||
|
@@ -146,7 +147,7 @@ public void Oneeq8() | |
const double b = 40; | ||
const double c = 200; | ||
|
||
return a*v*v + b*Math.Pow(v, 7/4) - c; | ||
return a * v * v + b * Math.Pow(v, 7 / 4) - c; | ||
}; | ||
|
||
double x = FindRoots.OfFunction(f1, 0.01, 1); | ||
|
@@ -158,7 +159,7 @@ public void Oneeq8() | |
public void StackOverflow39935588() | ||
{ | ||
// Roots at -2, 2 | ||
Func<double, double> f1 = x => (x - 3.0)*(x - 4.0); | ||
Func<double, double> f1 = x => (x - 3.0) * (x - 4.0); | ||
Assert.AreEqual(3.0, FindRoots.OfFunction(f1, -2.0, 3.5), 1e-10); | ||
Assert.AreEqual(4.0, FindRoots.OfFunction(f1, 3.5, 5.5), 1e-10); | ||
Assert.AreEqual(0.0, f1(FindRoots.OfFunction(f1, -2.0, 5.5, 1e-14)), 1e-14); | ||
|
@@ -177,7 +178,7 @@ void AssertComplexEqual(Complex expected, Complex actual, double delta) | |
public void QuadraticExpanded(double u, double v, double t) | ||
{ | ||
// t*(x+u)*(x+v) = t*u*v + t*(u+v)*x + t*x^2 | ||
double c = t*u*v, b = t*(u + v), a = t; | ||
double c = t * u * v, b = t * (u + v), a = t; | ||
var x = FindRoots.Quadratic(c, b, a); | ||
Complex x1 = x.Item1, x2 = x.Item2; | ||
|
||
|
@@ -187,8 +188,8 @@ public void QuadraticExpanded(double u, double v, double t) | |
|| x1.AlmostEqualRelative(r2, 1e-14) && x2.AlmostEqualRelative(r1, 1e-14)); | ||
|
||
// Verify they really are roots | ||
AssertComplexEqual(Complex.Zero, c + b*x1 + a*x1*x1, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b*x2 + a*x2*x2, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b * x1 + a * x1 * x1, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b * x2 + a * x2 * x2, 1e-14); | ||
} | ||
|
||
[TestCase(1d, 1d, 1d, -0.5, -0.866025403784439, -0.5, 0.866025403784439)] | ||
|
@@ -207,8 +208,101 @@ public void QuadraticExpected(double c, double b, double a, double x1R, double x | |
|| x1.AlmostEqualRelative(r2, 1e-14) && x2.AlmostEqualRelative(r1, 1e-14)); | ||
|
||
// Verify they really are roots | ||
AssertComplexEqual(Complex.Zero, c + b*x1 + a*x1*x1, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b*x2 + a*x2*x2, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b * x1 + a * x1 * x1, 1e-14); | ||
AssertComplexEqual(Complex.Zero, c + b * x2 + a * x2 * x2, 1e-14); | ||
} | ||
|
||
public static (Complex, Complex) FasterQuadratic(double c, double b, double a) | ||
{ | ||
Complex x1, x2; | ||
if (a == 0.0) | ||
{ | ||
if (b == 0.0) | ||
{ | ||
x1 = new Complex(double.NaN, double.NaN); // Complex.NaN; | ||
x2 = x1; | ||
} | ||
else | ||
{ | ||
x1 = new Complex(-c / b, 0.0); | ||
x2 = x1; | ||
} | ||
} | ||
else | ||
{ | ||
a = 1.0 / a; | ||
b = -0.5 * b * a; | ||
c = c * a; | ||
double delta = b * b - c; | ||
if (delta < 0.0) | ||
{ | ||
double sqrtDelta = Math.Sqrt(-delta); | ||
x1 = new Complex(b, sqrtDelta); | ||
x2 = new Complex(b, -sqrtDelta); | ||
} | ||
else | ||
{ | ||
double sqrtDelta = Math.Sqrt(delta); | ||
x1 = new Complex(b + sqrtDelta, 0.0); | ||
x2 = new Complex(b - sqrtDelta, 0.0); | ||
} | ||
} | ||
return (x1, x2); | ||
} | ||
|
||
[TestCase(70_000_000)] | ||
public void TestQuadraticSpeed(int N) | ||
{ | ||
Complex x1, x2; | ||
double sum = 0.0; | ||
double a, b, c; | ||
Stopwatch stopwatch = new Stopwatch(); | ||
for (int i = 11; i < N; i++) | ||
{ | ||
a = -0.1 * i + 1.0; | ||
b = 0.3464 * i + 5.0; | ||
c = -0.3 * i + 7.0; | ||
(x1, x2) = FindRoots.Quadratic(c, b, a); | ||
sum += x1.Real + x2.Imaginary; | ||
} | ||
Console.WriteLine($"sum={sum}"); | ||
stopwatch.Stop(); | ||
Console.WriteLine($"FindRoots.Quadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s"); | ||
|
||
sum = 0.0; | ||
stopwatch.Restart(); | ||
for (int i = 11; i < N; i++) | ||
{ | ||
a = -0.1 * i + 1.0; | ||
b = 0.3464 * i + 5.0; | ||
c = -0.3 * i + 7.0; | ||
(x1, x2) = FasterQuadratic(c, b, a); | ||
sum += x1.Real + x2.Imaginary; | ||
} | ||
Console.WriteLine($"sum={sum}"); | ||
stopwatch.Stop(); | ||
Console.WriteLine($"FasterQuadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s"); | ||
} | ||
|
||
[TestCase(7_000_000)] | ||
public void TestFasterQuadraticCorrectness(int N) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think these tests bring some value to the quality control of the code base. Performance benchmarks are indeed significant, but the results are not tested. And the correctness test is checking the same results, since the faster quadratic implementation is the same as the production code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave up committing and don't want to explain such an obvious problem anymore. |
||
{ | ||
Complex x1, x2, x3, x4; | ||
double a, b, c, b_a, c_a; | ||
for (int i = 11; i < N; i++) | ||
{ | ||
a = -0.1 * i + 1.0; | ||
b = 0.3464 * i + 5.0; | ||
c = -0.3 * i + 7.0; | ||
b_a = -b / a; | ||
c_a = c / a; | ||
(x1, x2) = FasterQuadratic(c, b, a); | ||
(x3, x4) = FindRoots.Quadratic(c, b, a); | ||
AssertComplexEqual(x1 + x2, b_a, 1e-14); | ||
AssertComplexEqual(x1 * x2, c_a, 1e-14); | ||
AssertComplexEqual(x3 + x4, x1 + x2, 1e-14); | ||
AssertComplexEqual(x3 * x4, x1 * x2, 1e-14); | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ public static class FindRoots | |
/// <param name="maxIterations">Maximum number of iterations. Example: 100.</param> | ||
public static double OfFunction(Func<double, double> f, double lowerBound, double upperBound, double accuracy = 1e-8, int maxIterations = 100) | ||
{ | ||
if (!ZeroCrossingBracketing.ExpandReduce(f, ref lowerBound, ref upperBound, 1.6, maxIterations, maxIterations*10)) | ||
if (!ZeroCrossingBracketing.ExpandReduce(f, ref lowerBound, ref upperBound, 1.6, maxIterations, maxIterations * 10)) | ||
{ | ||
throw new NonConvergenceException("The algorithm has failed, exceeded the number of iterations allowed or there is no root within the provided bounds."); | ||
} | ||
|
@@ -86,17 +86,40 @@ public static double OfFunctionDerivative(Func<double, double> f, Func<double, d | |
/// </summary> | ||
public static (Complex, Complex) Quadratic(double c, double b, double a) | ||
{ | ||
if (b == 0d) | ||
Complex x1, x2; | ||
if (a == 0.0) | ||
{ | ||
var t = new Complex(-c/a, 0d).SquareRoot(); | ||
return (t, -t); | ||
if (b == 0.0) | ||
{ | ||
x1 = Complex.Zero / Complex.Zero; // Complex.NaN; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not write here directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because "Complex.NaN" is not available in net48 and netstandard2.0, it is introduced since net5.0, "Complex.Zero / Complex.Zero" or "Complex(double.NaN, double.NaN)" is used as a substitute for "Complex.NaN" in early versions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, understood. Then I would preferably write |
||
x2 = x1; | ||
} | ||
else | ||
{ | ||
x1 = new Complex(-c / b, 0.0); | ||
x2 = x1; | ||
} | ||
} | ||
|
||
var q = b > 0d | ||
? -0.5*(b + new Complex(b*b - 4*a*c, 0d).SquareRoot()) | ||
: -0.5*(b - new Complex(b*b - 4*a*c, 0d).SquareRoot()); | ||
|
||
return (q/a, c/q); | ||
else | ||
{ | ||
a = 1.0 / a; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this a bit confusing, I would prefer the standard textbook notation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The computational cost of Complex( , ).SquareRoot() and Complex division(return ( , c/q)) is too high and completely unnecessary. So I use Math.Sqrt() to perform the square root operation of real numbers, which has much lower computational cost. As for why it's not written in the well-known form x1,x2=(-b±sqrt(b * b-4 * a * c))/(2 * a), it is because we need to reduce the number of floating-point multiplication and division operations, especially division(which costs too much time), and temporarily store some intermediate results to avoid repeated calculations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't seem to me that performance here is a big deal, I think the discussion leans more towards premature optimization. Did you at least make any benchmarking tests in various platforms and architectures which backup this claim? Just to make my earlier comment clear: I wouldn't reject the PR purely on this, I think unit tests are more critical, and even though performance "might" be better currently (a big IF), I think the readability is more important in such a big code base. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. public static (Complex, Complex) FasterQuadratic(double c, double b, double a)
{
Complex x1, x2;
if (a == 0.0)
{
if (b == 0.0)
{
x1 = new Complex(double.NaN, double.NaN); // Complex.NaN;
x2 = x1;
}
else
{
x1 = new Complex(-c / b, 0.0);
x2 = x1;
}
}
else
{
a = 1.0 / a;
b = -0.5 * b * a;
c = c * a;
double delta = b * b - c;
if (delta < 0.0)
{
double sqrtDelta = Math.Sqrt(-delta);
x1 = new Complex(b, sqrtDelta);
x2 = new Complex(b, -sqrtDelta);
}
else
{
double sqrtDelta = Math.Sqrt(delta);
x1 = new Complex(b + sqrtDelta, 0.0);
x2 = new Complex(b - sqrtDelta, 0.0);
}
}
return (x1, x2);
}
[TestCase(70000000)]
public void TestQuadraticSpeed(int N)
{
Complex x1, x2;
double sum1 = 0.0;
double a, b, c;
Stopwatch stopwatch = new Stopwatch();
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
(x1, x2) = FindRoots.Quadratic(c, b, a);
sum1 += x1.Real + x2.Real + Math.Abs(x2.Imaginary);
}
Console.WriteLine($"sum1={sum1}");
stopwatch.Stop();
Console.WriteLine($"FindRoots.Quadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");
double sum2 = 0.0;
stopwatch.Restart();
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
(x1, x2) = FasterQuadratic(c, b, a);
sum2 += x1.Real + x2.Real + Math.Abs(x2.Imaginary);
}
Console.WriteLine($"sum2={sum2}");
stopwatch.Stop();
Console.WriteLine($"FasterQuadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");
Assert.AreEqual(sum1, sum2, 1e-12 * sum2);
}
[TestCase(7000000)]
public void TestFasterQuadraticCorrectness(int N)
{
Complex x1, x2, x3, x4;
double a, b, c, b_a, c_a;
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
b_a = -b / a;
c_a = c / a;
(x1, x2) = FasterQuadratic(c, b, a);
(x3, x4) = FindRoots.Quadratic(c, b, a);
AssertComplexEqual(x1 + x2, b_a, 1e-14);
AssertComplexEqual(x1 * x2, c_a, 1e-14);
AssertComplexEqual(x3 + x4, x1 + x2, 1e-14);
AssertComplexEqual(x3 * x4, x1 * x2, 1e-14);
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FasterQuadratic is 3.8 times faster than FindRoots.Quadratic, and the latter does not handle the case of a=0, which is a bug. |
||
b = -0.5 * b * a; | ||
c = c * a; | ||
double delta = b * b - c; | ||
if (delta < 0.0) | ||
{ | ||
double sqrtDelta = Math.Sqrt(-delta); | ||
x1 = new Complex(b, sqrtDelta); | ||
x2 = new Complex(b, -sqrtDelta); | ||
} | ||
else | ||
{ | ||
double sqrtDelta = Math.Sqrt(delta); | ||
x1 = new Complex(b + sqrtDelta, 0.0); | ||
x2 = new Complex(b - sqrtDelta, 0.0); | ||
} | ||
} | ||
return (x1, x2); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -143,16 +166,16 @@ public static double[] ChebychevPolynomialFirstKind(int degree, double intervalB | |
} | ||
|
||
// transform to map to [-1..1] interval | ||
double location = 0.5*(intervalBegin + intervalEnd); | ||
double scale = 0.5*(intervalEnd - intervalBegin); | ||
double location = 0.5 * (intervalBegin + intervalEnd); | ||
double scale = 0.5 * (intervalEnd - intervalBegin); | ||
|
||
// evaluate first kind chebychev nodes | ||
double angleFactor = Constants.Pi/(2*degree); | ||
double angleFactor = Constants.Pi / (2 * degree); | ||
|
||
var samples = new double[degree]; | ||
for (int i = 0; i < samples.Length; i++) | ||
{ | ||
samples[i] = location + scale*Math.Cos(((2*i) + 1)*angleFactor); | ||
samples[i] = location + scale * Math.Cos(((2 * i) + 1) * angleFactor); | ||
} | ||
return samples; | ||
} | ||
|
@@ -172,16 +195,16 @@ public static double[] ChebychevPolynomialSecondKind(int degree, double interval | |
} | ||
|
||
// transform to map to [-1..1] interval | ||
double location = 0.5*(intervalBegin + intervalEnd); | ||
double scale = 0.5*(intervalEnd - intervalBegin); | ||
double location = 0.5 * (intervalBegin + intervalEnd); | ||
double scale = 0.5 * (intervalEnd - intervalBegin); | ||
|
||
// evaluate second kind chebychev nodes | ||
double angleFactor = Constants.Pi/(degree + 1); | ||
double angleFactor = Constants.Pi / (degree + 1); | ||
|
||
var samples = new double[degree]; | ||
for (int i = 0; i < samples.Length; i++) | ||
{ | ||
samples[i] = location + scale*Math.Cos((i + 1)*angleFactor); | ||
samples[i] = location + scale * Math.Cos((i + 1) * angleFactor); | ||
} | ||
return samples; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this the exact same implementation of
FindRoots.Quadratic
?