Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChainWithTransitionParameterTest2 uses more iterations #239

Merged
merged 2 commits into from
Apr 24, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions test/Tests/SerialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3664,7 +3664,7 @@ public void ChainWithTransitionParameterVectorTest()
Console.WriteLine("shift = {0}", shiftActual1);

int maxIter = 200;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter <= maxIter; iter++)
{
Expand Down Expand Up @@ -3711,7 +3711,7 @@ public void ChainWithTransitionParameterVectorTest()
[Fact]
public void ChainWithTransitionParameterTest()
{
// without PointEstimator, EP doesn't converge unless length <= 700,
// without PointEstimate, EP doesn't converge unless length <= 700,
// where it gets correct mean but variance = 1 (prior variance)
int length = 1000;

Expand All @@ -3722,12 +3722,6 @@ public void ChainWithTransitionParameterTest()
VariableArray<double> observation = Variable.Array<double>(rows).Named("observations");
Variable<double> shift = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift");
shift.AddAttribute(new PointEstimate());
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward<double>, shift).Named("shiftPoint");
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward2<double>, shift, (double)length).Named("shiftPoint");
//Variable<double> shiftPoint = Variable<double>.Factor(Damp.Forward<double>, shift, 0.1).Named("shiftPoint");
// when using PointEstimator.Forward with Secant, should always initialize
//shiftPoint.InitialiseTo(Gaussian.PointMass(0));
//shift.InitialiseTo(Gaussian.PointMass(1));

using (ForEachBlock rowBlock = Variable.ForEach(rows))
{
Expand Down Expand Up @@ -3767,7 +3761,7 @@ public void ChainWithTransitionParameterTest()
Console.WriteLine("shift = {0}", shiftActual1);

int maxIter = 100;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
Expand Down Expand Up @@ -3817,11 +3811,6 @@ public void ChainWithTransitionParameterTest2()
VariableArray<double> observation = Variable.Array<double>(rows).Named("observations");
Variable<double> shift = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift");
Variable<double> shift2 = Variable.GaussianFromMeanAndVariance(0, 1).Named("shift2");
//Variable<double> shiftPoint = Variable<double>.Factor(PointEstimator.Forward<double>, shift).Named("shiftPoint");
//Variable<double> shift2Point = Variable<double>.Factor(PointEstimator.Forward<double>, shift2).Named("shift2Point");
// when using PointEstimator.Forward with Secant, should always initialize
shift.InitialiseTo(Gaussian.PointMass(0));
shift2.InitialiseTo(Gaussian.PointMass(0));

using (ForEachBlock rowBlock = Variable.ForEach(rows))
{
Expand Down Expand Up @@ -3850,6 +3839,7 @@ public void ChainWithTransitionParameterTest2()
observation.ObservedValue = observationValues;

InferenceEngine engine = new InferenceEngine();
engine.ShowProgress = false;
//engine.Compiler.GivePriorityTo(typeof(VariablePointOp_Secant));
engine.OptimiseForVariables = new List<IVariable>() { states, shift, shift2 };
engine.NumberOfIterations = 3;
Expand All @@ -3867,8 +3857,8 @@ public void ChainWithTransitionParameterTest2()
Gaussian shift2Actual1 = engine.Infer<Gaussian>(shift2);
Console.WriteLine("shift = {0}, shift2 = {1}", shiftActual1, shift2Actual1);

int maxIter = 100;
if (false)
int maxIter = engine.Compiler.OptimiseInferenceCode ? 100 : 200;
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
Expand Down Expand Up @@ -3978,14 +3968,14 @@ public void ChainWithTransitionParameterTest3()
//const double shiftMeanExpected = 0.767681001959026;
const double shiftMeanExpected = 0.746576051928051;
int maxIter = 100;
if (false)
if (!engine.ShowProgress)
{
for (int iter = 1; iter < maxIter; iter++)
{
engine.NumberOfIterations = iter;
var shiftTemp = engine.Infer<Gaussian>(shift);
var precisionTemp = engine.Infer<Gamma>(precision);
Debug.WriteLine("{0} shift={1} prec={2}", iter, shiftTemp.GetMean(), precisionTemp.GetMean());
Console.WriteLine("{0} shift={1} prec={2}", iter, shiftTemp.GetMean(), precisionTemp.GetMean());
//if (shiftTemp.GetMean() == shiftMeanExpected)
// throw new Exception("converged at iter " + iter);
}
Expand Down