Skip to content

Commit

Permalink
vqsr serialized model sets annotation order
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Aug 23, 2018
1 parent 06f2ab9 commit 9944059
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public List<VariantDatum> getData() {
return data;
}

public void normalizeData(final boolean calculateMeans) {
/**
* Normalize annotations to mean 0 and standard deviation 1.
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
*
* @param calculateMeans Boolean indicating whether or not to calculate the means
* @param theOrder a list of integers specifying the desired annotation order. If this is null
* annotations will get sorted in decreasing size of their standard deviations.
*/
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
boolean foundZeroVarianceAnnotation = false;
for( int iii = 0; iii < meanVector.length; iii++ ) {
final double theMean, theSTD;
Expand Down Expand Up @@ -96,15 +104,18 @@ public void normalizeData(final boolean calculateMeans) {

// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
final List<Integer> theOrder = calculateSortOrder(meanVector);
// or use the serialized report's annotation order via the argument theOrder
if (theOrder == null){
theOrder = calculateSortOrder(meanVector);
}
annotationKeys = reorderList(annotationKeys, theOrder);
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
for( final VariantDatum datum : data ) {
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
}
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString());
logger.info("Annotation order is: " + annotationKeys.toString());
}

public double[] getMeanVector() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ public class VariantRecalibrator extends MultiVariantWalker {
optional=true)
private boolean TRUST_ALL_POLYMORPHIC = false;

@VisibleForTesting
protected List<Integer> annotationOrder = null;

/////////////////////////////
// Private Member Variables
/////////////////////////////
Expand Down Expand Up @@ -447,11 +450,9 @@ public void onTraversalStart() {
pPMixTable = reportIn.getTable("GoodGaussianPMix");
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
numAnnotations = dataManager.annotationKeys.size();

if ( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
throw new CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
}
orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
numAnnotations = annotationOrder.size();

final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
Expand Down Expand Up @@ -482,6 +483,31 @@ else if (null != sequenceDictionary) {
}
}

/**
* Order and validate annotations according to the annotations in the serialized model
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
* n^2 because we typically use 7 or less annotations.
* @param annotationTable GATKReportTable of annotations read from the serialized model file
*/
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
annotationOrder = new ArrayList<Integer>(annotationKeys.size());

for (int i = 0; i < annotationTable.getNumRows(); i++){
String serialAnno = (String)annotationTable.get(i, "Annotation");
for (int j = 0; j < annotationKeys.size(); j++) {
if (serialAnno.equals( annotationKeys.get(j))){
annotationOrder.add(j);
}
}
}

if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
throw new CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
}

}

//---------------------------------------------------------------------------------------------------------------
//
// apply
Expand Down Expand Up @@ -607,7 +633,7 @@ public Object onTraversalSuccess() {
for (int i = 1; i <= max_attempts; i++) {
try {
dataManager.setData(reduceSum);
dataManager.normalizeData(inputModel == null); // Each data point is now (x - mean) / standard deviation
dataManager.normalizeData(inputModel == null, annotationOrder); // Each data point is now (x - mean) / standard deviation

final GaussianMixtureModel goodModel;
final GaussianMixtureModel badModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,51 @@ public void testVariantRecalibratorModelInput() throws IOException {
spec.executeTest("testVariantRecalibratorModelInput"+ inputFile, this);
}

private final String annoOrderRecal = getLargeVQSRTestDataDir() + "expected/anno_order.recal";
private final String annoOrderTranches = getLargeVQSRTestDataDir() + "expected/anno_order.tranches";
private final String exacModelReportFilename = publicTestDir + "/subsetExAC.snps_model.report";

@Test
public void testVQSRAnnotationOrder() throws IOException {
final String inputFile = publicTestDir + "/oneSNP.vcf";

// We don't actually need resources because we are using a serialized model,
// so we just pass input as resource to prevent a crash
final IntegrationTestSpec spec = new IntegrationTestSpec(
" --variant " + inputFile +
" -L 1:110201699" +
" --resource hapmap,known=false,training=true,truth=true,prior=15:" + inputFile +
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR" +
" --output %s" +
" -tranches-file %s" +
" --input-model " + exacModelReportFilename +
" --add-output-vcf-command-line false" +
" -ignore-all-filters -mode SNP",
Arrays.asList(
annoOrderRecal,
annoOrderTranches));
spec.executeTest("testVariantRecalibratorModelInput"+ inputFile, this);

Utils.resetRandomGenerator();
// Change annotation order and assert consistent outputs
final IntegrationTestSpec spec2 = new IntegrationTestSpec(
" --variant " + inputFile +
" -L 1:110201699" +
" --resource hapmap,known=false,training=true,truth=true,prior=15:" + inputFile +
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS" +
" --output %s" +
" -tranches-file %s" +
" --input-model " + exacModelReportFilename +
" --add-output-vcf-command-line false" +
" -ignore-all-filters -mode SNP",
Arrays.asList(
annoOrderRecal,
annoOrderTranches));
spec2.executeTest("testVariantRecalibratorModelInput"+ inputFile, this);

}


@DataProvider(name="VarRecalSNPScattered")
public Object[][] getVarRecalSNPScatteredData() {
return new Object[][] {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import avro.shaded.com.google.common.collect.Lists;
import org.apache.log4j.Logger;

import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.utils.report.GATKReport;
import org.broadinstitute.hellbender.utils.report.GATKReportTable;

Expand Down Expand Up @@ -225,4 +227,42 @@ private GaussianMixtureModel getBadGMM(){
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
}

@Test
public void testAnnotationOrderAndValidate() {
final VariantRecalibrator vqsr = new VariantRecalibrator();
final List<String> annotationList = new ArrayList<>();
annotationList.add("QD");
annotationList.add("FS");
annotationList.add("ReadPosRankSum");
annotationList.add("MQ");
annotationList.add("MQRankSum");
annotationList.add("SOR");

double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
final String columnName = "Mean";
final String formatString = "%.3f";
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);

double epsilon = 1e-7;
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
Assert.assertEquals(i, vqsr.annotationOrder.get(i), epsilon);
}

final List<String> reversed = Lists.reverse(annotationList);
vqsr.orderAndValidateAnnotations(annotationTable, reversed);
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
Assert.assertEquals(reversed.size()-i-1, vqsr.annotationOrder.get(i), epsilon);
}

// Now break things...
// we should throw an error if there are too many or too few annotations on the command line.
annotationList.add("BaseQRankSum");
Assert.assertThrows(CommandLineException.class, () -> vqsr.orderAndValidateAnnotations(annotationTable, annotationList));

annotationList.remove(0);
annotationList.remove(annotationList.size()-1);
Assert.assertThrows(CommandLineException.class, () -> vqsr.orderAndValidateAnnotations(annotationTable, annotationList));
}

}
3 changes: 3 additions & 0 deletions src/test/resources/large/VQSR/expected/anno_order.recal
Git LFS file not shown
3 changes: 3 additions & 0 deletions src/test/resources/large/VQSR/expected/anno_order.tranches
Git LFS file not shown
Loading

0 comments on commit 9944059

Please sign in to comment.