diff --git a/src/dr/evomodel/bigfasttree/thorney/BranchLengthLikelihoodDelegate.java b/src/dr/evomodel/bigfasttree/thorney/BranchLengthLikelihoodDelegate.java index 1c3d72a948..311bb91fc5 100644 --- a/src/dr/evomodel/bigfasttree/thorney/BranchLengthLikelihoodDelegate.java +++ b/src/dr/evomodel/bigfasttree/thorney/BranchLengthLikelihoodDelegate.java @@ -4,5 +4,5 @@ public interface BranchLengthLikelihoodDelegate { double getLogLikelihood(double mutations, double branchLength); - double getGradientWrtTime(double mutations, double time); + public double getGradientWrtTime(double mutations, double time, double branchRate); } \ No newline at end of file diff --git a/src/dr/evomodel/bigfasttree/thorney/PoissonBranchLengthLikelihoodDelegate.java b/src/dr/evomodel/bigfasttree/thorney/PoissonBranchLengthLikelihoodDelegate.java index b82ce93eaf..b898d2d17b 100644 --- a/src/dr/evomodel/bigfasttree/thorney/PoissonBranchLengthLikelihoodDelegate.java +++ b/src/dr/evomodel/bigfasttree/thorney/PoissonBranchLengthLikelihoodDelegate.java @@ -21,13 +21,12 @@ public double getLogLikelihood(double mutations, double branchLength) { @Override - public double getGradientWrtTime(double mutations, double branchLength) { // TODO: better chain rule handling + public double getGradientWrtTime(double mutations, double time, double branchRate) { // TODO: better chain rule handling // if (!(this.branchRateModel instanceof StrictClockBranchRates)){ // throw new RuntimeException("gradients are only implemented for a strict clock model"); // } // double rate = (double) branchRateModel.getVariable(0).getValue(0); - // return SaddlePointExpansion.logPoissonMeanDerivative(time * rate * scale, (int) Math.round(mutations)) * rate * scale; - throw new RuntimeException("gradients are not implemented for this model"); + return SaddlePointExpansion.logPoissonMeanDerivative(time * branchRate * scale, (int) Math.round(mutations)) * branchRate * scale; } diff --git a/src/dr/evomodel/bigfasttree/thorney/ThorneyTreeGradient.java b/src/dr/evomodel/bigfasttree/thorney/ThorneyTreeGradient.java index 4f2cb07490..0c7ef0d073 100644 --- a/src/dr/evomodel/bigfasttree/thorney/ThorneyTreeGradient.java +++ b/src/dr/evomodel/bigfasttree/thorney/ThorneyTreeGradient.java @@ -26,7 +26,7 @@ package dr.evomodel.bigfasttree.thorney; import dr.evolution.tree.NodeRef; -import dr.evolution.tree.Tree; +import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; import dr.evomodel.treedatalikelihood.TreeDataLikelihood; @@ -48,7 +48,7 @@ public class ThorneyTreeGradient implements GradientWrtParameterProvider, Report private final TreeParameterModel indexHelper; private final double[] branchGradient; private final ThorneyDataLikelihoodDelegate dataLikelihoodDelegate; - + private final BranchRateModel branchRateModel; public ThorneyTreeGradient(TreeDataLikelihood likelihood) { this.likelihood = likelihood; this.tree = (TreeModel)likelihood.getTree(); @@ -56,6 +56,7 @@ public ThorneyTreeGradient(TreeDataLikelihood likelihood) { this.nodeHeightProxyParameter = new NodeHeightProxyParameter("ThorneyTreeGradient.NodeHeightProxyParameter", this.tree, true); this.branchGradient = new double[tree.getNodeCount() - 1]; this.indexHelper = new TreeParameterModel(tree, new Parameter.Default(branchGradient), false); + this.branchRateModel = likelihood.getBranchRateModel(); } @@ -80,7 +81,8 @@ private void calculateBranchGradient() { NodeRef node = tree.getNode(indexHelper.getNodeNumberFromParameterIndex(i)); double time = tree.getBranchLength(node); double mutations = dataLikelihoodDelegate.getMutationMap().getMutations(node); - branchGradient[i] = dataLikelihoodDelegate.getBranchLengthLikelihoodDelegate().getGradientWrtTime(mutations, time); + double rate = branchRateModel.getBranchRate(tree,node); + branchGradient[i] = dataLikelihoodDelegate.getBranchLengthLikelihoodDelegate().getGradientWrtTime(mutations, time, rate); } }