Skip to content

Commit

Permalink
add optional gradient check tolerance to coalescent gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Nov 7, 2023
1 parent 89dd9c5 commit 48153a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 40 deletions.
44 changes: 5 additions & 39 deletions src/dr/evolution/coalescent/CoalescentGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.Binomial;
import dr.util.ComparableDouble;
import dr.util.HeapSort;
import dr.xml.Reportable;

import java.util.ArrayList;
import java.util.Arrays;

/**
Expand All @@ -55,10 +52,12 @@ public class CoalescentGradient implements GradientWrtParameterProvider, Reporta
private final Tree tree;

public CoalescentGradient(CoalescentLikelihood likelihood,
TreeModel tree) {
TreeModel tree,
double tolerance) {
this.likelihood = likelihood;
this.tree = tree;
this.parameter = new NodeHeightProxyParameter("NodeHeights", tree, true);
this.tolerance = tolerance;
}

@Override
Expand Down Expand Up @@ -86,11 +85,6 @@ public double[] getGradientLogDensity() {
return gradient;
}

int[] intervalIndices = new int[tree.getInternalNodeCount()];
int[] nodeIndices = new int[tree.getInternalNodeCount()];
double[] sortedValues = new double[tree.getInternalNodeCount()];
getIntervalIndexForInternalNodes(intervalIndices, nodeIndices, sortedValues);

IntervalList intervals = likelihood.getIntervalList();
BigFastTreeIntervals bigFastTreeIntervals = (BigFastTreeIntervals) intervals;

Expand Down Expand Up @@ -128,9 +122,6 @@ public double[] getGradientLogDensity() {
}
}
}
for (int j = 0; j < numSameHeightNodes; j++) {
gradient[nodeIndices[tree.getInternalNodeCount() - j - 1]] = thisGradient / ((double) numSameHeightNodes);
}

int j = numSameHeightNodes;
int v = bigFastTreeIntervals.getIntervalCount() - 1;
Expand All @@ -145,37 +136,12 @@ public double[] getGradientLogDensity() {
return gradient;
}

@Deprecated
private void getIntervalIndexForInternalNodes(int[] intervalIndices, int[] nodeIndices, double[] sortedValues) {
double[] nodeHeights = new double[tree.getInternalNodeCount()];
ArrayList<ComparableDouble> sortedInternalNodes = new ArrayList<ComparableDouble>();
for (int i = 0; i < nodeIndices.length; i++) {
final double nodeHeight = tree.getNodeHeight(tree.getNode(tree.getExternalNodeCount() + i));
sortedInternalNodes.add(new ComparableDouble(nodeHeight));
nodeHeights[i] = nodeHeight;
}
HeapSort.sort(sortedInternalNodes, nodeIndices);
for (int i = 0; i < nodeIndices.length; i++) {
sortedValues[i] = nodeHeights[nodeIndices[i]];
}

IntervalList intervals = likelihood.getIntervalList();
int intervalIndex = 0;
double finishTime = intervals.getInterval(intervalIndex);
for (int i = 0; i < tree.getInternalNodeCount(); i++) {
while(intervalIndex < intervals.getIntervalCount() - 1 && sortedValues[i] - finishTime > realSmallNumber) {
intervalIndex++;
finishTime += intervals.getInterval(intervalIndex);
}
intervalIndices[nodeIndices[i]] = intervalIndex;
}
}

private final double realSmallNumber = 1E-10;
private final double tolerance;

@Override
public String getReport() {
return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 1E-2);
return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, tolerance);
}

@Override
Expand Down
7 changes: 6 additions & 1 deletion src/dr/evomodelxml/coalescent/CoalescentGradientParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import dr.evolution.coalescent.CoalescentGradient;
import dr.evomodel.coalescent.CoalescentLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.inferencexml.operators.hmc.HamiltonianMonteCarloOperatorParser;
import dr.xml.*;

/**
Expand All @@ -39,11 +40,14 @@ public class CoalescentGradientParser extends AbstractXMLObjectParser {

private static final String NAME = "coalescentGradient";

private static final String TOLERANCE = HamiltonianMonteCarloOperatorParser.GRADIENT_CHECK_TOLERANCE;

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
double tolerance = xo.getAttribute(TOLERANCE, 1E-1);
CoalescentLikelihood likelihood = (CoalescentLikelihood) xo.getChild(CoalescentLikelihood.class);
TreeModel tree = (TreeModel) xo.getChild(TreeModel.class);
return new CoalescentGradient(likelihood, tree);
return new CoalescentGradient(likelihood, tree, tolerance);
}

@Override
Expand All @@ -54,6 +58,7 @@ public XMLSyntaxRule[] getSyntaxRules() {
private final XMLSyntaxRule[] rules = {
new ElementRule(CoalescentLikelihood.class),
new ElementRule(TreeModel.class),
AttributeRule.newDoubleRule(TOLERANCE, true),
};


Expand Down

0 comments on commit 48153a7

Please sign in to comment.