diff --git a/ci/TestXML/testGaussianMarkovRandomField.xml b/ci/TestXML/testGaussianMarkovRandomField.xml index 5861c2fba0..871ef311b7 100644 --- a/ci/TestXML/testGaussianMarkovRandomField.xml +++ b/ci/TestXML/testGaussianMarkovRandomField.xml @@ -3,12 +3,11 @@ - - + - + - + @@ -18,17 +17,21 @@ - + + - + + + diff --git a/src/dr/inference/distribution/RandomField.java b/src/dr/inference/distribution/RandomField.java index b520861670..ea614cdcfd 100644 --- a/src/dr/inference/distribution/RandomField.java +++ b/src/dr/inference/distribution/RandomField.java @@ -66,6 +66,10 @@ public RandomField(String name, likelihoodKnown = false; } + public String toString() { + return getClass().getName() + " " + getModelName() + " (" + getLogLikelihood() + ")"; + } + public Parameter getField() { return field; } public RandomFieldDistribution getDistribution() { return distribution; } diff --git a/src/dr/inferencexml/distribution/GaussianMarkovRandomFieldParser.java b/src/dr/inferencexml/distribution/GaussianMarkovRandomFieldParser.java index e63b643ee7..e5a6c71084 100644 --- a/src/dr/inferencexml/distribution/GaussianMarkovRandomFieldParser.java +++ b/src/dr/inferencexml/distribution/GaussianMarkovRandomFieldParser.java @@ -37,6 +37,7 @@ public class GaussianMarkovRandomFieldParser extends AbstractXMLObjectParser { private static final String PRECISION = "precision"; private static final String START = "start"; private static final String WEIGHTS = "weights"; + private static final String MATCH_PSEUDO_DETERMINANT = "matchPseudoDeterminant"; public String getParserName() { return PARSER_NAME; } @@ -61,7 +62,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ") != distribution dim (" + dim + ") - 1"); } - return new GaussianMarkovRandomField(dim, incrementPrecision, start, weights); + boolean matchPseudoDeterminant = xo.getAttribute(MATCH_PSEUDO_DETERMINANT, false); + + return new GaussianMarkovRandomField(dim, incrementPrecision, start, weights, matchPseudoDeterminant); } public XMLSyntaxRule[] getSyntaxRules() { return rules; } @@ -73,7 +76,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { new ElementRule(START, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true), new ElementRule(WEIGHTS, - new XMLSyntaxRule[]{new ElementRule(RandomField.WeightProvider.class)}, true) + new XMLSyntaxRule[]{new ElementRule(RandomField.WeightProvider.class)}, true), + AttributeRule.newBooleanRule(MATCH_PSEUDO_DETERMINANT, true), }; diff --git a/src/dr/inferencexml/distribution/RandomFieldParser.java b/src/dr/inferencexml/distribution/RandomFieldParser.java index 94115c22e2..b70c2270ae 100644 --- a/src/dr/inferencexml/distribution/RandomFieldParser.java +++ b/src/dr/inferencexml/distribution/RandomFieldParser.java @@ -54,7 +54,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { ") != distribution dimension (" + distribution.getDimension() + ")"); } - return new RandomField(xo.getId(), field, distribution); + String id = xo.hasId() ? xo.getId() : null; + + return new RandomField(id, field, distribution); } //************************************************************************ diff --git a/src/dr/math/distributions/GaussianMarkovRandomField.java b/src/dr/math/distributions/GaussianMarkovRandomField.java index 8cf8afcdc6..d71fb7ac15 100644 --- a/src/dr/math/distributions/GaussianMarkovRandomField.java +++ b/src/dr/math/distributions/GaussianMarkovRandomField.java @@ -25,11 +25,12 @@ package dr.math.distributions; +import cern.colt.matrix.DoubleMatrix1D; +import cern.colt.matrix.impl.DenseDoubleMatrix2D; import dr.inference.distribution.RandomField; import dr.inference.model.*; import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser; -import no.uib.cipr.matrix.DenseVector; -import no.uib.cipr.matrix.SymmTridiagMatrix; +import dr.math.matrixAlgebra.RobustEigenDecomposition; import java.util.Arrays; @@ -49,23 +50,27 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution { private final double[][] precision; // TODO Use a sparse matrix, like in GmrfSkyrideLikelihood private double logDet; -// private double[][] variance = null; -// private double[][] cholesky = null; + private SymmetricTriDiagonalMatrix Q; + private SymmetricTriDiagonalMatrix savedQ; private boolean meanKnown; private boolean precisionKnown; private boolean determinantKnown; + private boolean qKnown; + + private final double logMatchTerm; public GaussianMarkovRandomField(int dim, Parameter precision, Parameter start) { - this(dim, precision, start, null); + this(dim, precision, start, null, true); } public GaussianMarkovRandomField(int dim, Parameter precision, Parameter start, - RandomField.WeightProvider weightProvider) { + RandomField.WeightProvider weightProvider, + boolean matchPseudoDeterminant) { super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL); @@ -77,9 +82,14 @@ public GaussianMarkovRandomField(int dim, this.mean = new double[dim]; this.precision = new double[dim][dim]; + this.Q = new SymmetricTriDiagonalMatrix(dim); + + this.logMatchTerm = matchPseudoDeterminant ? matchPseudoDeterminantTerm(dim) : 0.0; + meanKnown = false; precisionKnown = false; determinantKnown = false; // TODO No need to be computed separately + qKnown = false; } // private void check() { @@ -110,6 +120,25 @@ public double[] getMean() { return mean; } + private SymmetricTriDiagonalMatrix getQ() { + if (!qKnown) { + double precision = precisionParameter.getParameterValue(0); + Q.diagonal[0] = precision; + for (int i = 1; i < dim - 1; ++i) { + Q.diagonal[i] = 2 * precision; + } + Q.diagonal[dim - 1] = precision; + + for (int i = 0; i < dim - 1; ++i) { + Q.offDiagonal[i] = -precision; + } + // TODO Update for lambda != 1 and for weights + + qKnown = true; + } + return Q; + } + private double[][] getPrecision() { if (!precisionKnown) { @@ -162,22 +191,44 @@ public String getType() { // return cholesky; // } + private static double matchPseudoDeterminantTerm(int dim) { + double term = 0.0; + for (int i = 1; i < dim; ++i) { + double x = (2 - 2 * Math.cos(i * Math.PI / dim)); + term += Math.log(x); + } + return term; + } + public double getLogDet() { if (!determinantKnown) { - final double k = precisionParameter.getParameterValue(0); - double det = Math.pow(k, dim); - for(int i=2; i<=dim; ++i) { - det = det * (2 - 2 * Math.cos((i-1)*(Math.PI/dim))); + logDet = (dim - 1) * Math.log(precisionParameter.getParameterValue(0)) + logMatchTerm; + determinantKnown = true; + } + + if (CHECK_DETERMINANT) { + + RobustEigenDecomposition ed = new RobustEigenDecomposition(new DenseDoubleMatrix2D(getPrecision())); + DoubleMatrix1D values = ed.getRealEigenvalues(); + double sum = 0.0; + for (int i = 0; i < values.size(); ++i) { + double v = values.get(i); + if (Math.abs(v) > 1E-6) { + sum += Math.log(v); + } } - logDet = Math.log(det); - determinantKnown = true; + if (Math.abs(sum - logDet) > 1E-6) { + throw new RuntimeException("Incorrect pseudo-determinant"); + } } return logDet; } + private static final boolean CHECK_DETERMINANT = false; + @Override public double[][] getScaleMatrix() { return getPrecision(); @@ -190,18 +241,27 @@ public Variable getLocationVariable() { @Override public double logPdf(double[] x) { - return logPdf(x, getMean(), getPrecision(), getLogDet()); +// double x1 = logPdf(x, getMean(), getPrecision(), getLogDet()); +// double x2 = logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), 1.0, logMatchTerm); +// +// System.err.println(x1 + " ?= " + x2); +// + return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), + 1.0, logMatchTerm); } public double[] gradLogPdf(double[] x) { + // TODO Update to use Q return gradLogPdf(x, getMean(), getPrecision()); } public double[][] hessianLogPdf(double[] x) { + // TODO Update to use Q return hessianLogPdf(x, getPrecision()); } public double[] diagonalHessianLogPdf(double[] x) { + // TODO Update to use Q return diagonalHessianLogPdf(x, getPrecision()); } @@ -339,7 +399,11 @@ public static double[] diagonalHessianLogPdf(double[] x, double[][] precision) { // } private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q, - double precision, double lambda) { + double precision, double lambda, double logMatch) { + return getLogNormalization(x.length, precision, lambda, logMatch) - 0.5 * getSSE(x, mean, Q); + } + + private static double getSSE(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q) { final int dim = x.length; final double[] delta = new double[dim]; @@ -354,14 +418,24 @@ private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatr } SSE += Q.diagonal[dim - 1] * delta[dim - 1] * delta[dim - 1]; - double logLikelihood = 0.5 * (dim - 1) * Math.log(precision) - 0.5 * SSE; + return SSE; + } + + private static double getLogDeterminant(int dim, double precision, double lambda, double logMatch) { + return (dim - 1) * Math.log(precision) + logMatch; + } + + private static double getLogNormalization(int dim, double precision, double lambda, double logMatch) { + + double logNorm = 0.5 * getLogDeterminant(dim, precision, lambda, logMatch); + if (lambda == 1.0) { - logLikelihood -= (dim - 1) * logNormalize; + logNorm -= (dim - 1) * HALF_LOG_TWO_PI; } else { - logLikelihood -= dim * logNormalize; + logNorm -= dim * HALF_LOG_TWO_PI; } - return logLikelihood; + return logNorm; } @@ -370,6 +444,11 @@ class SymmetricTriDiagonalMatrix { double[] diagonal; double[] offDiagonal; + SymmetricTriDiagonalMatrix(int dim) { + this.diagonal = new double[dim]; + this.offDiagonal = new double[dim - 1]; + } + SymmetricTriDiagonalMatrix(double[] diagonal, double[] offDiagonal) { this.diagonal = diagonal; this.offDiagonal = offDiagonal; @@ -400,7 +479,17 @@ public static double logPdf(double[] x, double[] mean, double[][] precision, if (logDet == Double.NEGATIVE_INFINITY) return logDet; + return getLogNormalization(x.length, logDet) - 0.5 * getSSE(x, mean, precision); + } + + public static double getLogNormalization(int dim, double logDet) { + return -(dim - 1) * HALF_LOG_TWO_PI + 0.5 * logDet; // Pratyusa's normalization constant + } + + public static double getSSE(double[] x, double[] mean, double[][] precision) { + final int dim = x.length; + final double[] delta = new double[dim]; for (int i = 0; i < dim; i++) { @@ -412,40 +501,15 @@ public static double logPdf(double[] x, double[] mean, double[][] precision, for (int i = 0; i < dim-1; i++) { SSE += precision[i][i] * delta[i] * delta[i] + 2 * precision[i][i + 1] * delta[i] * delta[i + 1]; } - return (dim-1) * logNormalize + 0.5 * (logDet - SSE); // There was an error here. - // Variance = (scale * Precision^{-1}) - } - -// private static double[][] getInverse(double[][] x) { -// return new SymmetricMatrix(x).inverse().toComponents(); -// } - -// private static double[][] getCholeskyDecomposition(double[][] variance) { -// double[][] cholesky; -// try { -// cholesky = (new CholeskyDecomposition(variance)).getL(); -// } catch (IllegalDimension illegalDimension) { -// throw new RuntimeException("Attempted Cholesky decomposition on non-square matrix"); -// } -// return cholesky; -// } - private static final double logNormalize = -0.5 * Math.log(2.0 * Math.PI); - -// public double logPdf(Object x) { -// double[] v = (double[]) x; -// return logPdf(v); -// } + return SSE; + } + private static final double HALF_LOG_TWO_PI = Math.log(2.0 * Math.PI) / 2; @Override public int getDimension() { return dim; } -// public Parameter getincrementPrecision() { return precisionParameter; } - -// public Parameter getstart() { return meanParameter; } - - @Override public double[] getGradientLogDensity(Object x) { return gradLogPdf((double[]) x); @@ -493,6 +557,7 @@ protected void restoreState() { // TODO with caching meanKnown = false; precisionKnown = false; determinantKnown = false; + qKnown = false; } @Override