diff --git a/src/dr/evomodel/treedatalikelihood/discrete/RandomEffectsSubstitutionModelGradient.java b/src/dr/evomodel/treedatalikelihood/discrete/RandomEffectsSubstitutionModelGradient.java index 51f375aaf1..70bd93b417 100644 --- a/src/dr/evomodel/treedatalikelihood/discrete/RandomEffectsSubstitutionModelGradient.java +++ b/src/dr/evomodel/treedatalikelihood/discrete/RandomEffectsSubstitutionModelGradient.java @@ -58,14 +58,6 @@ public RandomEffectsSubstitutionModelGradient(String traitName, } } -// Parameter makeCompoundParameter(GeneralizedLinearModel glm) { -// CompoundParameter parameter = new CompoundParameter("random.effects"); -// for (int i = 0; i < glm.getNumberOfRandomEffects(); ++i) { -// parameter.addParameter(glm.getRandomEffect(i)); -// } -// return parameter; -// } - ParameterMap makeParameterMap(GeneralizedLinearModel glm) { return new ParameterMap() { diff --git a/src/dr/math/distributions/GaussianMarkovRandomField.java b/src/dr/math/distributions/GaussianMarkovRandomField.java index 6fd93fc6a3..05866efef6 100644 --- a/src/dr/math/distributions/GaussianMarkovRandomField.java +++ b/src/dr/math/distributions/GaussianMarkovRandomField.java @@ -50,21 +50,15 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution { private final double[] mean; - SymmetricTriDiagonalMatrix Q; - private SymmetricTriDiagonalMatrix savedQ; + final SymmetricTriDiagonalMatrix Q; + private final SymmetricTriDiagonalMatrix savedQ; private boolean meanKnown; boolean qKnown; + private boolean savedQKnown; private final double logMatchTerm; - public GaussianMarkovRandomField(String name, - int dim, - Parameter precision, - Parameter mean) { - this(name, dim, precision, mean, null, null, true); - } - public GaussianMarkovRandomField(String name, int dim, Parameter precision, @@ -94,6 +88,7 @@ public GaussianMarkovRandomField(String name, this.mean = new double[dim]; this.Q = new SymmetricTriDiagonalMatrix(dim); + this.savedQ = new SymmetricTriDiagonalMatrix(dim); this.logMatchTerm = matchPseudoDeterminant ? matchPseudoDeterminantTerm(dim) : 0.0; @@ -145,6 +140,7 @@ protected SymmetricTriDiagonalMatrix getQ() { } } + assert weightProvider == null : "Not yet implemented"; // TODO Update for weights qKnown = true; @@ -218,6 +214,8 @@ public double[] getGradientLogDensity(Object x) { throw new IllegalArgumentException("Unknown mean parameter structure"); } }; + } else if (parameter == lambdaParameter) { + throw new RuntimeException("Not yet implemented"); // TODO } else { throw new RuntimeException("Unknown parameter"); } @@ -234,11 +232,8 @@ private double matchPseudoDeterminantTerm(int dim) { double x = (2 - 2 * Math.cos(i * Math.PI / dim)); term += Math.log(x); } - return term; - } else { - double lambda = lambdaParameter.getParameterValue(0); - return (1 - dim) * Math.log(1 - lambda * lambda); } + return term; } private double getLogDeterminant() { @@ -246,19 +241,14 @@ private double getLogDeterminant() { int effectiveDim = isImproper() ? dim - 1 : dim; double logDet = effectiveDim * Math.log(precisionParameter.getParameterValue(0)) + logMatchTerm; - if (CHECK_DETERMINANT) { - - double[][] precision = new double[dim][dim]; - - for (int i = 0; i < dim; ++i) { - precision[i][i] = Q.diagonal[i]; - } + if (!isImproper()) { + double lambda = lambdaParameter.getParameterValue(0); + logDet += (1 - dim) * Math.log(1 - lambda * lambda); + } - for (int i = 0; i < dim - 1; ++i) { - precision[i][i + 1] = Q.offDiagonal[i]; - precision[i + 1][i] = Q.offDiagonal[i]; - } + if (CHECK_DETERMINANT) { + double[][] precision = makePrecisionMatrix(Q); RobustEigenDecomposition ed = new RobustEigenDecomposition(new DenseDoubleMatrix2D(precision)); DoubleMatrix1D values = ed.getRealEigenvalues(); double sum = 0.0; @@ -291,7 +281,7 @@ public Variable getLocationVariable() { @Override public double logPdf(double[] x) { - return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), isImproper(), getLogDeterminant()); + return logPdf(x, getMean(), getQ(), isImproper(), getLogDeterminant()); } public static double gradLogPdfWrtPrecision(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q, @@ -443,9 +433,8 @@ public static double[] diagonalHessianLogPdf(double[] x, SymmetricTriDiagonalMat // return currentLike; // } - @SuppressWarnings("unused") private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q, - double precision, boolean isImproper, double logDeterminant) { + boolean isImproper, double logDeterminant) { return getLogNormalization(x.length, isImproper, logDeterminant) - 0.5 * getSSE(x, mean, Q); } @@ -536,7 +525,7 @@ protected void handleModelChangedEvent(Model model, Object object, int index) { protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == meanParameter) { meanKnown = false; - } else if (variable == precisionParameter) { + } else if (variable == precisionParameter || variable == lambdaParameter) { qKnown = false; } else { throw new IllegalArgumentException("Unknown variable"); @@ -545,13 +534,20 @@ protected void handleVariableChangedEvent(Variable variable, int index, Paramete @Override protected void storeState() { - // TODO + if (qKnown) { + Q.copyTo(savedQ); + } + savedQKnown = qKnown; } @Override - protected void restoreState() { // TODO with caching + protected void restoreState() { // TODO cache mean meanKnown = false; - qKnown = false; + + qKnown = savedQKnown; + if (qKnown) { + savedQ.swap(Q); + } } @Override