Skip to content

Commit

Permalink
move calculations wrt lambda out of constant terms
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Dec 17, 2023
1 parent 3263d62 commit db1fa38
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ public RandomEffectsSubstitutionModelGradient(String traitName,
}
}

// Parameter makeCompoundParameter(GeneralizedLinearModel glm) {
// CompoundParameter parameter = new CompoundParameter("random.effects");
// for (int i = 0; i < glm.getNumberOfRandomEffects(); ++i) {
// parameter.addParameter(glm.getRandomEffect(i));
// }
// return parameter;
// }

ParameterMap makeParameterMap(GeneralizedLinearModel glm) {

return new ParameterMap() {
Expand Down
58 changes: 27 additions & 31 deletions src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,15 @@ public class GaussianMarkovRandomField extends RandomFieldDistribution {

private final double[] mean;

SymmetricTriDiagonalMatrix Q;
private SymmetricTriDiagonalMatrix savedQ;
final SymmetricTriDiagonalMatrix Q;
private final SymmetricTriDiagonalMatrix savedQ;

private boolean meanKnown;
boolean qKnown;
private boolean savedQKnown;

private final double logMatchTerm;

public GaussianMarkovRandomField(String name,
int dim,
Parameter precision,
Parameter mean) {
this(name, dim, precision, mean, null, null, true);
}

public GaussianMarkovRandomField(String name,
int dim,
Parameter precision,
Expand Down Expand Up @@ -94,6 +88,7 @@ public GaussianMarkovRandomField(String name,
this.mean = new double[dim];

this.Q = new SymmetricTriDiagonalMatrix(dim);
this.savedQ = new SymmetricTriDiagonalMatrix(dim);

this.logMatchTerm = matchPseudoDeterminant ? matchPseudoDeterminantTerm(dim) : 0.0;

Expand Down Expand Up @@ -145,6 +140,7 @@ protected SymmetricTriDiagonalMatrix getQ() {
}
}

assert weightProvider == null : "Not yet implemented";
// TODO Update for weights

qKnown = true;
Expand Down Expand Up @@ -218,6 +214,8 @@ public double[] getGradientLogDensity(Object x) {
throw new IllegalArgumentException("Unknown mean parameter structure");
}
};
} else if (parameter == lambdaParameter) {
throw new RuntimeException("Not yet implemented"); // TODO
} else {
throw new RuntimeException("Unknown parameter");
}
Expand All @@ -234,31 +232,23 @@ private double matchPseudoDeterminantTerm(int dim) {
double x = (2 - 2 * Math.cos(i * Math.PI / dim));
term += Math.log(x);
}
return term;
} else {
double lambda = lambdaParameter.getParameterValue(0);
return (1 - dim) * Math.log(1 - lambda * lambda);
}
return term;
}

private double getLogDeterminant() {

int effectiveDim = isImproper() ? dim - 1 : dim;
double logDet = effectiveDim * Math.log(precisionParameter.getParameterValue(0)) + logMatchTerm;

if (CHECK_DETERMINANT) {

double[][] precision = new double[dim][dim];

for (int i = 0; i < dim; ++i) {
precision[i][i] = Q.diagonal[i];
}
if (!isImproper()) {
double lambda = lambdaParameter.getParameterValue(0);
logDet += (1 - dim) * Math.log(1 - lambda * lambda);
}

for (int i = 0; i < dim - 1; ++i) {
precision[i][i + 1] = Q.offDiagonal[i];
precision[i + 1][i] = Q.offDiagonal[i];
}
if (CHECK_DETERMINANT) {

double[][] precision = makePrecisionMatrix(Q);
RobustEigenDecomposition ed = new RobustEigenDecomposition(new DenseDoubleMatrix2D(precision));
DoubleMatrix1D values = ed.getRealEigenvalues();
double sum = 0.0;
Expand Down Expand Up @@ -291,7 +281,7 @@ public Variable<Double> getLocationVariable() {

@Override
public double logPdf(double[] x) {
return logPdf(x, getMean(), getQ(), precisionParameter.getParameterValue(0), isImproper(), getLogDeterminant());
return logPdf(x, getMean(), getQ(), isImproper(), getLogDeterminant());
}

public static double gradLogPdfWrtPrecision(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
Expand Down Expand Up @@ -443,9 +433,8 @@ public static double[] diagonalHessianLogPdf(double[] x, SymmetricTriDiagonalMat
// return currentLike;
// }

@SuppressWarnings("unused")
private static double logPdf(double[] x, double[] mean, SymmetricTriDiagonalMatrix Q,
double precision, boolean isImproper, double logDeterminant) {
boolean isImproper, double logDeterminant) {
return getLogNormalization(x.length, isImproper, logDeterminant) - 0.5 * getSSE(x, mean, Q);
}

Expand Down Expand Up @@ -536,7 +525,7 @@ protected void handleModelChangedEvent(Model model, Object object, int index) {
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == meanParameter) {
meanKnown = false;
} else if (variable == precisionParameter) {
} else if (variable == precisionParameter || variable == lambdaParameter) {
qKnown = false;
} else {
throw new IllegalArgumentException("Unknown variable");
Expand All @@ -545,13 +534,20 @@ protected void handleVariableChangedEvent(Variable variable, int index, Paramete

@Override
protected void storeState() {
// TODO
if (qKnown) {
Q.copyTo(savedQ);
}
savedQKnown = qKnown;
}

@Override
protected void restoreState() { // TODO with caching
protected void restoreState() { // TODO cache mean
meanKnown = false;
qKnown = false;

qKnown = savedQKnown;
if (qKnown) {
savedQ.swap(Q);
}
}

@Override
Expand Down

0 comments on commit db1fa38

Please sign in to comment.