Skip to content

Commit

Permalink
fix pseudo-determinant in GMRF
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 2, 2023
1 parent a2c0c9e commit b36914c
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 55 deletions.
15 changes: 9 additions & 6 deletions ci/TestXML/testGaussianMarkovRandomField.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

<parameter id="data" value="1 2 3 4"/>


<multivariateDistributionLikelihood id="gmrf">
<randomField id="gmrf">
<distribution>
<gaussianMarkovRandomField dim="4">
<gaussianMarkovRandomField dim="4" matchPseudoDeterminant="true">
<precision>
<parameter value="1.0"/>
<parameter value="1.5"/>
</precision>
<start>
<parameter value="0.0"/>
Expand All @@ -18,17 +17,21 @@
<data>
<parameter idref="data"/>
</data>
</multivariateDistributionLikelihood>
</randomField>


<report>
<multivariateDistributionLikelihood idref="gmrf"/>
<randomField idref="gmrf"/>
</report>



<!--
<report>
<gradient>
<multivariateDistributionLikelihood idref="gmrf"/>
</gradient>
</report>
-->

</beast>
4 changes: 4 additions & 0 deletions src/dr/inference/distribution/RandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ public RandomField(String name,
likelihoodKnown = false;
}

public String toString() {
return getClass().getName() + " " + getModelName() + " (" + getLogLikelihood() + ")";
}

public Parameter getField() { return field; }

public RandomFieldDistribution getDistribution() { return distribution; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class GaussianMarkovRandomFieldParser extends AbstractXMLObjectParser {
private static final String PRECISION = "precision";
private static final String START = "start";
private static final String WEIGHTS = "weights";
private static final String MATCH_PSEUDO_DETERMINANT = "matchPseudoDeterminant";

public String getParserName() { return PARSER_NAME; }

Expand All @@ -61,7 +62,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
") != distribution dim (" + dim + ") - 1");
}

return new GaussianMarkovRandomField(dim, incrementPrecision, start, weights);
boolean matchPseudoDeterminant = xo.getAttribute(MATCH_PSEUDO_DETERMINANT, false);

return new GaussianMarkovRandomField(dim, incrementPrecision, start, weights, matchPseudoDeterminant);
}

public XMLSyntaxRule[] getSyntaxRules() { return rules; }
Expand All @@ -73,7 +76,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
new ElementRule(START,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
new ElementRule(WEIGHTS,
new XMLSyntaxRule[]{new ElementRule(RandomField.WeightProvider.class)}, true)
new XMLSyntaxRule[]{new ElementRule(RandomField.WeightProvider.class)}, true),
AttributeRule.newBooleanRule(MATCH_PSEUDO_DETERMINANT, true),

};

Expand Down
4 changes: 3 additions & 1 deletion src/dr/inferencexml/distribution/RandomFieldParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
") != distribution dimension (" + distribution.getDimension() + ")");
}

return new RandomField(xo.getId(), field, distribution);
String id = xo.hasId() ? xo.getId() : null;

return new RandomField(id, field, distribution);
}

//************************************************************************
Expand Down
157 changes: 111 additions & 46 deletions src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@

package dr.math.distributions;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import dr.inference.distribution.RandomField;
import dr.inference.model.*;
import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.SymmTridiagMatrix;
import dr.math.matrixAlgebra.RobustEigenDecomposition;

import java.util.Arrays;

Expand All @@ -49,23 +50,27 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution {
private final double[][] precision; // TODO Use a sparse matrix, like in GmrfSkyrideLikelihood
private double logDet;

// private double[][] variance = null;
// private double[][] cholesky = null;
private SymmetricTriDiagonalMatrix Q;
private SymmetricTriDiagonalMatrix savedQ;

private boolean meanKnown;
private boolean precisionKnown;
private boolean determinantKnown;
private boolean qKnown;

private final double logMatchTerm;

public GaussianMarkovRandomField(int dim,
Parameter precision,
Parameter start) {
this(dim, precision, start, null);
this(dim, precision, start, null, true);
}

public GaussianMarkovRandomField(int dim,
Parameter precision,
Parameter start,
RandomField.WeightProvider weightProvider) {
RandomField.WeightProvider weightProvider,
boolean matchPseudoDeterminant) {

super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL);

Expand All @@ -77,9 +82,14 @@ public GaussianMarkovRandomField(int dim,
this.mean = new double[dim];
this.precision = new double[dim][dim];

this.Q = new SymmetricTriDiagonalMatrix(dim);

this.logMatchTerm = matchPseudoDeterminant ? matchPseudoDeterminantTerm(dim) : 0.0;

meanKnown = false;
precisionKnown = false;
determinantKnown = false; // TODO No need to be computed separately
qKnown = false;
}

// private void check() {
Expand Down Expand Up @@ -110,6 +120,25 @@ public double[] getMean() {
return mean;
}

private SymmetricTriDiagonalMatrix getQ() {
if (!qKnown) {
double precision = precisionParameter.getParameterValue(0);
Q.diagonal[0] = precision;
for (int i = 1; i < dim - 1; ++i) {
Q.diagonal[i] = 2 * precision;
}
Q.diagonal[dim - 1] = precision;

for (int i = 0; i < dim - 1; ++i) {
Q.offDiagonal[i] = -precision;
}
// TODO Update for lambda != 1 and for weights

qKnown = true;
}
return Q;
}

private double[][] getPrecision() {

if (!precisionKnown) {
Expand Down Expand Up @@ -162,22 +191,44 @@ public String getType() {
// return cholesky;
// }

private static double matchPseudoDeterminantTerm(int dim) {
double term = 0.0;
for (int i = 1; i < dim; ++i) {
double x = (2 - 2 * Math.cos(i * Math.PI / dim));
term += Math.log(x);
}
return term;
}

public double getLogDet() {

if (!determinantKnown) {
final double k = precisionParameter.getParameterValue(0);
double det = Math.pow(k, dim);
for(int i=2; i<=dim; ++i) {
det = det * (2 - 2 * Math.cos((i-1)*(Math.PI/dim)));
logDet = (dim - 1) * Math.log(precisionParameter.getParameterValue(0)) + logMatchTerm;
determinantKnown = true;
}

if (CHECK_DETERMINANT) {

RobustEigenDecomposition ed = new RobustEigenDecomposition(new DenseDoubleMatrix2D(getPrecision()));
DoubleMatrix1D values = ed.getRealEigenvalues();
double sum = 0.0;
for (int i = 0; i < values.size(); ++i) {
double v = values.get(i);
if (Math.abs(v) > 1E-6) {
sum += Math.log(v);
}
}
logDet = Math.log(det);

determinantKnown = true;
if (Math.abs(sum - logDet) > 1E-6) {
throw new RuntimeException("Incorrect pseudo-determinant");
}
}

return logDet;
}

private static final boolean CHECK_DETERMINANT = false;

@Override
public double[][] getScaleMatrix() {
return getPrecision();
Expand All @@ -190,18 +241,27 @@ public Variable<Double> getLocationVariable() {

@Override
public double logPdf(double[] x) {
return logPdf(x, getMean(), getPrecision(), getLogDet());
// double x1 = logPdf(x, getMean(), getPrecision(), getLogDet());
// double x2 = logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), 1.0, logMatchTerm);
//
// System.err.println(x1 + " ?= " + x2);
//
return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0),
1.0, logMatchTerm);
}

public double[] gradLogPdf(double[] x) {
// TODO Update to use Q
return gradLogPdf(x, getMean(), getPrecision());
}

public double[][] hessianLogPdf(double[] x) {
// TODO Update to use Q
return hessianLogPdf(x, getPrecision());
}

public double[] diagonalHessianLogPdf(double[] x) {
// TODO Update to use Q
return diagonalHessianLogPdf(x, getPrecision());
}

Expand Down Expand Up @@ -339,7 +399,11 @@ public static double[] diagonalHessianLogPdf(double[] x, double[][] precision) {
// }

private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
double precision, double lambda) {
double precision, double lambda, double logMatch) {
return getLogNormalization(x.length, precision, lambda, logMatch) - 0.5 * getSSE(x, mean, Q);
}

private static double getSSE(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q) {

final int dim = x.length;
final double[] delta = new double[dim];
Expand All @@ -354,14 +418,24 @@ private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatr
}
SSE += Q.diagonal[dim - 1] * delta[dim - 1] * delta[dim - 1];

double logLikelihood = 0.5 * (dim - 1) * Math.log(precision) - 0.5 * SSE;
return SSE;
}

private static double getLogDeterminant(int dim, double precision, double lambda, double logMatch) {
return (dim - 1) * Math.log(precision) + logMatch;
}

private static double getLogNormalization(int dim, double precision, double lambda, double logMatch) {

double logNorm = 0.5 * getLogDeterminant(dim, precision, lambda, logMatch);

if (lambda == 1.0) {
logLikelihood -= (dim - 1) * logNormalize;
logNorm -= (dim - 1) * HALF_LOG_TWO_PI;
} else {
logLikelihood -= dim * logNormalize;
logNorm -= dim * HALF_LOG_TWO_PI;
}

return logLikelihood;
return logNorm;
}


Expand All @@ -370,6 +444,11 @@ class SymmetricTriDiagonalMatrix {
double[] diagonal;
double[] offDiagonal;

SymmetricTriDiagonalMatrix(int dim) {
this.diagonal = new double[dim];
this.offDiagonal = new double[dim - 1];
}

SymmetricTriDiagonalMatrix(double[] diagonal, double[] offDiagonal) {
this.diagonal = diagonal;
this.offDiagonal = offDiagonal;
Expand Down Expand Up @@ -400,7 +479,17 @@ public static double logPdf(double[] x, double[] mean, double[][] precision,
if (logDet == Double.NEGATIVE_INFINITY)
return logDet;

return getLogNormalization(x.length, logDet) - 0.5 * getSSE(x, mean, precision);
}

public static double getLogNormalization(int dim, double logDet) {
return -(dim - 1) * HALF_LOG_TWO_PI + 0.5 * logDet; // Pratyusa's normalization constant
}

public static double getSSE(double[] x, double[] mean, double[][] precision) {

final int dim = x.length;

final double[] delta = new double[dim];

for (int i = 0; i < dim; i++) {
Expand All @@ -412,40 +501,15 @@ public static double logPdf(double[] x, double[] mean, double[][] precision,
for (int i = 0; i < dim-1; i++) {
SSE += precision[i][i] * delta[i] * delta[i] + 2 * precision[i][i + 1] * delta[i] * delta[i + 1];
}
return (dim-1) * logNormalize + 0.5 * (logDet - SSE); // There was an error here.
// Variance = (scale * Precision^{-1})
}

// private static double[][] getInverse(double[][] x) {
// return new SymmetricMatrix(x).inverse().toComponents();
// }

// private static double[][] getCholeskyDecomposition(double[][] variance) {
// double[][] cholesky;
// try {
// cholesky = (new CholeskyDecomposition(variance)).getL();
// } catch (IllegalDimension illegalDimension) {
// throw new RuntimeException("Attempted Cholesky decomposition on non-square matrix");
// }
// return cholesky;
// }

private static final double logNormalize = -0.5 * Math.log(2.0 * Math.PI);

// public double logPdf(Object x) {
// double[] v = (double[]) x;
// return logPdf(v);
// }
return SSE;
}

private static final double HALF_LOG_TWO_PI = Math.log(2.0 * Math.PI) / 2;

@Override
public int getDimension() { return dim; }

// public Parameter getincrementPrecision() { return precisionParameter; }

// public Parameter getstart() { return meanParameter; }


@Override
public double[] getGradientLogDensity(Object x) {
return gradLogPdf((double[]) x);
Expand Down Expand Up @@ -493,6 +557,7 @@ protected void restoreState() { // TODO with caching
meanKnown = false;
precisionKnown = false;
determinantKnown = false;
qKnown = false;
}

@Override
Expand Down

0 comments on commit b36914c

Please sign in to comment.