Skip to content

Commit

Permalink
intermediate split of GP out of LogLinear models
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 22, 2023
1 parent 5db23d3 commit 404542a
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 25 deletions.
41 changes: 23 additions & 18 deletions ci/TestXML/testGamGpLikelihoodAndGradient.xml
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,8 @@
<parameter id="loc.clock.rate" value="1E-4" lower="0.0"/>
</rate>
</strictClockBranchRates>

<glmSubstitutionModel id="loc.model" normalize="true">
<generalDataType idref="loc.dataType"/>
<rootFrequencies>
<frequencyModel id="loc.frequencyModel" normalize="true">
<generalDataType idref="loc.dataType"/>
<frequencies>
<parameter id="loc.frequencies" dimension="3"/>
</frequencies>
</frequencyModel>
</rootFrequencies>
<gamGpModel id="gamGpModel"> <!-- returns density of realized field-->

<gamGpModel id="gamGpModel"> <!-- returns density of realized field-->
<realizedField>
<parameter id="loc.coefficients0" value="0 0 0 0 0 0"/> <!-- 1 2 3 4 5put other numbers-->
</realizedField>
Expand Down Expand Up @@ -154,7 +144,22 @@
</independentVariables>

<!-- nothing -->
</gamGpModel>
</gamGpModel>

<glmSubstitutionModel id="loc.model" normalize="true">
<generalDataType idref="loc.dataType"/>
<rootFrequencies>
<frequencyModel id="loc.frequencyModel" normalize="true">
<generalDataType idref="loc.dataType"/>
<frequencies>
<parameter id="loc.frequencies" dimension="3"/>
</frequencies>
</frequencyModel>
</rootFrequencies>
<!-- <gamGpModel idref="gamGpModel"/> -->
<logRates>
<parameter idref="loc.coefficients0"/>
</logRates>
</glmSubstitutionModel>

<siteModel id="loc.siteModel">
Expand All @@ -178,16 +183,15 @@
<treeDataLikelihood idref="treeLikelihood"/>
</report>

<!--
<gamGpSubstitutionModelGradient id="gradient1" traitName="loc" effects="fixed">
<approximateLogCtmcRateGradient id="gradient1" traitName="loc">
<treeDataLikelihood idref="treeLikelihood"/>
<glmSubstitutionModel idref="loc.model"/>
</gamGpSubstitutionModelGradient>
</approximateLogCtmcRateGradient>

<report>
<glmSubstitutionModelGradient idref="gradient1"/>
<approximateLogCtmcRateGradient idref="gradient1"/>
</report>
-->


<operators id="operators" optimizationSchedule="log">
<randomWalkOperator windowSize="0.1" weight="1">
Expand Down Expand Up @@ -236,4 +240,5 @@

<traceAnalysis fileName="test.log"/>


</beast>
1 change: 1 addition & 0 deletions src/dr/app/beast/development_parsers.properties
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ dr.evomodelxml.treelikelihood.thorneytreelikelihood.UniformSubtreePruneRegraftPa
dr.inferencexml.operators.MaskMoveOperatorParser
dr.evomodelxml.continuous.hmc.GlmSubstitutionModelGradientParser
dr.evomodelxml.continuous.hmc.GamGpSubstitutionModelGradientParser
dr.evomodelxml.continuous.hmc.ApproximateLogCtmcRateGradientParser
dr.evomodelxml.substmodel.BirthDeathSubstitutionModelParser

# Uncertain attributes:
Expand Down
61 changes: 59 additions & 2 deletions src/dr/evomodel/substmodel/LogAdditiveCtmcRateProvider.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package dr.evomodel.substmodel;

import dr.inference.loggers.LogColumn;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.*;

public interface LogAdditiveCtmcRateProvider extends Model, Likelihood {

Expand All @@ -17,4 +16,62 @@ default double[] getRates() {
}
return rates;
}

interface Integrated extends LogAdditiveCtmcRateProvider { }

interface DataAugmented extends LogAdditiveCtmcRateProvider {

Parameter getLogRateParameter();

class Basic extends AbstractModelLikelihood implements DataAugmented {

private final Parameter logRateParameter;

public Basic(String name, Parameter logRateParameter) {
super(name);
this.logRateParameter = logRateParameter;

addVariable(logRateParameter);
}

public Parameter getLogRateParameter() { return logRateParameter; }

@Override
public double[] getXBeta() { // TODO this function should _not_ exponentiate

final int fieldDim = logRateParameter.getDimension();
double[] rates = new double[fieldDim];

for (int i = 0; i < fieldDim; ++i) {
rates[i] = Math.exp(logRateParameter.getParameterValue(i));
}
return rates;
}

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) { }

@Override
protected void handleVariableChangedEvent(Variable variable, int index,
Parameter.ChangeType type) { }

@Override
protected void storeState() { }

@Override
protected void restoreState() { }

@Override
protected void acceptState() { }

@Override
public Model getModel() { return this; }

@Override
public double getLogLikelihood() { return 0; }

@Override
public void makeDirty() { }
}
}
}
144 changes: 144 additions & 0 deletions src/dr/evomodel/treedatalikelihood/discrete/LogCtmcRateGradient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* LogCtmcRateGradient.java
*
* Copyright (c) 2002-2023 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.evomodel.treedatalikelihood.discrete;

import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.substmodel.LogAdditiveCtmcRateProvider;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.util.Citation;
import dr.util.CommonCitations;

import java.util.Collections;
import java.util.List;

/**
* @author Filippo Monti
* @author Marc A. Suchard
*/

public class LogCtmcRateGradient extends AbstractLogAdditiveSubstitutionModelGradient {

private final LogAdditiveCtmcRateProvider.DataAugmented rateProvider;
private final int[][] mapEffectToIndices;

public LogCtmcRateGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel,
ApproximationMode.FIRST_ORDER);

if (substitutionModel.getRateProvider() instanceof LogAdditiveCtmcRateProvider.DataAugmented)
this.rateProvider = (LogAdditiveCtmcRateProvider.DataAugmented)
substitutionModel.getRateProvider();
else {
throw new IllegalArgumentException("Invalid substitution model");
}

this.mapEffectToIndices = makeAsymmetricMap();
}

@Override
protected double preProcessNormalization(double[] differentials, double[] generator,
boolean normalize) {
double total = 0.0;
if (normalize) {
for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
final int ij = i * stateCount + j;
total += differentials[ij] * generator[ij];
}
}
}
return total;
}

private int[][] makeAsymmetricMap() {
int[][] map = new int[stateCount * (stateCount - 1)][];

int k = 0;
for (int i = 0; i < stateCount; ++i) {
for (int j = i + 1; j < stateCount; ++j) {
map[k++] = new int[]{i, j};
}
}

for (int j = 0; j < stateCount; ++j) {
for (int i = j + 1; i < stateCount; ++i) {
map[k++] = new int[]{i, j};
}
}

return map;
}

@Override
double processSingleGradientDimension(int k, double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {

final int i = mapEffectToIndices[k][0], j = mapEffectToIndices[k][1];
final int ii = i * stateCount + i;
final int ij = i * stateCount + j;

double element = generator[ij];
double total = (differentials[ij] - differentials[ii]) * element;

if (normalize) {
total -= element * pi[i] * normalizationConstant;
}

return total;
}

@Override
public Parameter getParameter() {
return rateProvider.getLogRateParameter();
}

@Override
public LogColumn[] getColumns() {
throw new RuntimeException("Not yet implemented");
}

@Override
public Citation.Category getCategory() {
return Citation.Category.SUBSTITUTION_MODELS;
}

@Override
public String getDescription() {
return null; // TODO
}

@Override
public List<Citation> getCitations() {
// TODO Update
return Collections.singletonList(CommonCitations.LEMEY_2014_UNIFYING);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* GamGpSubstitutionModelGradientParser.java
*
* Copyright (c) 2002-2023 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/

package dr.evomodelxml.continuous.hmc;

import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.AbstractGlmSubstitutionModelGradient;
import dr.evomodel.treedatalikelihood.discrete.LogCtmcRateGradient;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.xml.*;

import static dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;

/**
* @author Filippo Monti
* @author Marc A. Suchard
*/

public class ApproximateLogCtmcRateGradientParser extends AbstractXMLObjectParser {

private static final String PARSER_NAME = "approximateLogCtmcRateGradient";
private static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME;

public String getParserName(){ return PARSER_NAME; }

public Object parseXMLObject(XMLObject xo) throws XMLParseException {

String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
GlmSubstitutionModel substitutionModel = (GlmSubstitutionModel) xo.getChild(GlmSubstitutionModel.class);

DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
if (!(delegate instanceof BeagleDataLikelihoodDelegate)) {
throw new XMLParseException("Unknown likelihood delegate type");
}

return new LogCtmcRateGradient(traitName, treeDataLikelihood,
(BeagleDataLikelihoodDelegate) delegate, substitutionModel);
}

@Override
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}

private final XMLSyntaxRule[] rules = {
AttributeRule.newStringRule(TRAIT_NAME, true),
new ElementRule(TreeDataLikelihood.class),
new ElementRule(GlmSubstitutionModel.class),
};

@Override
public String getParserDescription() {
return null;
}

@Override
public Class getReturnType() {
return AbstractGlmSubstitutionModelGradient.class;
}
}

Loading

0 comments on commit 404542a

Please sign in to comment.