Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NormalCdfRatioLn tests work correctly in arbitrary precision #250

Merged
merged 3 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Runtime/Core/Maths/SpecialFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2418,7 +2418,7 @@ public static double NormalCdfRatioLn(double x, double y, double r, double sqrto
{
return NormalCdfRatioLn(Math.Min(x, y)) + MMath.LnSqrt2PI;
}
else if (r == -1)
else if (r == -1 && sqrtomr2 == 0)
{
// In this case, we should subtract log N(y;0,1)
bool shouldThrow = true;
Expand Down
25 changes: 25 additions & 0 deletions src/Runtime/Factors/IndexOfMaximum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.ML.Probabilistic.Factors
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Collections;
using Distributions;
Expand Down Expand Up @@ -328,6 +329,30 @@ public static void Unpermute<T>(IList<T> array, int[] indices)
}
}

public static void MaxOfOthers_MonteCarlo(IList<Gaussian> array, IList<Gaussian> result)
{
if (array.Count == 0)
return;
if (array.Count == 1)
{
result[0] = Gaussian.Uniform();
return;
}
int iterCount = 1000000;
double[] x = new double[array.Count];
var est = new ArrayEstimator<GaussianEstimator, IList<Gaussian>, Gaussian, double>(array.Count, i => new GaussianEstimator());
for (int iter = 0; iter < iterCount; iter++)
{
for (int i = 0; i < x.Length; i++)
{
x[i] = array[i].Sample();
}
double[] maxOfOthers = Factor.MaxOfOthers(x);
est.Add(maxOfOthers);
}
est.GetDistribution(result);
}

public static void MaxOfOthers_Quadratic(IList<Gaussian> array, IList<Gaussian> result)
{
if (array.Count == 0)
Expand Down
34 changes: 23 additions & 11 deletions src/Tools/PythonScripts/ComputeSpecialFunctionsTestValues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ def normal_cdf2(x, y, r):
x = y
y = z

if r < 0 and False:
# Avoid quadrature with r < 0 since it is sometimes inaccurate.
if r < 0 and x - y <= 0:
# phi(x,y,r) = phi(inf,y,r) - phi(-x,y,-r)
# phi(x,y,r) = phi(x,inf,r) - phi(x,-y,-r)
return ncdf(x) - normal_cdf2(x, -y, -r)

if x > 0 and x + y > 0 and False:
if x > 0 and -x + y <= 0:
return ncdf(y) - normal_cdf2(-x,y,-r)

if x + y > 0:
Expand All @@ -85,18 +86,20 @@ def f(t):
return 1 / (2 * pi * sqrt(omt2)) * exp(-(x * x + y * y - 2 * t * x * y) / (2 * omt2))

omr2 = (1+r)*(1-r)
ymrx = y - r*x
def f2(t):
return npdf(t - x) * normal_cdf((y - r*(x-t))/omr2)
return npdf(t - x) * normal_cdf((ymrx + r*t)/omr2)

# This integral excludes normal_cdf2(x,y,-1)
# which will be zero when x+y <= 0
result, err = safe_quad(f, [-1, r])
#result, err = safe_quad(f2, [0, inf]) # this method has no offset

if mpf(10)**output_dps * abs(err) > abs(result):
print(f"Suspiciously big error when evaluating an integral for normal_cdf2({nstr(x)}, {nstr(y)}, {nstr(r)}).")
print(f"Integral: {nstr(result)}")
print(f"Integral error estimate: {nstr(err)}")
result, err = safe_quad(f2, [0, inf])
if mpf(10)**output_dps * abs(err) > abs(result):
print(f"Suspiciously big error when evaluating an integral for normal_cdf2({nstr(x)}, {nstr(y)}, {nstr(r)}).")
print(f"Integral: {nstr(result)}")
print(f"Integral error estimate: {nstr(err)}")
return result

def safe_quad(f, points):
Expand All @@ -121,7 +124,11 @@ def normal_cdf2_ln(x, y, r):
return ln(normal_cdf2(x, y, r))

def normal_cdf2_ratio_ln(x, y, r, sqrtomr2):
omr2 = 1-r*r
if sqrtomr2 < 0.618:
omr2 = sqrtomr2*sqrtomr2
r = sign(r)*sqrt(1 - omr2)
else:
omr2 = 1-r*r
return normal_cdf2_ln(x, y, r) + (x*x+y*y-2*r*x*y)/2/omr2 + log(2*pi)

def logistic_gaussian(m, v):
Expand Down Expand Up @@ -317,7 +324,7 @@ def beta_cdf(x, a, b):
'NormalCdfLn2.csv': normal_cdf2_ln,
'NormalCdfLogit.csv': lambda x: log(ncdf(x)) - log(ncdf(-x)),
'NormalCdfMomentRatio.csv': normal_cdf_moment_ratio,
#'NormalCdfRatioLn2.csv': normal_cdf2_ratio_ln,
'NormalCdfRatioLn2.csv': normal_cdf2_ratio_ln,
'Tetragamma.csv': lambda x: polygamma(2, x),
'Trigamma.csv': lambda x: polygamma(1, x),
'ulp.csv': None
Expand All @@ -332,7 +339,7 @@ def float_str_python_to_csharp(s):
dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', '..', 'test', 'Tests', 'data', 'SpecialFunctionsValues')
with os.scandir(dir) as it:
for entry in it:
if entry.name.endswith('.csv') and entry.is_file() and entry.name == 'logisticGaussian.csv':
if entry.name.endswith('.csv') and entry.is_file():
print(f'Processing {entry.name}...')
if entry.name not in pair_info.keys() or pair_info[entry.name] == None:
print("Don't know how to process. Skipping.")
Expand All @@ -344,13 +351,18 @@ def float_str_python_to_csharp(s):
arg_count = len(fieldnames) - 1
newrows = []
for row in reader:
if entry.name == 'NormalCdfRatioLn2.csv':
sqrtomr2 = mpf(float_str_csharp_to_python(row['arg3']))
r = mpf(float_str_csharp_to_python(row['arg2']))
if sqrtomr2 < 0.618:
row['arg2'] = nstr(sign(r)*sqrt(1-sqrtomr2*sqrtomr2), output_dps)
newrow = dict(row)
args = []
for i in range(arg_count):
args.append(mpf(float_str_csharp_to_python(row[f'arg{i}'])))
result_in_file = row['expectedresult']
verbose = True
if result_in_file == 'Infinity' or result_in_file == '-Infinity' or result_in_file == 'NaN' or len(newrows) != 127:
if result_in_file == 'Infinity' or result_in_file == '-Infinity' or result_in_file == 'NaN':
newrow['expectedresult'] = result_in_file
else:
try:
Expand Down
28 changes: 27 additions & 1 deletion test/Tests/Core/SpecialFunctionsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ public void NormalCdf2Test()
CheckFunctionValues("NormalCdfLn2", new MathFcn3(MMath.NormalCdfLn), normalcdfln2_pairs);

double[,] normalcdfRatioLn2_pairs = ReadPairs(Path.Combine(TestUtils.DataFolderPath, "SpecialFunctionsValues", "NormalCdfRatioLn2.csv"));
CheckFunctionValues("NormalCdfRatioLn2", new MathFcn4(MMath.NormalCdfRatioLn), normalcdfRatioLn2_pairs);
CheckFunctionValues("NormalCdfRatioLn2", new MathFcn4(NormalCdfRatioLn), normalcdfRatioLn2_pairs);

// The true values are computed using
// x * MMath.NormalCdf(x, y, r) + System.Math.Exp(Gaussian.GetLogProb(x, 0, 1) + MMath.NormalCdfLn(ymrx)) + r * System.Math.Exp(Gaussian.GetLogProb(y, 0, 1) + MMath.NormalCdfLn(xmry))
Expand All @@ -762,6 +762,32 @@ public void NormalCdf2Test()
CheckFunctionValues("NormalCdfIntegralRatio", new MathFcn3(MMath.NormalCdfIntegralRatio), normalcdfIntegralRatio_pairs);
}

// Same as MMath.NormalCdfRatioLn but avoids inconsistent values of r and sqrtomr2 when using arbitrary precision.
private static double NormalCdfRatioLn(double x, double y, double r, double sqrtomr2)
{
if (sqrtomr2 < 0.618)
{
// In this regime, it is more accurate to compute r from sqrtomr2.
// Proof:
// In the presence of roundoff, sqrt(1 - sqrtomr2^2) becomes
// sqrt(1 - sqrtomr2^2*(1+eps))
// = sqrt(1 - sqrtomr2^2 - sqrtomr2^2*eps)
// =approx sqrt(1 - sqrtomr2^2) - sqrtomr2^2*eps*0.5/sqrt(1-sqrtomr2^2)
// The error is below machine precision when
// sqrtomr2^2/sqrt(1-sqrtomr2^2) < 1
// which is equivalent to sqrtomr2 < (sqrt(5)-1)/2 =approx 0.618
double omr2 = sqrtomr2 * sqrtomr2;
r = System.Math.Sign(r) * System.Math.Sqrt(1 - omr2);
}
else
{
// In this regime, it is more accurate to compute sqrtomr2 from r.
double omr2 = 1 - r * r;
sqrtomr2 = System.Math.Sqrt(omr2);
}
return MMath.NormalCdfRatioLn(x, y, r, sqrtomr2);
}

[Fact]
public void NormalCdfIntegralTest()
{
Expand Down
7 changes: 4 additions & 3 deletions test/Tests/Data/SpecialFunctionsValues/NormalCdfRatioLn2.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
arg0,arg1,arg2,arg3,expectedresult
-9945842.62906782515347003936767578125,9945822.06800303049385547637939453125,-0.9891958110248051383450729190371930599212646484375,0.1466002982636051832354695534377242438495159149169921875,268535429229.677886962890625
-312498.3686245033168233931064605712890625,312498.2982211210182867944240570068359375,-0.9999893339082690513208717675297521054744720458984375,0.46186653587789346098180232047525350935757160186767578125e-2,249505.2149799815961159765720367431640625
379473319.22020542621612548828125,-379473319.22020542621612548828125,-0.9999999999999997779553950749686919152736663818359375,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509029795784954330883920192718505859375
-379473319.22020542621612548828125,379473319.22020542621612548828125,-0.99999999999999999991598639455782313129063619083291,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509028748108216091076265162573268973654
-9945842.62906782515347003936767578125,9945822.06800303049385547637939453125,-0.98919581102480513845229381540009496617995993256594,0.1466002982636051832354695534377242438495159149169921875,268535429229.67977366017465839378343324531827501415
-312498.3686245033168233931064605712890625,312498.2982211210182867944240570068359375,-0.99998933390826905132111329040156813258636837058731,0.46186653587789346098180232047525350935757160186767578125e-2,249505.2149808101614030573818221460088509383223012
379473319.22020542621612548828125,-379473319.22020542621612548828125,-0.99999999999999999991598639455782313129063619083291,0.40991122317442558334301211400382304594902649341747746802866458892822265625e-9,-21.613063775509028748108216091076265162573268973654