Skip to content

Commit

Permalink
forgot gradients in the refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Aug 18, 2023
1 parent 197ab11 commit c6a9ff9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
public interface BranchLengthLikelihoodDelegate {
double getLogLikelihood(double mutations, double branchLength);

double getGradientWrtTime(double mutations, double time);
public double getGradientWrtTime(double mutations, double time, double branchRate);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ public double getLogLikelihood(double mutations, double branchLength) {


@Override
public double getGradientWrtTime(double mutations, double branchLength) { // TODO: better chain rule handling
public double getGradientWrtTime(double mutations, double time, double branchRate) { // TODO: better chain rule handling
// if (!(this.branchRateModel instanceof StrictClockBranchRates)){
// throw new RuntimeException("gradients are only implemented for a strict clock model");
// }
// double rate = (double) branchRateModel.getVariable(0).getValue(0);
// return SaddlePointExpansion.logPoissonMeanDerivative(time * rate * scale, (int) Math.round(mutations)) * rate * scale;
throw new RuntimeException("gradients are not implemented for this model");
return SaddlePointExpansion.logPoissonMeanDerivative(time * branchRate * scale, (int) Math.round(mutations)) * branchRate * scale;
}


Expand Down
8 changes: 5 additions & 3 deletions src/dr/evomodel/bigfasttree/thorney/ThorneyTreeGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
Expand All @@ -48,14 +48,15 @@ public class ThorneyTreeGradient implements GradientWrtParameterProvider, Report
private final TreeParameterModel indexHelper;
private final double[] branchGradient;
private final ThorneyDataLikelihoodDelegate dataLikelihoodDelegate;

private final BranchRateModel branchRateModel;
public ThorneyTreeGradient(TreeDataLikelihood likelihood) {
this.likelihood = likelihood;
this.tree = (TreeModel)likelihood.getTree();
this.dataLikelihoodDelegate = (ThorneyDataLikelihoodDelegate) likelihood.getDataLikelihoodDelegate();
this.nodeHeightProxyParameter = new NodeHeightProxyParameter("ThorneyTreeGradient.NodeHeightProxyParameter", this.tree, true);
this.branchGradient = new double[tree.getNodeCount() - 1];
this.indexHelper = new TreeParameterModel(tree, new Parameter.Default(branchGradient), false);
this.branchRateModel = likelihood.getBranchRateModel();


}
Expand All @@ -80,7 +81,8 @@ private void calculateBranchGradient() {
NodeRef node = tree.getNode(indexHelper.getNodeNumberFromParameterIndex(i));
double time = tree.getBranchLength(node);
double mutations = dataLikelihoodDelegate.getMutationMap().getMutations(node);
branchGradient[i] = dataLikelihoodDelegate.getBranchLengthLikelihoodDelegate().getGradientWrtTime(mutations, time);
double rate = branchRateModel.getBranchRate(tree,node);
branchGradient[i] = dataLikelihoodDelegate.getBranchLengthLikelihoodDelegate().getGradientWrtTime(mutations, time, rate);
}
}

Expand Down

0 comments on commit c6a9ff9

Please sign in to comment.