Skip to content

Commit

Permalink
NormalCdfRatioLn tests work correctly in arbitrary precision (dotnet#250
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tminka committed Aug 24, 2020
1 parent 870fb13 commit 4b10322
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/Runtime/Core/Maths/SpecialFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2418,7 +2418,7 @@ public static double NormalCdfRatioLn(double x, double y, double r, double sqrto
{
return NormalCdfRatioLn(Math.Min(x, y)) + MMath.LnSqrt2PI;
}
else if (r == -1)
else if (r == -1 && sqrtomr2 == 0)
{
// In this case, we should subtract log N(y;0,1)
bool shouldThrow = true;
Expand Down
25 changes: 25 additions & 0 deletions src/Runtime/Factors/IndexOfMaximum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.ML.Probabilistic.Factors
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Collections;
using Distributions;
Expand Down Expand Up @@ -328,6 +329,30 @@ public static void Unpermute<T>(IList<T> array, int[] indices)
}
}

public static void MaxOfOthers_MonteCarlo(IList<Gaussian> array, IList<Gaussian> result)
{
if (array.Count == 0)
return;
if (array.Count == 1)
{
result[0] = Gaussian.Uniform();
return;
}
int iterCount = 1000000;
double[] x = new double[array.Count];
var est = new ArrayEstimator<GaussianEstimator, IList<Gaussian>, Gaussian, double>(array.Count, i => new GaussianEstimator());
for (int iter = 0; iter < iterCount; iter++)
{
for (int i = 0; i < x.Length; i++)
{
x[i] = array[i].Sample();
}
double[] maxOfOthers = Factor.MaxOfOthers(x);
est.Add(maxOfOthers);
}
est.GetDistribution(result);
}

public static void MaxOfOthers_Quadratic(IList<Gaussian> array, IList<Gaussian> result)
{
if (array.Count == 0)
Expand Down
34 changes: 23 additions & 11 deletions src/Tools/PythonScripts/ComputeSpecialFunctionsTestValues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ def normal_cdf2(x, y, r):
x = y
y = z

if r < 0 and False:
# Avoid quadrature with r < 0 since it is sometimes inaccurate.
if r < 0 and x - y <= 0:
# phi(x,y,r) = phi(inf,y,r) - phi(-x,y,-r)
# phi(x,y,r) = phi(x,inf,r) - phi(x,-y,-r)
return ncdf(x) - normal_cdf2(x, -y, -r)

if x > 0 and x + y > 0 and False:
if x > 0 and -x + y <= 0:
return ncdf(y) - normal_cdf2(-x,y,-r)

if x + y > 0:
Expand All @@ -85,18 +86,20 @@ def f(t):
return 1 / (2 * pi * sqrt(omt2)) * exp(-(x * x + y * y - 2 * t * x * y) / (2 * omt2))

omr2 = (1+r)*(1-r)
ymrx = y - r*x
def f2(t):
return npdf(t - x) * normal_cdf((y - r*(x-t))/omr2)
return npdf(t - x) * normal_cdf((ymrx + r*t)/omr2)

# This integral excludes normal_cdf2(x,y,-1)
# which will be zero when x+y <= 0
result, err = safe_quad(f, [-1, r])
#result, err = safe_quad(f2, [0, inf]) # this method has no offset

if mpf(10)**output_dps * abs(err) > abs(result):
print(f"Suspiciously big error when evaluating an integral for normal_cdf2({nstr(x)}, {nstr(y)}, {nstr(r)}).")
print(f"Integral: {nstr(result)}")
print(f"Integral error estimate: {nstr(err)}")
result, err = safe_quad(f2, [0, inf])
if mpf(10)**output_dps * abs(err) > abs(result):
print(f"Suspiciously big error when evaluating an integral for normal_cdf2({nstr(x)}, {nstr(y)}, {nstr(r)}).")
print(f"Integral: {nstr(result)}")
print(f"Integral error estimate: {nstr(err)}")
return result

def safe_quad(f, points):
Expand All @@ -121,7 +124,11 @@ def normal_cdf2_ln(x, y, r):
return ln(normal_cdf2(x, y, r))

def normal_cdf2_ratio_ln(x, y, r, sqrtomr2):
omr2 = 1-r*r
if sqrtomr2 < 0.618:
omr2 = sqrtomr2*sqrtomr2
r = sign(r)*sqrt(1 - omr2)
else:
omr2 = 1-r*r
return normal_cdf2_ln(x, y, r) + (x*x+y*y-2*r*x*y)/2/omr2 + log(2*pi)

def logistic_gaussian(m, v):
Expand Down Expand Up @@ -317,7 +324,7 @@ def beta_cdf(x, a, b):
'NormalCdfLn2.csv': normal_cdf2_ln,
'NormalCdfLogit.csv': lambda x: log(ncdf(x)) - log(ncdf(-x)),
'NormalCdfMomentRatio.csv': normal_cdf_moment_ratio,
#'NormalCdfRatioLn2.csv': normal_cdf2_ratio_ln,
'NormalCdfRatioLn2.csv': normal_cdf2_ratio_ln,
'Tetragamma.csv': lambda x: polygamma(2, x),
'Trigamma.csv': lambda x: polygamma(1, x),
'ulp.csv': None
Expand All @@ -332,7 +339,7 @@ def float_str_python_to_csharp(s):
dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', '..', 'test', 'Tests', 'data', 'SpecialFunctionsValues')
with os.scandir(dir) as it:
for entry in it:
if entry.name.endswith('.csv') and entry.is_file() and entry.name == 'logisticGaussian.csv':
if entry.name.endswith('.csv') and entry.is_file():
print(f'Processing {entry.name}...')
if entry.name not in pair_info.keys() or pair_info[entry.name] == None:
print("Don't know how to process. Skipping.")
Expand All @@ -344,13 +351,18 @@ def float_str_python_to_csharp(s):
arg_count = len(fieldnames) - 1
newrows = []
for row in reader:
if entry.name == 'NormalCdfRatioLn2.csv':
sqrtomr2 = mpf(float_str_csharp_to_python(row['arg3']))
r = mpf(float_str_csharp_to_python(row['arg2']))
if sqrtomr2 < 0.618:
row['arg2'] = nstr(sign(r)*sqrt(1-sqrtomr2*sqrtomr2), output_dps)
newrow = dict(row)
args = []
for i in range(arg_count):
args.append(mpf(float_str_csharp_to_python(row[f'arg{i}'])))
result_in_file = row['expectedresult']
verbose = True
if result_in_file == 'Infinity' or result_in_file == '-Infinity' or result_in_file == 'NaN' or len(newrows) != 127:
if result_in_file == 'Infinity' or result_in_file == '-Infinity' or result_in_file == 'NaN':
newrow['expectedresult'] = result_in_file
else:
try:
Expand Down
28 changes: 27 additions & 1 deletion test/Tests/Core/SpecialFunctionsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ public void NormalCdf2Test()
CheckFunctionValues("NormalCdfLn2", new MathFcn3(MMath.NormalCdfLn), normalcdfln2_pairs);

double[,] normalcdfRatioLn2_pairs = ReadPairs(Path.Combine(TestUtils.DataFolderPath, "SpecialFunctionsValues", "NormalCdfRatioLn2.csv"));
CheckFunctionValues("NormalCdfRatioLn2", new MathFcn4(MMath.NormalCdfRatioLn), normalcdfRatioLn2_pairs);
CheckFunctionValues("NormalCdfRatioLn2", new MathFcn4(NormalCdfRatioLn), normalcdfRatioLn2_pairs);

// The true values are computed using
// x * MMath.NormalCdf(x, y, r) + System.Math.Exp(Gaussian.GetLogProb(x, 0, 1) + MMath.NormalCdfLn(ymrx)) + r * System.Math.Exp(Gaussian.GetLogProb(y, 0, 1) + MMath.NormalCdfLn(xmry))
Expand All @@ -762,6 +762,32 @@ public void NormalCdf2Test()
CheckFunctionValues("NormalCdfIntegralRatio", new MathFcn3(MMath.NormalCdfIntegralRatio), normalcdfIntegralRatio_pairs);
}

// Same as MMath.NormalCdfRatioLn but avoids inconsistent values of r and sqrtomr2 when using arbitrary precision.
private static double NormalCdfRatioLn(double x, double y, double r, double sqrtomr2)
{
if (sqrtomr2 < 0.618)
{
// In this regime, it is more accurate to compute r from sqrtomr2.
// Proof:
// In the presence of roundoff, sqrt(1 - sqrtomr2^2) becomes
// sqrt(1 - sqrtomr2^2*(1+eps))
// = sqrt(1 - sqrtomr2^2 - sqrtomr2^2*eps)
// =approx sqrt(1 - sqrtomr2^2) - sqrtomr2^2*eps*0.5/sqrt(1-sqrtomr2^2)
// The error is below machine precision when
// sqrtomr2^2/sqrt(1-sqrtomr2^2) < 1
// which is equivalent to sqrtomr2 < (sqrt(5)-1)/2 =approx 0.618
double omr2 = sqrtomr2 * sqrtomr2;
r = System.Math.Sign(r) * System.Math.Sqrt(1 - omr2);
}
else
{
// In this regime, it is more accurate to compute sqrtomr2 from r.
double omr2 = 1 - r * r;
sqrtomr2 = System.Math.Sqrt(omr2);
}
return MMath.NormalCdfRatioLn(x, y, r, sqrtomr2);
}

[Fact]
public void NormalCdfIntegralTest()
{
Expand Down
7 changes: 4 additions & 3 deletions test/Tests/Data/SpecialFunctionsValues/NormalCdfRatioLn2.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
arg0,arg1,arg2,arg3,expectedresult
-9945842.62906782515347003936767578125,9945822.06800303049385547637939453125,-0.9891958110248051383450729190371930599212646484375,0.1466002982636051832354695534377242438495159149169921875,268535429229.677886962890625
-312498.3686245033168233931064605712890625,312498.2982211210182867944240570068359375,-0.9999893339082690513208717675297521054744720458984375,0.46186653587789346098180232047525350935757160186767578125e-2,249505.2149799815961159765720367431640625
379473319.22020542621612548828125,-379473319.22020542621612548828125,-0.9999999999999997779553950749686919152736663818359375,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509029795784954330883920192718505859375
-379473319.22020542621612548828125,379473319.22020542621612548828125,-0.99999999999999999991598639455782313129063619083291,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509028748108216091076265162573268973654
-9945842.62906782515347003936767578125,9945822.06800303049385547637939453125,-0.98919581102480513845229381540009496617995993256594,0.1466002982636051832354695534377242438495159149169921875,268535429229.67977366017465839378343324531827501415
-312498.3686245033168233931064605712890625,312498.2982211210182867944240570068359375,-0.99998933390826905132111329040156813258636837058731,0.46186653587789346098180232047525350935757160186767578125e-2,249505.2149808101614030573818221460088509383223012
379473319.22020542621612548828125,-379473319.22020542621612548828125,-0.99999999999999999991598639455782313129063619083291,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509028748108216091076265162573268973654

0 comments on commit 4b10322

Please sign in to comment.