Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quadratic deal with a=0 and speed up #1103

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 112 additions & 18 deletions src/Numerics.Tests/RootFindingTests/FindRootsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
using System;
using Complex = System.Numerics.Complex;
using NUnit.Framework;
using System.Diagnostics;

namespace MathNet.Numerics.Tests.RootFindingTests
{
Expand All @@ -40,7 +41,7 @@ internal class FindRootsTest
public void MultipleRoots()
{
// Roots at -2, 2
Func<double, double> f1 = x => x*x - 4;
Func<double, double> f1 = x => x * x - 4;
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, 0, 5, 1e-14)));
Assert.AreEqual(-2, FindRoots.OfFunction(f1, -5, -1, 1e-14));
Assert.AreEqual(2, FindRoots.OfFunction(f1, 1, 4, 1e-14));
Expand All @@ -49,7 +50,7 @@ public void MultipleRoots()
Assert.AreEqual(2, FindRoots.OfFunction(x => -f1(x), 1, 4, 1e-14));

// Roots at 3, 4
Func<double, double> f2 = x => (x - 3)*(x - 4);
Func<double, double> f2 = x => (x - 3) * (x - 4);
Assert.AreEqual(0, f2(FindRoots.OfFunction(f2, 3.5, 5, 1e-14)), 1e-14);
Assert.AreEqual(3, FindRoots.OfFunction(f2, -5, 3.5, 1e-14));
Assert.AreEqual(4, FindRoots.OfFunction(f2, 3.2, 5, 1e-14));
Expand All @@ -60,23 +61,23 @@ public void MultipleRoots()
[Test]
public void LocalMinima()
{
Func<double, double> f1 = x => x*x*x - 2*x + 2;
Func<double, double> f1 = x => x * x * x - 2 * x + 2;
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, -5, 5, 1e-14)), 1e-14);
Assert.AreEqual(0, f1(FindRoots.OfFunction(f1, -2, 4, 1e-14)), 1e-14);
}

[Test]
public void NoRoot()
{
Func<double, double> f1 = x => x*x + 4;
Func<double, double> f1 = x => x * x + 4;
Assert.That(() => FindRoots.OfFunction(f1, -5, 5, 1e-14), Throws.TypeOf<NonConvergenceException>());
}

[Test]
public void Oneeq3()
{
// Test case from http://www.polymath-software.com/library/nle/Oneeq3.htm
Func<double, double> f1 = T => Math.Exp(21000/T)/(T*T) - 1.11e11;
Func<double, double> f1 = T => Math.Exp(21000 / T) / (T * T) - 1.11e11;
double x = FindRoots.OfFunction(f1, 550, 560, 1e-12);
Assert.AreEqual(551.773822885233, x, 1e-5);
Assert.AreEqual(0, f1(x), 1e-2);
Expand All @@ -92,7 +93,7 @@ public void Oneeq5()
const double BETA = 2.298E-3;
const double GAMMA = 0.283E-6;
const double DH = -57798.0;
return DH + TR*(ALPHA + TR*(BETA/2 + TR*GAMMA/3)) - 298.0*(ALPHA + 298.0*(BETA/2 + 298.0*GAMMA/3));
return DH + TR * (ALPHA + TR * (BETA / 2 + TR * GAMMA / 3)) - 298.0 * (ALPHA + 298.0 * (BETA / 2 + 298.0 * GAMMA / 3));
};

double x = FindRoots.OfFunction(f1, 3000, 5000, 1e-10);
Expand All @@ -114,11 +115,11 @@ public void Oneeq6a()
const double A = 0.01855;
const double B = -0.01587;
const double P = 100.0;
const double Beta = R*T*B0 - A0 - R*C/(T*T);
const double Gama = -R*T*B0*B + A0*A - R*C*B0/(T*T);
const double Delta = R*B0*B*C/(T*T);
const double Beta = R * T * B0 - A0 - R * C / (T * T);
const double Gama = -R * T * B0 * B + A0 * A - R * C * B0 / (T * T);
const double Delta = R * B0 * B * C / (T * T);

return R*T/V + Beta/(V*V) + Gama/(V*V*V) + Delta/(V*V*V*V) - P;
return R * T / V + Beta / (V * V) + Gama / (V * V * V) + Delta / (V * V * V * V) - P;
};

double x = FindRoots.OfFunction(f1, 0.1, 1);
Expand All @@ -130,7 +131,7 @@ public void Oneeq6a()
public void Oneeq7()
{
// Test case from http://www.polymath-software.com/library/nle/Oneeq7.htm
Func<double, double> f1 = x => x/(1 - x) - 5*Math.Log(0.4*(1 - x)/(0.4 - 0.5*x)) + 4.45977;
Func<double, double> f1 = x => x / (1 - x) - 5 * Math.Log(0.4 * (1 - x) / (0.4 - 0.5 * x)) + 4.45977;
double r = FindRoots.OfFunction(f1, 0, 0.79, 1e-10);
Assert.AreEqual(0.757396293891, r, 1e-6);
Assert.AreEqual(0, f1(r), 1e-6);
Expand All @@ -146,7 +147,7 @@ public void Oneeq8()
const double b = 40;
const double c = 200;

return a*v*v + b*Math.Pow(v, 7/4) - c;
return a * v * v + b * Math.Pow(v, 7 / 4) - c;
};

double x = FindRoots.OfFunction(f1, 0.01, 1);
Expand All @@ -158,7 +159,7 @@ public void Oneeq8()
public void StackOverflow39935588()
{
// Roots at -2, 2
Func<double, double> f1 = x => (x - 3.0)*(x - 4.0);
Func<double, double> f1 = x => (x - 3.0) * (x - 4.0);
Assert.AreEqual(3.0, FindRoots.OfFunction(f1, -2.0, 3.5), 1e-10);
Assert.AreEqual(4.0, FindRoots.OfFunction(f1, 3.5, 5.5), 1e-10);
Assert.AreEqual(0.0, f1(FindRoots.OfFunction(f1, -2.0, 5.5, 1e-14)), 1e-14);
Expand All @@ -177,7 +178,7 @@ void AssertComplexEqual(Complex expected, Complex actual, double delta)
public void QuadraticExpanded(double u, double v, double t)
{
// t*(x+u)*(x+v) = t*u*v + t*(u+v)*x + t*x^2
double c = t*u*v, b = t*(u + v), a = t;
double c = t * u * v, b = t * (u + v), a = t;
var x = FindRoots.Quadratic(c, b, a);
Complex x1 = x.Item1, x2 = x.Item2;

Expand All @@ -187,8 +188,8 @@ public void QuadraticExpanded(double u, double v, double t)
|| x1.AlmostEqualRelative(r2, 1e-14) && x2.AlmostEqualRelative(r1, 1e-14));

// Verify they really are roots
AssertComplexEqual(Complex.Zero, c + b*x1 + a*x1*x1, 1e-14);
AssertComplexEqual(Complex.Zero, c + b*x2 + a*x2*x2, 1e-14);
AssertComplexEqual(Complex.Zero, c + b * x1 + a * x1 * x1, 1e-14);
AssertComplexEqual(Complex.Zero, c + b * x2 + a * x2 * x2, 1e-14);
}

[TestCase(1d, 1d, 1d, -0.5, -0.866025403784439, -0.5, 0.866025403784439)]
Expand All @@ -207,8 +208,101 @@ public void QuadraticExpected(double c, double b, double a, double x1R, double x
|| x1.AlmostEqualRelative(r2, 1e-14) && x2.AlmostEqualRelative(r1, 1e-14));

// Verify they really are roots
AssertComplexEqual(Complex.Zero, c + b*x1 + a*x1*x1, 1e-14);
AssertComplexEqual(Complex.Zero, c + b*x2 + a*x2*x2, 1e-14);
AssertComplexEqual(Complex.Zero, c + b * x1 + a * x1 * x1, 1e-14);
AssertComplexEqual(Complex.Zero, c + b * x2 + a * x2 * x2, 1e-14);
}

public static (Complex, Complex) FasterQuadratic(double c, double b, double a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the exact same implementation of FindRoots.Quadratic?

{
Complex x1, x2;
if (a == 0.0)
{
if (b == 0.0)
{
x1 = new Complex(double.NaN, double.NaN); // Complex.NaN;
x2 = x1;
}
else
{
x1 = new Complex(-c / b, 0.0);
x2 = x1;
}
}
else
{
a = 1.0 / a;
b = -0.5 * b * a;
c = c * a;
double delta = b * b - c;
if (delta < 0.0)
{
double sqrtDelta = Math.Sqrt(-delta);
x1 = new Complex(b, sqrtDelta);
x2 = new Complex(b, -sqrtDelta);
}
else
{
double sqrtDelta = Math.Sqrt(delta);
x1 = new Complex(b + sqrtDelta, 0.0);
x2 = new Complex(b - sqrtDelta, 0.0);
}
}
return (x1, x2);
}

[TestCase(70_000_000)]
public void TestQuadraticSpeed(int N)
{
Complex x1, x2;
double sum = 0.0;
double a, b, c;
Stopwatch stopwatch = new Stopwatch();
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
(x1, x2) = FindRoots.Quadratic(c, b, a);
sum += x1.Real + x2.Imaginary;
}
Console.WriteLine($"sum={sum}");
stopwatch.Stop();
Console.WriteLine($"FindRoots.Quadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");

sum = 0.0;
stopwatch.Restart();
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
(x1, x2) = FasterQuadratic(c, b, a);
sum += x1.Real + x2.Imaginary;
}
Console.WriteLine($"sum={sum}");
stopwatch.Stop();
Console.WriteLine($"FasterQuadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");
}

[TestCase(7_000_000)]
public void TestFasterQuadraticCorrectness(int N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these tests bring some value to the quality control of the code base. Performance benchmarks are indeed significant, but the results are not tested. And the correctness test is checking the same results, since the faster quadratic implementation is the same as the production code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave up committing and don't want to explain such an obvious problem anymore.

{
Complex x1, x2, x3, x4;
double a, b, c, b_a, c_a;
for (int i = 11; i < N; i++)
{
a = -0.1 * i + 1.0;
b = 0.3464 * i + 5.0;
c = -0.3 * i + 7.0;
b_a = -b / a;
c_a = c / a;
(x1, x2) = FasterQuadratic(c, b, a);
(x3, x4) = FindRoots.Quadratic(c, b, a);
AssertComplexEqual(x1 + x2, b_a, 1e-14);
AssertComplexEqual(x1 * x2, c_a, 1e-14);
AssertComplexEqual(x3 + x4, x1 + x2, 1e-14);
AssertComplexEqual(x3 * x4, x1 * x2, 1e-14);
}
}
}
}
59 changes: 41 additions & 18 deletions src/Numerics/FindRoots.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static class FindRoots
/// <param name="maxIterations">Maximum number of iterations. Example: 100.</param>
public static double OfFunction(Func<double, double> f, double lowerBound, double upperBound, double accuracy = 1e-8, int maxIterations = 100)
{
if (!ZeroCrossingBracketing.ExpandReduce(f, ref lowerBound, ref upperBound, 1.6, maxIterations, maxIterations*10))
if (!ZeroCrossingBracketing.ExpandReduce(f, ref lowerBound, ref upperBound, 1.6, maxIterations, maxIterations * 10))
{
throw new NonConvergenceException("The algorithm has failed, exceeded the number of iterations allowed or there is no root within the provided bounds.");
}
Expand Down Expand Up @@ -86,17 +86,40 @@ public static double OfFunctionDerivative(Func<double, double> f, Func<double, d
/// </summary>
public static (Complex, Complex) Quadratic(double c, double b, double a)
{
if (b == 0d)
Complex x1, x2;
if (a == 0.0)
{
var t = new Complex(-c/a, 0d).SquareRoot();
return (t, -t);
if (b == 0.0)
{
x1 = Complex.Zero / Complex.Zero; // Complex.NaN;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not write here directly Complex.NaN?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because "Complex.NaN" is not available in net48 and netstandard2.0, it is introduced since net5.0, "Complex.Zero / Complex.Zero" or "Complex(double.NaN, double.NaN)" is used as a substitute for "Complex.NaN" in early versions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, understood. Then I would preferably write Complex(double.NaN, double.NaN)

x2 = x1;
}
else
{
x1 = new Complex(-c / b, 0.0);
x2 = x1;
}
}

var q = b > 0d
? -0.5*(b + new Complex(b*b - 4*a*c, 0d).SquareRoot())
: -0.5*(b - new Complex(b*b - 4*a*c, 0d).SquareRoot());

return (q/a, c/q);
else
{
a = 1.0 / a;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this a bit confusing, I would prefer the standard textbook notation of b^2 - 4ac

Copy link
Author

@FrogGuaGuaGua FrogGuaGuaGua Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computational cost of Complex( , ).SquareRoot() and Complex division(return ( , c/q)) is too high and completely unnecessary. So I use Math.Sqrt() to perform the square root operation of real numbers, which has much lower computational cost. As for why it's not written in the well-known form x1,x2=(-b±sqrt(b * b-4 * a * c))/(2 * a), it is because we need to reduce the number of floating-point multiplication and division operations, especially division(which costs too much time), and temporarily store some intermediate results to avoid repeated calculations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to me that performance here is a big deal, I think the discussion leans more towards premature optimization. Did you at least make any benchmarking tests in various platforms and architectures which backup this claim?

Just to make my earlier comment clear: I wouldn't reject the PR purely on this, I think unit tests are more critical, and even though performance "might" be better currently (a big IF), I think the readability is more important in such a big code base.

Copy link
Author

@FrogGuaGuaGua FrogGuaGuaGua Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public static (Complex, Complex) FasterQuadratic(double c, double b, double a)
{
   Complex x1, x2;
   if (a == 0.0)
   {
	   if (b == 0.0)
	   {
		   x1 = new Complex(double.NaN, double.NaN);  // Complex.NaN;
		   x2 = x1;
	   }
	   else
	   {
		   x1 = new Complex(-c / b, 0.0);
		   x2 = x1;
	   }
   }
   else
   {
	   a = 1.0 / a;
	   b = -0.5 * b * a;
	   c = c * a;
	   double delta = b * b - c;
	   if (delta < 0.0)
	   {
		   double sqrtDelta = Math.Sqrt(-delta);
		   x1 = new Complex(b, sqrtDelta);
		   x2 = new Complex(b, -sqrtDelta);
	   }
	   else
	   {
		   double sqrtDelta = Math.Sqrt(delta);
		   x1 = new Complex(b + sqrtDelta, 0.0);
		   x2 = new Complex(b - sqrtDelta, 0.0);
	   }
   }
   return (x1, x2);
}

[TestCase(70000000)]
public void TestQuadraticSpeed(int N)
{
   Complex x1, x2;
   double sum1 = 0.0;
   double a, b, c;
   Stopwatch stopwatch = new Stopwatch();
   for (int i = 11; i < N; i++)
   {
	   a = -0.1 * i + 1.0;
	   b = 0.3464 * i + 5.0;
	   c = -0.3 * i + 7.0;
	   (x1, x2) = FindRoots.Quadratic(c, b, a);
	   sum1 += x1.Real + x2.Real + Math.Abs(x2.Imaginary);
   }
   Console.WriteLine($"sum1={sum1}");
   stopwatch.Stop();
   Console.WriteLine($"FindRoots.Quadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");

   double sum2 = 0.0;
   stopwatch.Restart();
   for (int i = 11; i < N; i++)
   {
	   a = -0.1 * i + 1.0;
	   b = 0.3464 * i + 5.0;
	   c = -0.3 * i + 7.0;
	   (x1, x2) = FasterQuadratic(c, b, a);
	   sum2 += x1.Real + x2.Real + Math.Abs(x2.Imaginary);
   }
   Console.WriteLine($"sum2={sum2}");
   stopwatch.Stop();
   Console.WriteLine($"FasterQuadratic time: {stopwatch.ElapsedMilliseconds * 0.001:F3}s");
   Assert.AreEqual(sum1, sum2, 1e-12 * sum2);
}

[TestCase(7000000)]
public void TestFasterQuadraticCorrectness(int N)
{
   Complex x1, x2, x3, x4;
   double a, b, c, b_a, c_a;
   for (int i = 11; i < N; i++)
   {
	   a = -0.1 * i + 1.0;
	   b = 0.3464 * i + 5.0;
	   c = -0.3 * i + 7.0;
	   b_a = -b / a;
	   c_a = c / a;
	   (x1, x2) = FasterQuadratic(c, b, a);
	   (x3, x4) = FindRoots.Quadratic(c, b, a);
	   AssertComplexEqual(x1 + x2, b_a, 1e-14);
	   AssertComplexEqual(x1 * x2, c_a, 1e-14);
	   AssertComplexEqual(x3 + x4, x1 + x2, 1e-14);
	   AssertComplexEqual(x3 * x4, x1 * x2, 1e-14);
   }
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FasterQuadratic is 3.8 times faster than FindRoots.Quadratic, and the latter does not handle the case of a=0, which is a bug.

b = -0.5 * b * a;
c = c * a;
double delta = b * b - c;
if (delta < 0.0)
{
double sqrtDelta = Math.Sqrt(-delta);
x1 = new Complex(b, sqrtDelta);
x2 = new Complex(b, -sqrtDelta);
}
else
{
double sqrtDelta = Math.Sqrt(delta);
x1 = new Complex(b + sqrtDelta, 0.0);
x2 = new Complex(b - sqrtDelta, 0.0);
}
}
return (x1, x2);
}

/// <summary>
Expand Down Expand Up @@ -143,16 +166,16 @@ public static double[] ChebychevPolynomialFirstKind(int degree, double intervalB
}

// transform to map to [-1..1] interval
double location = 0.5*(intervalBegin + intervalEnd);
double scale = 0.5*(intervalEnd - intervalBegin);
double location = 0.5 * (intervalBegin + intervalEnd);
double scale = 0.5 * (intervalEnd - intervalBegin);

// evaluate first kind chebychev nodes
double angleFactor = Constants.Pi/(2*degree);
double angleFactor = Constants.Pi / (2 * degree);

var samples = new double[degree];
for (int i = 0; i < samples.Length; i++)
{
samples[i] = location + scale*Math.Cos(((2*i) + 1)*angleFactor);
samples[i] = location + scale * Math.Cos(((2 * i) + 1) * angleFactor);
}
return samples;
}
Expand All @@ -172,16 +195,16 @@ public static double[] ChebychevPolynomialSecondKind(int degree, double interval
}

// transform to map to [-1..1] interval
double location = 0.5*(intervalBegin + intervalEnd);
double scale = 0.5*(intervalEnd - intervalBegin);
double location = 0.5 * (intervalBegin + intervalEnd);
double scale = 0.5 * (intervalEnd - intervalBegin);

// evaluate second kind chebychev nodes
double angleFactor = Constants.Pi/(degree + 1);
double angleFactor = Constants.Pi / (degree + 1);

var samples = new double[degree];
for (int i = 0; i < samples.Length; i++)
{
samples[i] = location + scale*Math.Cos((i + 1)*angleFactor);
samples[i] = location + scale * Math.Cos((i + 1) * angleFactor);
}
return samples;
}
Expand Down
Loading