Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved accuracy of FastSumOp #272

Merged
merged 3 commits into from
Aug 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Compiler/Infer/ModelCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public IDeclarationProvider DeclarationProvider
/// Controls if debug information is included in generated DLLs. If true, debug information will
/// be included which allows stepping through the generated code in a debugger.
/// </summary>
public bool IncludeDebugInformation { get; set; } = true;
public bool IncludeDebugInformation { get; set; } = false;

/// <summary>
/// If true, prints compilation progress information to the console during model compilation.
Expand Down
87 changes: 54 additions & 33 deletions src/Runtime/Factors/Sum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.ML.Probabilistic.Factors
using Microsoft.ML.Probabilistic.Factors.Attributes;

/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="FastSumOp"]/doc/*'/>
[FactorMethod(typeof(Factor), "Sum", typeof(double[]), Default=true)]
[FactorMethod(typeof(Factor), "Sum", typeof(double[]), Default = true)]
[Quality(QualityBand.Mature)]
public static class FastSumOp
{
Expand Down Expand Up @@ -97,7 +97,6 @@ public static GaussianList ArrayAverageConditional<GaussianList>(
}
return result;
}
double totalMean, totalVariance;
double sumMean, sumVariance;
sum.GetMeanAndVarianceImproper(out sumMean, out sumVariance);

Expand All @@ -123,41 +122,63 @@ public static GaussianList ArrayAverageConditional<GaussianList>(
if (indexOfUniform >= 0)
{
// exactly one element of array is uniform.
totalMean = 0;
totalVariance = 0;
for (int i = 0; i < array.Count; i++)
{
if (i == indexOfUniform)
continue;
double meani, variancei;
array[i].GetMeanAndVarianceImproper(out meani, out variancei);
totalMean += meani;
totalVariance += variancei;
result[i] = new Gaussian();
}
// totalMean = sum_{i except indexOfUniform} array[i].GetMean()
// totalVariance = sum_{i except indexOfUniform} array[i].GetVariance()
result[indexOfUniform] = new Gaussian(sumMean - totalMean, sumVariance + totalVariance);
var (sumExceptMean, sumExceptVariance) = SumExcept(array, indexOfUniform);
// sumExceptMean = sum_{i except indexOfUniform} array[i].GetMean()
// sumExceptVariance = sum_{i except indexOfUniform} array[i].GetVariance()
result[indexOfUniform] = new Gaussian(sumMean - sumExceptMean, sumVariance + sumExceptVariance);
return result;
}
// at this point, the array has no uniform elements.

// get the mean and variance of sum of all the Gaussians;
double totalMean, totalVariance;
to_sum.GetMeanAndVarianceImproper(out totalMean, out totalVariance);

// subtract it off from the mean and variance of incoming Gaussian from Sum
totalMean = sumMean - totalMean;
totalVariance = sumVariance + totalVariance;
double totalMsgMean = sumMean - totalMean;
double totalMsgVariance = sumVariance + totalVariance;

for (int i = 0; i < array.Count; i++)
{
double meani, variancei;
array[i].GetMeanAndVarianceImproper(out meani, out variancei);
result[i] = new Gaussian(totalMean + meani, totalVariance - variancei);
double msgMean = totalMsgMean + meani;
double msgVariance = totalMsgVariance - variancei;
if (Math.Abs(msgVariance) < Math.Abs(totalMsgVariance) * 1e-10)
{
// Avoid loss of precision by recalculating the sum
// This should happen for at most one i
var (sumExceptMean, sumExceptVariance) = SumExcept(array, i);
msgMean = sumMean - sumExceptMean;
msgVariance = sumVariance + sumExceptVariance;
}
result[i] = new Gaussian(msgMean, msgVariance);
}
return result;
}

public static (double, double) SumExcept(IList<Gaussian> array, int excludedIndex)
{
double totalMean = 0;
double totalVariance = 0;
for (int j = 0; j < array.Count; j++)
{
if (j == excludedIndex)
continue;
double meanj, variancej;
array[j].GetMeanAndVarianceImproper(out meanj, out variancej);
totalMean += meanj;
totalVariance += variancej;
}
return (totalMean, totalVariance);
}

/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="FastSumOp"]/message_doc[@name="ArrayAverageConditional(Gaussian, Gaussian, Gaussian[], Gaussian[])"]/*'/>
public static Gaussian[] ArrayAverageConditional([SkipIfUniform] Gaussian sum, [Fresh] Gaussian to_sum, Gaussian[] array, Gaussian[] result)
{
Expand Down Expand Up @@ -459,19 +480,19 @@ public static GaussianList ArrayAverageLogarithm5<GaussianList>([SkipIfUniform]
public static GaussianList ArrayAverageLogarithm4<GaussianList>([SkipIfUniform] Gaussian sum, [Proper] IList<Gaussian> array, GaussianList to_array)
where GaussianList : IList<Gaussian>, ICloneable
{
GaussianList array2 = (GaussianList) to_array.Clone();
GaussianList array2 = (GaussianList)to_array.Clone();
for (int i = 0; i < array.Count; i++)
{
array2[i] = array[i];
}
GaussianList result = (GaussianList) to_array.Clone();
GaussianList result = (GaussianList)to_array.Clone();
for (int iter = 0; iter < 10; iter++)
{
GaussianList oldresult = (GaussianList) result.Clone();
GaussianList oldresult = (GaussianList)result.Clone();
result = ArrayAverageLogarithm1(sum, array2, result);
for (int i = 0; i < array2.Count; i++)
{
array2[i] = array2[i]*(result[i]/oldresult[i]);
array2[i] = array2[i] * (result[i] / oldresult[i]);
}
}
return result;
Expand Down Expand Up @@ -505,12 +526,12 @@ public static GaussianList ArrayAverageLogarithm3<GaussianList>([SkipIfUniform]
double mr, vr;
to_array[i].GetMeanAndVariance(out mr, out vr);
double mu1 = sumMean - partialSum;
newMarginal[i] = array[i]*(new Gaussian(mu1, sumVar)/to_array[i]);
double vpost = 1/(1/vi + 1/sumVar - 1/vr);
double m1 = vpost*(mi/vi + mu1/sumVar - mr/vr);
newMarginal[i] = array[i] * (new Gaussian(mu1, sumVar) / to_array[i]);
double vpost = 1 / (1 / vi + 1 / sumVar - 1 / vr);
double m1 = vpost * (mi / vi + mu1 / sumVar - mr / vr);
arraySumOfMean = partialSum + m1;
a[i] = -vpost/sumVar;
b[i] = m1 + vpost*(sumMean/sumVar) - vpost*mu1/sumVar;
a[i] = -vpost / sumVar;
b[i] = m1 + vpost * (sumMean / sumVar) - vpost * mu1 / sumVar;
}
Matrix x = new Matrix(n, n);
for (int i = 0; i < n; i++)
Expand Down Expand Up @@ -548,16 +569,16 @@ public static GaussianList ArrayAverageLogarithm<GaussianList>([SkipIfUniform] G
// resultIndex version. this should be the version that is used in the future but the compiler does not yet know how to schedule it correctly.
public static Gaussian ArrayAverageLogarithm7([SkipIfUniform] Gaussian sum, [Proper, AllExceptIndex] IList<Gaussian> array, int resultIndex)
{
double sumMean, sumVar;
sum.GetMeanAndVariance(out sumMean, out sumVar);
double sumMean, sumVar;
sum.GetMeanAndVariance(out sumMean, out sumVar);

double arraySumOfMean = 0;
for (int i = 0; i < array.Count; i++)
{
if(i != resultIndex) arraySumOfMean = arraySumOfMean + array[i].GetMean();
}
double partialSum = arraySumOfMean;
return new Gaussian(sumMean - partialSum, sumVar);
double arraySumOfMean = 0;
for (int i = 0; i < array.Count; i++)
{
if (i != resultIndex) arraySumOfMean = arraySumOfMean + array[i].GetMean();
}
double partialSum = arraySumOfMean;
return new Gaussian(sumMean - partialSum, sumVar);
}

/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="FastSumOp"]/message_doc[@name="ArrayAverageLogarithm{GaussianList}(double, IList{Gaussian}, GaussianList)"]/*'/>
Expand Down Expand Up @@ -587,7 +608,7 @@ public static double AverageLogFactor()
public static class SumOp3
{
// fully parallel version

/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="SumOp3"]/message_doc[@name="ArrayAverageLogarithm{GaussianList}(Gaussian, IList{Gaussian}, GaussianList)"]/*'/>
/// <typeparam name="GaussianList">The type of the message to <c>array</c>.</typeparam>
public static GaussianList ArrayAverageLogarithm<GaussianList>([SkipIfUniform] Gaussian sum, [Proper] IList<Gaussian> array, GaussianList result)
Expand Down
2 changes: 1 addition & 1 deletion test/TestFSharp/TestFSharp.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugFull|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugCore|AnyCPU'">
<DebugType>full</DebugType>
<DebugType>portable</DebugType>
<DebugSymbols>true</DebugSymbols>
<DefineConstants>$(DefineConstants);DEBUG</DefineConstants>
</PropertyGroup>
Expand Down
3 changes: 1 addition & 2 deletions test/Tests/Core/MatrixTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ public void VectorIndexAtCumulativeSumTests()
Assert.Equal(5, v.IndexAtCumulativeSum(0.67));
Assert.Equal(6, v.IndexAtCumulativeSum(0.69));
Assert.Equal(7, v.IndexAtCumulativeSum(0.79));
Assert.Equal(8, v.IndexAtCumulativeSum(0.9));
Assert.Equal(8, v.IndexAtCumulativeSum(0.9));
Assert.Equal(8, v.IndexAtCumulativeSum(0.901));
Assert.Equal(9, v.IndexAtCumulativeSum(0.93));
Assert.Equal(12, v.IndexAtCumulativeSum(0.99));
}
Expand Down
7 changes: 6 additions & 1 deletion test/Tests/Distributions/DistributionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1680,14 +1680,19 @@ public void GammaGetProbLessThanTest()
Gamma g = new Gamma(1.0, m);
double median = -m * System.Math.Log(0.5);
Assert.Equal(0.5, g.GetProbLessThan(median), 1e-4);
Assert.Equal(median, g.GetQuantile(0.5));
AssertAlmostEqual(median, g.GetQuantile(0.5));

g = new Gamma(2, m);
double probability = g.GetProbLessThan(median);
double quantile = g.GetQuantile(probability);
Assert.Equal(median, quantile, 1e-10);
}

internal static void AssertAlmostEqual(double x, double y)
{
Assert.False(SpecialFunctionsTests.IsErrorSignificant(1e-16, x - y));
}

[Fact]
public void GammaModeTest()
{
Expand Down
14 changes: 10 additions & 4 deletions test/Tests/Operators/OperatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,21 @@ public void SumOpTest()
Gaussian sum_B = Gaussian.FromMeanAndPrecision(0, 1e-310);
GaussianArray array = new GaussianArray(2, i => new Gaussian(0, 1));
var result = FastSumOp.ArrayAverageConditional(sum_B, sum_F, array, new GaussianArray(array.Count));
Console.WriteLine(result);
Assert.True(!double.IsNaN(result[0].MeanTimesPrecision));

sum_B = new Gaussian(0, 1);
array[0] = Gaussian.FromMeanAndPrecision(0, 1e-310);
array[0] = Gaussian.FromMeanAndPrecision(0, 0);
sum_F = FastSumOp.SumAverageConditional(array);
result = FastSumOp.ArrayAverageConditional(sum_B, sum_F, array, new GaussianArray(array.Count));
Console.WriteLine(result);
Assert.True(!double.IsNaN(result[0].MeanTimesPrecision));
Assert.True(result[0].MaxDiff(new Gaussian(0, 2)) < 1e-10);
Assert.True(result[1].IsUniform());
for (int i = 0; i < 1030; i++)
{
array[0] = Gaussian.FromMeanAndPrecision(0, System.Math.Pow(2, -i));
sum_F = FastSumOp.SumAverageConditional(array);
result = FastSumOp.ArrayAverageConditional(sum_B, sum_F, array, new GaussianArray(array.Count));
Assert.True(result[0].MaxDiff(new Gaussian(0, 2)) < 1e-10);
}
}

[Fact]
Expand Down
5 changes: 3 additions & 2 deletions test/Tests/SharedVariableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ public void OutputMessageTest()

Range W = new Range(sizeVocab).Named("W");
Range T = new Range(numTopics).Named("T");
var Theta = Variable<Vector>.DirichletSymmetric(numTopics, 0.1).Named("Theta");
var Theta = Variable<Vector>.DirichletSymmetric(numTopics, 0.125).Named("Theta");
Theta.SetValueRange(T);
var Phi = Variable.Array<Vector>(T).Named("Phi");
Phi.SetValueRange(W);
Expand All @@ -1021,14 +1021,15 @@ public void OutputMessageTest()
engine.NumberOfIterations = 1;

Word.ObservedValue = 0;
PhiPrior.ObservedValue = Util.ArrayInit(numTopics, d => Dirichlet.Symmetric(sizeVocab, 0.1));
PhiPrior.ObservedValue = Util.ArrayInit(numTopics, d => Dirichlet.Symmetric(sizeVocab, 0.125));
Phi.AddAttribute(QueryTypes.Marginal);
Phi.AddAttribute(QueryTypes.MarginalDividedByPrior);
var PhiOutput = engine.Infer<Dirichlet[]>(Phi, QueryTypes.MarginalDividedByPrior)[0];
var phiMarg = engine.Infer<Dirichlet[]>(Phi)[0];
var phiManualOut = new Dirichlet(phiMarg);
phiManualOut.SetToRatio(phiMarg, PhiPrior.ObservedValue[0], false);

// Since we check for exact equality here, the numbers in the problem must be exactly representable.
Assert.Equal(phiManualOut, PhiOutput);
}
}
Expand Down