Skip to content

Commit

Permalink
ChainWithTransitionParameterTest2 uses more iterations when OptimiseI…
Browse files Browse the repository at this point in the history
…nferenceCode=false (dotnet#239)
  • Loading branch information
tminka committed Aug 24, 2020
1 parent 96004d7 commit 7447228
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions test/Tests/SerialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3664,7 +3664,7 @@ public void ChainWithTransitionParameterVectorTest()
Console.WriteLine("shift = {0}", shiftActual1);

int maxIter = 200;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter <= maxIter; iter++)
{
Expand Down Expand Up @@ -3711,7 +3711,7 @@ public void ChainWithTransitionParameterVectorTest()
[Fact]
public void ChainWithTransitionParameterTest()
{
// without PointEstimator, EP doesn't converge unless length <= 700,
// without PointEstimate, EP doesn't converge unless length <= 700,
// where it gets correct mean but variance = 1 (prior variance)
int length = 1000;

Expand All @@ -3722,12 +3722,6 @@ public void ChainWithTransitionParameterTest()
VariableArray<double> observation = Variable.Array<double>(rows).Named("observations");
Variable<double> shift = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift");
shift.AddAttribute(new PointEstimate());
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward<double>, shift).Named("shiftPoint");
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward2<double>, shift, (double)length).Named("shiftPoint");
//Variable<double> shiftPoint = Variable<double>.Factor(Damp.Forward<double>, shift, 0.1).Named("shiftPoint");
// when using PointEstimator.Forward with Secant, should always initialize
//shiftPoint.InitialiseTo(Gaussian.PointMass(0));
//shift.InitialiseTo(Gaussian.PointMass(1));

using (ForEachBlock rowBlock = Variable.ForEach(rows))
{
Expand Down Expand Up @@ -3767,7 +3761,7 @@ public void ChainWithTransitionParameterTest()
Console.WriteLine("shift = {0}", shiftActual1);

int maxIter = 100;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
Expand Down Expand Up @@ -3817,11 +3811,6 @@ public void ChainWithTransitionParameterTest2()
VariableArray<double> observation = Variable.Array<double>(rows).Named("observations");
Variable<double> shift = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift");
Variable<double> shift2 = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift2");
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward<double>, shift).Named("shiftPoint");
//Variable<double> shift2Point = Variable<double>.Factor(PointEstimator.Forward<double>, shift2).Named("shift2Point");
// when using PointEstimator.Forward with Secant, should always initialize
shift.InitialiseTo(Gaussian.PointMass(0));
shift2.InitialiseTo(Gaussian.PointMass(0));

using (ForEachBlock rowBlock = Variable.ForEach(rows))
{
Expand Down Expand Up @@ -3850,6 +3839,7 @@ public void ChainWithTransitionParameterTest2()
observation.ObservedValue = observationValues;

InferenceEngine engine = new InferenceEngine();
engine.ShowProgress = false;
//engine.Compiler.GivePriorityTo(typeof(VariablePointOp_Secant));
engine.OptimiseForVariables = new List<IVariable>() { states, shift, shift2 };
engine.NumberOfIterations = 3;
Expand All @@ -3867,8 +3857,8 @@ public void ChainWithTransitionParameterTest2()
Gaussian shift2Actual1 = engine.Infer<Gaussian>(shift2);
Console.WriteLine("shift = {0}, shift2 = {1}", shiftActual1, shift2Actual1);

int maxIter = 100;
if (false)
int maxIter = engine.Compiler.OptimiseInferenceCode ? 100 : 200;
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
Expand Down Expand Up @@ -3978,14 +3968,14 @@ public void ChainWithTransitionParameterTest3()
//const double shiftMeanExpected = 0.767681001959026;
const double shiftMeanExpected = 0.746576051928051;
int maxIter = 100;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
engine.NumberOfIterations = iter;
var shiftTemp = engine.Infer<Gaussian>(shift);
var precisionTemp = engine.Infer<Gamma>(precision);
Debug.WriteLine("{0} shift={1} prec={2}", iter, shiftTemp.GetMean(), precisionTemp.GetMean());
Console.WriteLine("{0} shift={1} prec={2}", iter, shiftTemp.GetMean(), precisionTemp.GetMean());
//if (shiftTemp.GetMean() == shiftMeanExpected)
// throw new Exception("converged at iter " + iter);
}
Expand Down

0 comments on commit 7447228

Please sign in to comment.