diff --git a/build.gradle b/build.gradle
index 3af1061cad7..e2754b5a489 100644
--- a/build.gradle
+++ b/build.gradle
@@ -285,6 +285,7 @@ dependencies {
implementation 'org.apache.commons:commons-lang3:3.5'
implementation 'org.apache.commons:commons-math3:3.5'
+ implementation 'org.hipparchus:hipparchus-stat:2.0'
implementation 'org.apache.commons:commons-collections4:4.1'
implementation 'org.apache.commons:commons-vfs2:2.0'
implementation 'org.apache.commons:commons-configuration2:2.4'
diff --git a/scripts/gatkcondaenv.yml.template b/scripts/gatkcondaenv.yml.template
index 10467af3cdf..dbe29ed5a28 100644
--- a/scripts/gatkcondaenv.yml.template
+++ b/scripts/gatkcondaenv.yml.template
@@ -42,6 +42,7 @@ dependencies:
- conda-forge::matplotlib=3.2.1
- conda-forge::pandas=1.0.3
- conda-forge::typing_extensions=4.1.1 # see https://github.com/broadinstitute/gatk/issues/7800 and linked PRs
+- conda-forge::dill=0.3.4 # used for pickling lambdas in TrainVariantAnnotationsModel
# core R dependencies; these should only be used for plotting and do not take precedence over core python dependencies!
- r-base=3.6.2
diff --git a/scripts/vcf_site_level_filtering_wdl/JointVcfFiltering.wdl b/scripts/vcf_site_level_filtering_wdl/JointVcfFiltering.wdl
index 89e05abcb89..87b520aca0d 100644
--- a/scripts/vcf_site_level_filtering_wdl/JointVcfFiltering.wdl
+++ b/scripts/vcf_site_level_filtering_wdl/JointVcfFiltering.wdl
@@ -192,8 +192,6 @@ task TrainVariantAnnotationModel {
command <<<
set -e
- conda install -y --name gatk dill
-
export GATK_LOCAL_JAR=~{default="/root/gatk.jar" gatk_override}
mode=$(echo "~{mode}" | awk '{print toupper($0)}')
@@ -245,8 +243,6 @@ task ScoreVariantAnnotations {
ln -s ~{sep=" . && ln -s " model_files} .
- conda install -y --name gatk dill
-
export GATK_LOCAL_JAR=~{default="/root/gatk.jar" gatk_override}
gatk --java-options "-Xmx~{command_mem}m" \
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java
index 4706a1e1d7d..dbe972fa541 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CollectReadCounts.java
@@ -68,7 +68,7 @@
* to TSV format. Using HDF5 files with {@link CreateReadCountPanelOfNormals}
* can decrease runtime, by reducing time spent on IO, so this is the default output format.
* The HDF5 format contains information in the paths defined in {@link HDF5SimpleCountCollection}. HDF5 files may be viewed using
- * hdfview or loaded in python using
+ * hdfview or loaded in Python using
* PyTables or h5py.
* The TSV format has a SAM-style header containing a read group sample name, a sequence dictionary, a row specifying the column headers contained in
* {@link SimpleCountCollection.SimpleCountTableColumn}, and the corresponding entry rows.
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CreateReadCountPanelOfNormals.java b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CreateReadCountPanelOfNormals.java
index dbcc0cc1c4d..d4d6b8db9c0 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CreateReadCountPanelOfNormals.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/CreateReadCountPanelOfNormals.java
@@ -85,7 +85,7 @@
* Panel-of-normals file.
* This is an HDF5 file containing the panel data in the paths defined in {@link HDF5SVDReadCountPanelOfNormals}.
* HDF5 files may be viewed using hdfview
- * or loaded in python using PyTables or h5py.
+ * or loaded in Python using PyTables or h5py.
*
*
*
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/utils/HDF5Utils.java b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/utils/HDF5Utils.java
index 8590e3476f2..870ce37b7dc 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/copynumber/utils/HDF5Utils.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/copynumber/utils/HDF5Utils.java
@@ -135,7 +135,7 @@ public static double[][] readChunkedDoubleMatrix(final HDF5File file,
* Given a large matrix, chunks the matrix into equally sized subsets of rows
* (plus a subset containing the remainder, if necessary) and writes these submatrices to indexed sub-paths
* to avoid a hard limit in Java HDF5 on the number of elements in a matrix given by
- * {@code MAX_NUM_VALUES_PER_HDF5_MATRIX}. The number of chunks is determined by {@code maxChunkSize},
+ * {@code MAX_NUMBER_OF_VALUES_PER_HDF5_MATRIX}. The number of chunks is determined by {@code maxChunkSize},
* which should be set appropriately for the desired number of columns.
*
* @param maxChunkSize The maximum number of values in each chunk. Decreasing this number will reduce
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/VariantRecalibrator.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/VariantRecalibrator.java
index f7148d043f1..2da7997a51c 100644
--- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/VariantRecalibrator.java
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/VariantRecalibrator.java
@@ -10,6 +10,7 @@
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLine;
import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.hdf5.HDF5File;
import org.broadinstitute.hellbender.cmdline.*;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
@@ -23,6 +24,8 @@
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.MultiVariantWalker;
+import org.broadinstitute.hellbender.exceptions.GATKException;
+import org.broadinstitute.hellbender.utils.io.IOUtils;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;
import org.broadinstitute.hellbender.utils.R.RScriptExecutor;
import org.broadinstitute.hellbender.utils.SimpleInterval;
@@ -41,6 +44,7 @@
import java.io.*;
import java.util.*;
+import java.util.stream.IntStream;
/**
* Build a recalibration model to score variant quality for filtering purposes
@@ -639,6 +643,10 @@ public Object onTraversalSuccess() {
for (int i = 1; i <= max_attempts; i++) {
try {
dataManager.setData(reduceSum);
+
+ final String rawAnnotationsOutput = output.toString().endsWith(".recal") ? output.toString().split(".recal")[0] : output.toString();
+ writeAnnotationsHDF5(new File(rawAnnotationsOutput + ".annot.raw.hdf5"));
+
dataManager.normalizeData(inputModel == null, annotationOrder); // Each data point is now (x - mean) / standard deviation
final GaussianMixtureModel goodModel;
@@ -678,6 +686,9 @@ public Object onTraversalSuccess() {
}
}
+ final String annotationsOutput = output.toString().endsWith(".recal") ? output.toString().split(".recal")[0] : output.toString();
+ writeAnnotationsHDF5(new File(annotationsOutput + ".annot.hdf5"));
+
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData(dataManager.getData(), badModel, true);
@@ -686,6 +697,10 @@ public Object onTraversalSuccess() {
saveModelReport(report, outputModel);
}
+ final String modelOutput = output.toString().endsWith(".recal") ? output.toString().split(".recal")[0] : output.toString();
+ writeModelHDF5(new File(modelOutput + ".positive.hdf5"), goodModel);
+ writeModelHDF5(new File(modelOutput + ".negative.hdf5"), badModel);
+
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
@@ -1176,4 +1191,43 @@ private void createArrangeFunction( final PrintStream stream ) {
stream.println("}");
stream.println("}");
}
+
+ public void writeAnnotationsHDF5(final File file) {
+ try (final HDF5File hdf5File = new HDF5File(file, HDF5File.OpenMode.CREATE)) { // TODO allow appending
+ IOUtils.canReadFile(hdf5File.getFile());
+
+ hdf5File.makeStringArray("/data/annotation_names", dataManager.getAnnotationKeys().stream().toArray(String[]::new));
+ hdf5File.makeDoubleMatrix("/data/annotations", dataManager.getData().stream().map(vd -> vd.annotations).toArray(double[][]::new));
+ hdf5File.makeDoubleArray("/data/is_training", dataManager.getData().stream().mapToDouble(vd -> vd.atTrainingSite ? 1 : 0).toArray());
+ hdf5File.makeDoubleArray("/data/is_truth", dataManager.getData().stream().mapToDouble(vd -> vd.atTruthSite ? 1 : 0).toArray());
+ hdf5File.makeDoubleArray("/data/is_anti_training", dataManager.getData().stream().mapToDouble(vd -> vd.atAntiTrainingSite ? 1 : 0).toArray());
+
+ logger.info(String.format("Annotations written to %s.", file.getAbsolutePath()));
+ } catch (final RuntimeException exception) {
+ throw new GATKException(String.format("Exception encountered during writing of annotations (%s). Output file at %s may be in a bad state.",
+ exception, file.getAbsolutePath()));
+ }
+ }
+
+ public void writeModelHDF5(final File file,
+ final GaussianMixtureModel model) {
+ try (final HDF5File hdf5File = new HDF5File(file, HDF5File.OpenMode.CREATE)) { // TODO allow appending
+ IOUtils.canReadFile(hdf5File.getFile());
+
+ final int nComponents = model.getModelGaussians().size();
+ final int nFeatures = model.getNumAnnotations();
+ hdf5File.makeDouble("/vqsr/number_of_components", nComponents);
+ hdf5File.makeDouble("/vqsr/number_of_features", nComponents);
+ hdf5File.makeDoubleArray("/vqsr/weights", model.getModelGaussians().stream().mapToDouble(g -> Math.pow(10., (g.pMixtureLog10))).toArray());
+ IntStream.range(0, nComponents).forEach(
+ k -> hdf5File.makeDoubleArray("/vqsr/means/" + k, model.getModelGaussians().get(k).mu));
+ IntStream.range(0, nComponents).forEach(
+ k -> hdf5File.makeDoubleMatrix("vqsr/covariances/" + k, model.getModelGaussians().get(k).sigma.getArray()));
+
+ logger.info(String.format("VQSR model written to %s.", file.getAbsolutePath()));
+ } catch (final RuntimeException exception) {
+ throw new GATKException(String.format("Exception encountered during writing of VQSR model (%s). Output file at %s may be in a bad state.",
+ exception, file.getAbsolutePath()));
+ }
+ }
}
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java
new file mode 100644
index 00000000000..48f73007767
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ExtractVariantAnnotations.java
@@ -0,0 +1,361 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;
+
+import htsjdk.variant.variantcontext.Allele;
+import htsjdk.variant.variantcontext.VariantContext;
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.RandomGeneratorFactory;
+import org.broadinstitute.barclay.argparser.Argument;
+import org.broadinstitute.barclay.argparser.BetaFeature;
+import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
+import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
+import org.broadinstitute.hellbender.engine.FeatureContext;
+import org.broadinstitute.hellbender.engine.ReadsContext;
+import org.broadinstitute.hellbender.engine.ReferenceContext;
+import org.broadinstitute.hellbender.exceptions.GATKException;
+import org.broadinstitute.hellbender.tools.copynumber.utils.HDF5Utils;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantRecalibrator;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
+import picard.cmdline.programgroups.VariantFilteringProgramGroup;
+
+import java.io.File;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+/**
+ * Extracts site-level variant annotations, labels, and other metadata from a VCF file to HDF5 files.
+ *
+ *
+ * This tool is intended to be used as the first step in a variant-filtering workflow that supersedes the
+ * {@link VariantRecalibrator} workflow. This tool extracts site-level annotations, labels, and other relevant metadata
+ * from variant sites (or alleles, in allele-specific mode) that are or are not present in specified labeled
+ * resource VCFs (e.g., training or calibration VCFs). The former, present sites are considered labeled; each site
+ * can have multiple labels. The latter sites are considered unlabeled and can be randomly downsampled using
+ * reservoir sampling; extraction of these is optional. The outputs of the tool are HDF5 files containing the
+ * extracted data for labeled and (optional) unlabeled variant sets, as well as a sites-only indexed VCF containing
+ * the labeled variants.
+ *
+ *
+ *
+ * The extracted sets can be provided as input to the {@link TrainVariantAnnotationsModel} tool
+ * to produce an annotation-based model for scoring variant calls. This model can in turn be provided
+ * along with a VCF file to the {@link ScoreVariantAnnotations} tool, which assigns a score to each call
+ * (with a lower score indicating that a call is more likely to be an artifact and should perhaps be filtered).
+ * Each score can also be converted to a corresponding sensitivity to a calibration set, if the latter is available.
+ *
+ *
+ *
+ * Note that annotations and metadata are collected in memory during traversal until they are written to HDF5 files
+ * upon completion of the traversal. Memory requirements thus roughly scale linearly with both the number of sites
+ * extracted and the number of annotations.
+ *
+ *
+ *
+ * Note that HDF5 files may be viewed using hdfview
+ * or loaded in Python using PyTables or h5py.
+ *
+ *
+ *
Inputs
+ *
+ *
+ *
+ * Input VCF file. Site-level annotations will be extracted from the contained variants (or alleles,
+ * if the {@value USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME} argument is specified).
+ *
+ *
+ * Annotations to extract.
+ *
+ *
+ * Variant types (i.e., SNP and/or INDEL) to extract. Logic for determining variant type was retained from
+ * {@link VariantRecalibrator}; see {@link VariantType}. Extracting SNPs and INDELs separately in two runs of
+ * this tool can be useful if one wishes to extract different sets of annotations for each variant type,
+ * for example.
+ *
+ *
+ * (Optional) Resource VCF file(s). Each resource should be tagged with a label, which will be assigned to
+ * extracted sites that are present in the resource. In typical use, the {@value LabeledVariantAnnotationsData#TRAINING_LABEL}
+ * and {@value LabeledVariantAnnotationsData#CALIBRATION_LABEL} labels should be used to tag at least one resource
+ * apiece. The resulting sets of sites will be used for model training and conversion of scores to
+ * calibration-set sensitivity, respectively; the trustworthiness of the respective resources should be
+ * taken into account accordingly. The {@value LabeledVariantAnnotationsData#SNP_LABEL} label is
+ * reserved by the tool, as it is used to label sites determined to be SNPs, and thus it cannot be used to tag
+ * provided resources.
+ *
+ *
+ * (Optional) Maximum number of unlabeled variants (or alleles) to randomly sample with reservoir downsampling.
+ * If nonzero, annotations will also be extracted from unlabeled sites (i.e., those that are not present
+ * in the labeled resources).
+ *
+ *
+ * Output prefix.
+ * This is used as the basename for output files.
+ *
+ *
+ *
+ *
Outputs
+ *
+ *
+ *
+ * (Optional) Labeled-annotations HDF5 file (.annot.hdf5). Annotation data and metadata for those sites that
+ * are present in labeled resources are stored in the following HDF5 directory structure:
+ *
+ *
+ * Here, each chunk is a double matrix, with dimensions given by (number of sites in the chunk) x (number of annotations).
+ * See the methods {@link HDF5Utils#writeChunkedDoubleMatrix} and {@link HDF5Utils#writeIntervals} for additional details.
+ * If {@value USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME} is specified, each record corresponds to an individual allele;
+ * otherwise, each record corresponds to a variant site, which may contain multiple alleles.
+ * Storage of alleles can be omitted using the {@value OMIT_ALLELES_IN_HDF5_LONG_NAME} argument, which will reduce
+ * the size of the file. This file will only be produced if resources are provided and the number of extracted
+ * labeled sites is nonzero.
+ *
+ *
+ *
+ *
+ * Labeled sites-only VCF file and index. The VCF will not be gzipped if the {@value DO_NOT_GZIP_VCF_OUTPUT_LONG_NAME}
+ * argument is set to true. The VCF can be provided as a resource in subsequent runs of
+ * {@link ScoreVariantAnnotations} and used to indicate labeled sites that were extracted.
+ * This can be useful if the {@value StandardArgumentDefinitions#INTERVALS_LONG_NAME} argument was used to
+ * subset sites in training or calibration resources for extraction; this may occur when setting up
+ * training/validation/test splits, for example. Note that records for the random sample of unlabeled sites are
+ * currently not included in the VCF.
+ *
+ *
+ * (Optional) Unlabeled-annotations HDF5 file. This will have the same directory structure as in the
+ * labeled-annotations HDF5 file. However, note that records are currently written in the order they
+ * appear in the downsampling reservoir after random sampling, and hence, are not in genomic order.
+ * This file will only be produced if a nonzero value of the {@value MAXIMUM_NUMBER_OF_UNLABELED_VARIANTS_LONG_NAME}
+ * argument is provided.
+ *
+ *
+ *
+ *
Usage examples
+ *
+ *
+ * Extract annotations from training/calibration SNP/INDEL sites, producing the outputs
+ * 1) {@code extract.annot.hdf5}, 2) {@code extract.vcf.gz}, and 3) {@code extract.vcf.gz.tbi}.
+ * The HDF5 file can then be provided to {@link TrainVariantAnnotationsModel}
+ * to train a model using a positive-only approach.
+ *
+ *
+ * Extract annotations from both training/calibration SNP/INDEL sites and a random sample of
+ * 1000000 unlabeled (i.e., non-training/calibration) sites, producing the outputs
+ * 1) {@code extract.annot.hdf5}, 2) {@code extract.unlabeled.annot.hdf5}, 3) {@code extract.vcf.gz},
+ * and 4) {@code extract.vcf.gz.tbi}. The HDF5 files can then be provided to {@link TrainVariantAnnotationsModel}
+ * to train a model using a positive-negative approach (similar to that used in {@link VariantRecalibrator}).
+ *
+ *
+ * In the (atypical) event that resource VCFs are unavailable, one can still extract annotations from a random sample of
+ * unlabeled sites, producing the outputs 1) {@code extract.unlabeled.annot.hdf5},
+ * 2) {@code extract.vcf.gz} (which will contain no records), and 3) {@code extract.vcf.gz.tbi}.
+ * This random sample cannot be used by {@link TrainVariantAnnotationsModel}, but may still be useful for
+ * exploratory analyses.
+ *
+ *
+ *
+ *
+ * DEVELOPER NOTE: See documentation in {@link LabeledVariantAnnotationsWalker}.
+ *
+ * @author Samuel Lee <slee@broadinstitute.org>
+ */
+@CommandLineProgramProperties(
+ summary = "Extracts site-level variant annotations, labels, and other metadata from a VCF file to HDF5 files.",
+ oneLineSummary = "Extracts site-level variant annotations, labels, and other metadata from a VCF file to HDF5 files",
+ programGroup = VariantFilteringProgramGroup.class
+)
+@DocumentedFeature
+@BetaFeature
+public final class ExtractVariantAnnotations extends LabeledVariantAnnotationsWalker {
+
+ public static final String MAXIMUM_NUMBER_OF_UNLABELED_VARIANTS_LONG_NAME = "maximum-number-of-unlabeled-variants";
+ public static final String RESERVOIR_SAMPLING_RANDOM_SEED_LONG_NAME = "reservoir-sampling-random-seed";
+
+ public static final String UNLABELED_TAG = ".unlabeled";
+
+ @Argument(
+ fullName = MAXIMUM_NUMBER_OF_UNLABELED_VARIANTS_LONG_NAME,
+ doc = "Maximum number of unlabeled variants to extract. " +
+ "If greater than zero, reservoir sampling will be used to randomly sample this number " +
+ "of sites from input sites that are not present in the specified resources.",
+ minValue = 0)
+ private int maximumNumberOfUnlabeledVariants = 0;
+
+ @Argument(
+ fullName = RESERVOIR_SAMPLING_RANDOM_SEED_LONG_NAME,
+ doc = "Random seed to use for reservoir sampling of unlabeled variants.")
+ private int reservoirSamplingRandomSeed = 0;
+
+ private RandomGenerator rng;
+ private LabeledVariantAnnotationsData unlabeledDataReservoir; // will not be sorted in genomic order
+ private int unlabeledIndex = 0;
+
+ @Override
+ public void afterOnTraversalStart() {
+ if (!resourceLabels.contains(LabeledVariantAnnotationsData.TRAINING_LABEL)) {
+ logger.warn("No training set found! If you are using the downstream TrainVariantAnnotationsModel and ScoreVariantAnnotations tools, " +
+ "provide sets of known polymorphic loci marked with the training=true feature input tag. " +
+ "For example, --resource:hapmap,training=true hapmap.vcf");
+ }
+ if (!resourceLabels.contains(LabeledVariantAnnotationsData.CALIBRATION_LABEL)) {
+ logger.warn("No calibration set found! If you are using the downstream TrainVariantAnnotationsModel and ScoreVariantAnnotations tools " +
+ "and wish to convert scores to sensitivity to a calibration set of variants, " +
+ "provide sets of known polymorphic loci marked with the calibration=true feature input tag. " +
+ "For example, --resource:hapmap,calibration=true hapmap.vcf");
+ }
+
+ rng = RandomGeneratorFactory.createRandomGenerator(new Random(reservoirSamplingRandomSeed));
+ unlabeledDataReservoir = maximumNumberOfUnlabeledVariants == 0
+ ? null
+ : new LabeledVariantAnnotationsData(annotationNames, resourceLabels, useASAnnotations, maximumNumberOfUnlabeledVariants);
+ }
+
+ @Override
+ protected void nthPassApply(final VariantContext variant,
+ final ReadsContext readsContext,
+ final ReferenceContext referenceContext,
+ final FeatureContext featureContext,
+ final int n) {
+ if (n == 0) {
+ final List, VariantType, TreeSet>> metadata = extractVariantMetadata(
+ variant, featureContext, unlabeledDataReservoir != null);
+ final boolean isVariantExtracted = !metadata.isEmpty();
+ if (isVariantExtracted) {
+ final boolean isUnlabeled = metadata.stream().map(Triple::getRight).allMatch(Set::isEmpty);
+ if (!isUnlabeled) {
+ addExtractedVariantToData(data, variant, metadata);
+ writeExtractedVariantToVCF(variant, metadata);
+ } else {
+ // Algorithm R for reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
+ if (unlabeledIndex < maximumNumberOfUnlabeledVariants) {
+ addExtractedVariantToData(unlabeledDataReservoir, variant, metadata);
+ } else {
+ final int j = rng.nextInt(unlabeledIndex);
+ if (j < maximumNumberOfUnlabeledVariants) {
+ setExtractedVariantInData(unlabeledDataReservoir, variant, metadata, j);
+ }
+ }
+ unlabeledIndex++;
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void afterNthPass(final int n) {
+ if (n == 0) {
+ writeAnnotationsToHDF5();
+ data.clear();
+ if (unlabeledDataReservoir != null) {
+ writeUnlabeledAnnotationsToHDF5();
+ // TODO write extracted unlabeled variants to VCF, which can be used to mark extraction in scoring step
+ unlabeledDataReservoir.clear();
+ }
+ if (vcfWriter != null) {
+ vcfWriter.close();
+ }
+ }
+ }
+
+ @Override
+ public Object onTraversalSuccess() {
+
+ logger.info(String.format("%s complete.", getClass().getSimpleName()));
+
+ return null;
+ }
+
+ private static void setExtractedVariantInData(final LabeledVariantAnnotationsData data,
+ final VariantContext variant,
+ final List, VariantType, TreeSet>> metadata,
+ final int index) {
+ data.set(index, variant,
+ metadata.stream().map(Triple::getLeft).collect(Collectors.toList()),
+ metadata.stream().map(Triple::getMiddle).collect(Collectors.toList()),
+ metadata.stream().map(Triple::getRight).collect(Collectors.toList()));
+ }
+
+ private void writeUnlabeledAnnotationsToHDF5() {
+ final File outputUnlabeledAnnotationsFile = new File(outputPrefix + UNLABELED_TAG + ANNOTATIONS_HDF5_SUFFIX);
+ if (unlabeledDataReservoir.size() == 0) {
+ throw new GATKException("No unlabeled variants were present in the input VCF.");
+ }
+ for (final VariantType variantType : variantTypesToExtract) {
+ logger.info(String.format("Extracted unlabeled annotations for %d variants of type %s.",
+ unlabeledDataReservoir.getVariantTypeFlat().stream().mapToInt(t -> t == variantType ? 1 : 0).sum(), variantType));
+ }
+ logger.info(String.format("Extracted unlabeled annotations for %s total variants.", unlabeledDataReservoir.size()));
+
+ logger.info("Writing unlabeled annotations...");
+ // TODO coordinate sort
+ unlabeledDataReservoir.writeHDF5(outputUnlabeledAnnotationsFile, omitAllelesInHDF5);
+ logger.info(String.format("Unlabeled annotations and metadata written to %s.", outputUnlabeledAnnotationsFile.getAbsolutePath()));
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java
new file mode 100644
index 00000000000..128b3bcf1df
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/LabeledVariantAnnotationsWalker.java
@@ -0,0 +1,382 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;
+
+import htsjdk.samtools.SAMSequenceDictionary;
+import htsjdk.variant.variantcontext.Allele;
+import htsjdk.variant.variantcontext.VariantContext;
+import htsjdk.variant.variantcontext.VariantContextBuilder;
+import htsjdk.variant.variantcontext.writer.VariantContextWriter;
+import htsjdk.variant.vcf.VCFConstants;
+import htsjdk.variant.vcf.VCFHeader;
+import htsjdk.variant.vcf.VCFHeaderLine;
+import htsjdk.variant.vcf.VCFHeaderLineType;
+import htsjdk.variant.vcf.VCFInfoHeaderLine;
+import org.apache.commons.collections4.ListUtils;
+import org.apache.commons.lang3.tuple.Triple;
+import org.broadinstitute.barclay.argparser.Argument;
+import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
+import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
+import org.broadinstitute.hellbender.engine.FeatureContext;
+import org.broadinstitute.hellbender.engine.FeatureInput;
+import org.broadinstitute.hellbender.engine.MultiplePassVariantWalker;
+import org.broadinstitute.hellbender.exceptions.UserException;
+import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
+import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
+import org.broadinstitute.hellbender.utils.variant.GATKVCFHeaderLines;
+import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;
+import org.broadinstitute.hellbender.utils.variant.VcfUtils;
+import picard.cmdline.programgroups.VariantFilteringProgramGroup;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+/**
+ * Base walker for both {@link ExtractVariantAnnotations} and {@link ScoreVariantAnnotations},
+ * which enforces identical variant-extraction behavior in both tools via {@link #extractVariantMetadata}.
+ *
+ * This base implementation covers functionality for {@link ExtractVariantAnnotations}. Namely, it is a single-pass
+ * walker, performing the operations:
+ *
+ * - nthPassApply(n = 0)
+ * - if variant/alleles pass filters and variant-type/overlapping-resource checks, then:
+ * - add variant/alleles to a {@link LabeledVariantAnnotationsData} collection
+ * - write variant/alleles with labels appended to a sites-only VCF file
+ * - afterNthPass(n = 0)
+ * - write the resulting {@link LabeledVariantAnnotationsData} collection to an HDF5 file
+ *
+ * This results in the following output:
+ *
+ * - an HDF5 file, with the directory structure documented in {@link LabeledVariantAnnotationsData#writeHDF5};
+ * note that the matrix of annotations contains a single row per datum (i.e., per allele, in allele-specific mode,
+ * and per variant otherwise)
+ * - a sites-only VCF file, containing a single line per extracted variant, with labels appended
+ *
+ * In contrast, the {@link ScoreVariantAnnotations} implementation overrides methods to yield a two-pass walker,
+ * performing the operations:
+ *
+ * - nthPassApply(n = 0)
+ * - if variant/alleles pass filters and variant-type checks, then:
+ * - add variant/alleles to a {@link LabeledVariantAnnotationsData} collection
+ * - afterNthPass(n = 0)
+ * - write the resulting {@link LabeledVariantAnnotationsData} collection to an HDF5 file
+ * - pass this annotations HDF5 file to a {@link VariantAnnotationsScorer}, which generates and writes scores to an HDF5 file
+ * - read the scores back in and load them into an iterator
+ * - nthPassApply(n = 1)
+ * - if variant/alleles pass filters and variant-type checks (which are identical to the first pass), then:
+ * - draw the corresponding score (or scores, in allele-specific mode) from the iterator
+ * - write the variant (with all alleles, not just those extracted) with the score
+ * (or best score, in allele-specific mode) appended to a VCF file
+ * - else:
+ * - write an unprocessed copy of the variant to a VCF file
+ *
+ * This results in the following output:
+ *
+ * - an HDF5 file, as above
+ * - a VCF file, containing the input variants, with labels and scores appended for those passing variant-type checks TODO + calibration-sensitivity scores + filters applied?
+ */
+@CommandLineProgramProperties(
+ // TODO
+ summary = "",
+ oneLineSummary = "",
+ programGroup = VariantFilteringProgramGroup.class
+)
+@DocumentedFeature
+public abstract class LabeledVariantAnnotationsWalker extends MultiplePassVariantWalker {
+
+ public static final String MODE_LONG_NAME = "mode";
+ public static final String USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME = "use-allele-specific-annotations";
+ public static final String IGNORE_FILTER_LONG_NAME = "ignore-filter";
+ public static final String IGNORE_ALL_FILTERS_LONG_NAME = "ignore-all-filters";
+ public static final String DO_NOT_TRUST_ALL_POLYMORPHIC_LONG_NAME = "do-not-trust-all-polymorphic";
+ public static final String OMIT_ALLELES_IN_HDF5_LONG_NAME = "omit-alleles-in-hdf5";
+ public static final String DO_NOT_GZIP_VCF_OUTPUT_LONG_NAME = "do-not-gzip-vcf-output";
+
+ public static final String ANNOTATIONS_HDF5_SUFFIX = ".annot.hdf5";
+
+ public static final String RESOURCE_LABEL_INFO_HEADER_LINE_FORMAT_STRING = "This site was labeled as %s according to resources";
+
+ @Argument(
+ fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME,
+ shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
+ doc = "Prefix for output filenames.")
+ String outputPrefix;
+
+ @Argument(
+ fullName = StandardArgumentDefinitions.RESOURCE_LONG_NAME,
+ doc = "Resource VCFs used to label extracted variants.",
+ optional = true)
+ private List> resources = new ArrayList<>(10);
+
+ @Argument(
+ fullName = StandardArgumentDefinitions.ANNOTATION_LONG_NAME,
+ shortName = StandardArgumentDefinitions.ANNOTATION_SHORT_NAME,
+ doc = "Names of the annotations to extract. Note that a requested annotation may in fact not be present " +
+ "at any extraction site; NaN missing values will be generated for such annotations.",
+ minElements = 1)
+ List annotationNames = new ArrayList<>();
+
+ @Argument(
+ fullName = MODE_LONG_NAME,
+ doc = "Variant types to extract.",
+ minElements = 1)
+ private List variantTypesToExtractList = new ArrayList<>(Arrays.asList(VariantType.SNP, VariantType.INDEL));
+
+ @Argument(
+ fullName = USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME,
+ doc = "If true, use the allele-specific versions of the specified annotations.",
+ optional = true)
+ boolean useASAnnotations = false;
+
+ @Argument(
+ fullName = IGNORE_FILTER_LONG_NAME,
+ doc = "Ignore the specified filter(s) in the input VCF.",
+ optional = true)
+ private List ignoreInputFilters = new ArrayList<>();
+
+ @Argument(
+ fullName = IGNORE_ALL_FILTERS_LONG_NAME,
+ doc = "If true, ignore all filters in the input VCF.",
+ optional = true)
+ private boolean ignoreAllFilters = false;
+
+ // TODO this is a perhaps vestigial argument inherited from VQSR; its impact and necessity could be reevaluated
+ @Argument(
+ fullName = DO_NOT_TRUST_ALL_POLYMORPHIC_LONG_NAME,
+ doc = "If true, do not trust that unfiltered records in the resources contain only polymorphic sites. " +
+ "This may increase runtime.",
+ optional = true)
+ private boolean doNotTrustAllPolymorphic = false;
+
+ @Argument(
+ fullName = OMIT_ALLELES_IN_HDF5_LONG_NAME,
+ doc = "If true, omit alleles in output HDF5 files in order to decrease file sizes.",
+ optional = true
+ )
+ boolean omitAllelesInHDF5 = false;
+
+ @Argument(
+ fullName = DO_NOT_GZIP_VCF_OUTPUT_LONG_NAME,
+ doc = "If true, VCF output will not be compressed.",
+ optional = true
+ )
+ boolean doNotGZIPVCFOutput = false;
+
+ private final Set ignoreInputFilterSet = new TreeSet<>();
+ Set variantTypesToExtract;
+ TreeSet resourceLabels = new TreeSet<>();
+
+ File outputAnnotationsFile;
+ VariantContextWriter vcfWriter;
+
+ LabeledVariantAnnotationsData data;
+
+ @Override
+ public void onTraversalStart() {
+
+ ignoreInputFilterSet.addAll(ignoreInputFilters);
+
+ variantTypesToExtract = EnumSet.copyOf(variantTypesToExtractList);
+
+ outputAnnotationsFile = new File(outputPrefix + ANNOTATIONS_HDF5_SUFFIX);
+ final String vcfSuffix = doNotGZIPVCFOutput ? ".vcf" : ".vcf.gz";
+ final File outputVCFFile = new File(outputPrefix + vcfSuffix);
+
+ // TODO this validation method should perhaps be moved outside of the CNV code
+ CopyNumberArgumentValidationUtils.validateOutputFiles(outputAnnotationsFile, outputVCFFile);
+
+ for (final FeatureInput resource : resources) {
+ final TreeSet trackResourceLabels = resource.getTagAttributes().entrySet().stream()
+ .filter(e -> e.getValue().equals("true"))
+ .map(Map.Entry::getKey)
+ .sorted()
+ .collect(Collectors.toCollection(TreeSet::new));
+ resourceLabels.addAll(trackResourceLabels);
+ logger.info( String.format("Found %s track: labels = %s", resource.getName(), trackResourceLabels));
+ }
+ resourceLabels.forEach(String::intern);
+
+ if (resourceLabels.contains(LabeledVariantAnnotationsData.SNP_LABEL)) {
+ throw new UserException.BadInput(String.format("The resource label \"%s\" is reserved for labeling variant types.",
+ LabeledVariantAnnotationsData.SNP_LABEL));
+ }
+
+ data = new LabeledVariantAnnotationsData(annotationNames, resourceLabels, useASAnnotations);
+
+ vcfWriter = createVCFWriter(outputVCFFile);
+ vcfWriter.writeHeader(constructVCFHeader(data.getSortedLabels()));
+
+ afterOnTraversalStart(); // perform additional validation, set modes in child tools, etc.
+ }
+
+ public void afterOnTraversalStart() {
+ // override
+ }
+
+ @Override
+ protected int numberOfPasses() {
+ return 1;
+ }
+
+ @Override
+ public Object onTraversalSuccess() {
+ return null;
+ }
+
+ // TODO maybe clean up all this Triple and metadata business with a class?
+ static void addExtractedVariantToData(final LabeledVariantAnnotationsData data,
+ final VariantContext variant,
+ final List, VariantType, TreeSet>> metadata) {
+ data.add(variant,
+ metadata.stream().map(Triple::getLeft).collect(Collectors.toList()),
+ metadata.stream().map(Triple::getMiddle).collect(Collectors.toList()),
+ metadata.stream().map(Triple::getRight).collect(Collectors.toList()));
+ }
+
+ void writeExtractedVariantToVCF(final VariantContext variant,
+ final List, VariantType, TreeSet>> metadata) {
+ writeExtractedVariantToVCF(variant,
+ metadata.stream().map(Triple::getLeft).flatMap(List::stream).collect(Collectors.toList()),
+ metadata.stream().map(Triple::getRight).flatMap(Set::stream).collect(Collectors.toSet()));
+ }
+
+ void writeAnnotationsToHDF5() {
+ if (data.size() == 0) {
+ logger.warn("Found no input variants for extraction. This may be because the specified " +
+ "genomic region contains no input variants of the requested type(s) or, if extracting " +
+ "training labels, because none of the input variants were contained in the resource VCFs " +
+ "or no resource VCFs were provided. The annotations HDF5 file will not be generated.");
+ return;
+ }
+ for (final VariantType variantType : variantTypesToExtract) {
+ logger.info(String.format("Extracted annotations for %d variants of type %s.",
+ data.getVariantTypeFlat().stream().mapToInt(t -> t == variantType ? 1 : 0).sum(), variantType));
+ }
+ for (final String label : data.getSortedLabels()) {
+ logger.info(String.format("Extracted annotations for %d variants labeled as %s.",
+ data.isLabelFlat(label).stream().mapToInt(b -> b ? 1 : 0).sum(), label));
+ }
+ logger.info(String.format("Extracted annotations for %s total variants.", data.size()));
+
+ logger.info("Writing annotations...");
+ data.writeHDF5(outputAnnotationsFile, omitAllelesInHDF5);
+ logger.info(String.format("Annotations and metadata written to %s.", outputAnnotationsFile.getAbsolutePath()));
+ }
+
+ /**
+ * Writes a sites-only VCF containing the extracted variants and corresponding labels.
+ */
+ void writeExtractedVariantToVCF(final VariantContext vc,
+ final List altAlleles,
+ final Set labels) {
+ final List alleles = ListUtils.union(Collections.singletonList(vc.getReference()), altAlleles);
+ final VariantContextBuilder builder = new VariantContextBuilder(
+ vc.getSource(), vc.getContig(), vc.getStart(), vc.getEnd(), alleles);
+ labels.forEach(l -> builder.attribute(l, true)); // labels should already be sorted as a TreeSet
+ vcfWriter.add(builder.make());
+ }
+
+ // modified from VQSR code
+ // TODO we're just writing a standard sites-only VCF here, maybe there's a nicer way to do this?
+ VCFHeader constructVCFHeader(final List sortedLabels) {
+ Set hInfo = getDefaultToolVCFHeaderLines();
+ hInfo.addAll(sortedLabels.stream()
+ .map(l -> new VCFInfoHeaderLine(l, 1, VCFHeaderLineType.Flag, String.format(RESOURCE_LABEL_INFO_HEADER_LINE_FORMAT_STRING, l)))
+ .collect(Collectors.toList()));
+ hInfo.add(GATKVCFHeaderLines.getFilterLine(VCFConstants.PASSES_FILTERS_v4));
+ final SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
+ hInfo = VcfUtils.updateHeaderContigLines(hInfo, null, sequenceDictionary, true);
+ return new VCFHeader(hInfo);
+ }
+
+ /**
+ * Performs variant-filter and variant-type checks to determine variants/alleles suitable for extraction, and returns
+ * a corresponding list of metadata. This method should not be overridden, as it is intended to enforce identical
+ * variant-extraction behavior in all child tools. Logic here and below for filtering and determining variant type
+ * was retained from VQSR, but has been heavily refactored.
+ */
+ final List, VariantType, TreeSet>> extractVariantMetadata(final VariantContext vc,
+ final FeatureContext featureContext,
+ final boolean isExtractUnlabeled) {
+ // if variant is filtered, do not consume here
+ if (vc == null || !(ignoreAllFilters || vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()))) {
+ return Collections.emptyList();
+ }
+ if (!useASAnnotations) {
+ // in non-allele-specific mode, get a singleton list of the triple
+ // (list of alt alleles passing variant-type and overlapping-resource checks, variant type, set of labels)
+ final VariantType variantType = VariantType.getVariantType(vc);
+ if (variantTypesToExtract.contains(variantType)) {
+ final TreeSet overlappingResourceLabels = findOverlappingResourceLabels(vc, null, null, featureContext);
+ if (isExtractUnlabeled || !overlappingResourceLabels.isEmpty()) {
+ return Collections.singletonList(Triple.of(vc.getAlternateAlleles(), variantType, overlappingResourceLabels));
+ }
+ }
+ } else {
+ // in allele-specific mode, get a list containing the triples
+ // (singleton list of alt allele, variant type, set of labels)
+ // corresponding to alt alleles that pass variant-type and overlapping-resource checks
+ return vc.getAlternateAlleles().stream()
+ .filter(a -> !GATKVCFConstants.isSpanningDeletion(a))
+ .filter(a -> variantTypesToExtract.contains(VariantType.getVariantType(vc, a)))
+ .map(a -> Triple.of(Collections.singletonList(a), VariantType.getVariantType(vc, a),
+ findOverlappingResourceLabels(vc, vc.getReference(), a, featureContext)))
+ .filter(t -> isExtractUnlabeled || !t.getRight().isEmpty())
+ .collect(Collectors.toList());
+ }
+ // if variant-type and overlapping-resource checks failed, return an empty list
+ return Collections.emptyList();
+ }
+
+ private TreeSet findOverlappingResourceLabels(final VariantContext vc,
+ final Allele refAllele,
+ final Allele altAllele,
+ final FeatureContext featureContext) {
+ final TreeSet overlappingResourceLabels = new TreeSet<>();
+ for (final FeatureInput resource : resources) {
+ final List resourceVCs = featureContext.getValues(resource, featureContext.getInterval().getStart());
+ for (final VariantContext resourceVC : resourceVCs) {
+ if (useASAnnotations && !doAllelesMatch(refAllele, altAllele, resourceVC)) {
+ continue;
+ }
+ if (isValidVariant(vc, resourceVC, !doNotTrustAllPolymorphic)) {
+ resource.getTagAttributes().entrySet().stream()
+ .filter(e -> e.getValue().equals("true"))
+ .map(Map.Entry::getKey)
+ .forEach(overlappingResourceLabels::add);
+ }
+ }
+ }
+ return overlappingResourceLabels;
+ }
+
+ private static boolean isValidVariant(final VariantContext vc,
+ final VariantContext resourceVC,
+ final boolean trustAllPolymorphic) {
+ return resourceVC != null && resourceVC.isNotFiltered() && resourceVC.isVariant() && VariantType.checkVariantType(vc, resourceVC) &&
+ (trustAllPolymorphic || !resourceVC.hasGenotypes() || resourceVC.isPolymorphicInSamples());
+ }
+
+ private static boolean doAllelesMatch(final Allele refAllele,
+ final Allele altAllele,
+ final VariantContext resourceVC) {
+ if (altAllele == null) {
+ return true;
+ }
+ try {
+ return GATKVariantContextUtils.isAlleleInList(refAllele, altAllele, resourceVC.getReference(), resourceVC.getAlternateAlleles());
+ } catch (final IllegalStateException e) {
+ throw new IllegalStateException("Reference allele mismatch at position " + resourceVC.getContig() + ':' + resourceVC.getStart() + " : ", e);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ScoreVariantAnnotations.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ScoreVariantAnnotations.java
new file mode 100644
index 00000000000..33fefe62ad1
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/ScoreVariantAnnotations.java
@@ -0,0 +1,624 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;
+
+import com.google.common.primitives.Doubles;
+import htsjdk.variant.variantcontext.Allele;
+import htsjdk.variant.variantcontext.VariantContext;
+import htsjdk.variant.variantcontext.VariantContextBuilder;
+import htsjdk.variant.vcf.VCFFilterHeaderLine;
+import htsjdk.variant.vcf.VCFHeader;
+import htsjdk.variant.vcf.VCFHeaderLine;
+import htsjdk.variant.vcf.VCFHeaderLineType;
+import htsjdk.variant.vcf.VCFInfoHeaderLine;
+import org.apache.commons.lang3.tuple.Triple;
+import org.broadinstitute.barclay.argparser.Argument;
+import org.broadinstitute.barclay.argparser.BetaFeature;
+import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
+import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
+import org.broadinstitute.hellbender.engine.FeatureContext;
+import org.broadinstitute.hellbender.engine.ReadsContext;
+import org.broadinstitute.hellbender.engine.ReferenceContext;
+import org.broadinstitute.hellbender.exceptions.GATKException;
+import org.broadinstitute.hellbender.exceptions.UserException;
+import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
+import org.broadinstitute.hellbender.tools.copynumber.utils.HDF5Utils;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantRecalibrator;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsScorer;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonSklearnVariantAnnotationsScorer;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModel;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModelBackend;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
+import org.broadinstitute.hellbender.utils.Utils;
+import org.broadinstitute.hellbender.utils.io.IOUtils;
+import org.broadinstitute.hellbender.utils.io.Resource;
+import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
+import picard.cmdline.programgroups.VariantFilteringProgramGroup;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Scores variant calls in a VCF file based on site-level annotations using a previously trained model.
+ *
+ *
+ * This tool is intended to be used as the last step in a variant-filtering workflow that supersedes the
+ * {@link VariantRecalibrator} workflow. Using a previously trained model produced by {@link TrainVariantAnnotationsModel},
+ * this tool assigns a score to each call (with a lower score indicating that a call is more likely to be an artifact).
+ * Each score can also be converted to a corresponding sensitivity to a calibration set, if the latter is available.
+ * Each VCF record can also be annotated with additional resource labels and/or hard filtered based on its
+ * calibration-set sensitivity, if desired.
+ *
+ *
+ *
+ * Note that annotations and metadata are collected in memory during traversal until they are written to HDF5 files
+ * upon completion of the traversal. Memory requirements thus roughly scale linearly with both the number of sites
+ * scored and the number of annotations. For large callsets, this tool may be run in parallel over separate
+ * genomic shards using the {@value StandardArgumentDefinitions#INTERVALS_LONG_NAME} argument as usual.
+ *
+ *
+ *
+ * Scores and annotations are also output to HDF5 files, which may be viewed using
+ * hdfview or loaded in Python using
+ * PyTables or h5py.
+ *
+ *
+ *
Inputs
+ *
+ *
+ *
+ * Input VCF file. Site-level annotations will be extracted from the contained variants (or alleles,
+ * if the {@value USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME} argument is specified).
+ *
+ *
+ * Annotations to use for scoring. These should be identical to those used in the {@link ExtractVariantAnnotations}
+ * step to create the training set.
+ *
+ *
+ * Variant types (i.e., SNP and/or INDEL) to score. Logic for determining variant type was retained from
+ * {@link VariantRecalibrator}; see {@link VariantType}. To use different models for SNPs and INDELs
+ * (e.g., if it is desired to use different sets of annotations for each variant type), one can first run
+ * this tool to score SNPs and then again on the resulting output to score INDELs.
+ *
+ *
+ * Model prefix. This should denote the path of model files produced by {@link TrainVariantAnnotationsModel}.
+ *
+ *
+ * (Optional) Model backend. This should be identical to that specified in {@link TrainVariantAnnotationsModel}.
+ * The default Python IsolationForest implementation requires either the GATK Python environment
+ * or that certain Python packages (argparse, h5py, numpy, sklearn, and dill) are otherwise available.
+ * A custom backend can also be specified in conjunction with the {@value PYTHON_SCRIPT_LONG_NAME} argument.
+ *
+ *
+ * (Optional) Resource VCF file(s). See the corresponding documentation in {@link ExtractVariantAnnotations}.
+ * In typical usage, the same resource VCFs and tags provided to that tool should also be provided here.
+ * In addition, the sites-only VCF that is produced by that tool can also be provided here and used to
+ * mark those labeled sites that were extracted, which can be useful if these are a subset of the resource sites.
+ *
+ *
+ * (Optional) Calibration-set sensitivity thresholds for SNPs and INDELs. If the corresponding SNP or INDEL
+ * calibration-set scores are available in the provided model files, sites that have a calibration-set
+ * sensitivity falling above the corresponding threshold (i.e., a score falling below the corresponding
+ * score threshold) will have a filter applied.
+ *
+ *
+ * Output prefix.
+ * This is used as the basename for output files.
+ *
+ *
+ *
+ *
Outputs
+ *
+ *
+ *
+ * Scored VCF file and index. The VCF will not be gzipped if the {@value DO_NOT_GZIP_VCF_OUTPUT_LONG_NAME}
+ * argument is set to true. The INFO field in each VCF record will be annotated with:
+ *
+ *
+ * 1) a score (with a key as given by the {@value SCORE_KEY_LONG_NAME} argument,
+ * which has a default value of {@value DEFAULT_SCORE_KEY}),
+ *
+ *
+ * 2) if resources are provided, flags corresponding to the labels (e.g.,
+ * {@value LabeledVariantAnnotationsData#TRAINING_LABEL}, {@value LabeledVariantAnnotationsData#CALIBRATION_LABEL}, etc.)
+ * of resources containing the record,
+ *
+ *
+ * 3) if the {@value SNP_KEY_LONG_NAME} argument (which has a default value of {@value DEFAULT_SNP_KEY})
+ * is non-null, a flag corresponding to whether a site is treated as a SNP,
+ *
+ *
+ * 4) if {@value SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME} and/or
+ * {@value INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME} are provided, a filter (with name given by
+ * the {@value LOW_SCORE_FILTER_NAME_LONG_NAME} argument, which has a default value of
+ * {@value DEFAULT_LOW_SCORE_FILTER_NAME}) will be applied if a record has a calibration-set sensitivity
+ * falling above the appropriate threshold (i.e., if it has a score falling below the corresponding
+ * score threshold).
+ *
+ *
+ * If {@value USE_ALLELE_SPECIFIC_ANNOTATIONS_LONG_NAME} is true, the score, SNP flag, calibration sensitivity,
+ * and filter appropriate for the highest scoring allele are used; however, the resource labels for all alleles
+ * are applied.
+ *
+ *
+ *
+ *
+ * (Optional) Annotations HDF5 file (.annot.hdf5). Annotation data and metadata for all scored sites
+ * (labeled and unlabeled) are stored in the HDF5 directory structure given in the documentation for the
+ * {@link ExtractVariantAnnotations} tool. This file will only be produced if the number of scored sites
+ * is nonzero.
+ *
+ *
+ *
+ *
+ * (Optional) Scores HDF5 file (.scores.hdf5). Scores for all scored sites are stored in the
+ * HDF5 path {@value VariantAnnotationsScorer#SCORES_PATH}. Scores are given in the same order as records
+ * in both the VCF and the annotations HDF5 file. This file will only be produced if the number of scored sites
+ * is nonzero.
+ *
+ *
+ *
+ *
+ *
Usage examples
+ *
+ *
+ * Score sites using a model (produced by {@link TrainVariantAnnotationsModel} using the default
+ * {@link VariantAnnotationsModelBackend#PYTHON_IFOREST} model backend and contained in the directory
+ * {@code model_dir}), producing the outputs 1) {@code output.vcf.gz}, 2) {@code output.vcf.gz.tbi},
+ * 3) {@code output.annot.hdf5}, and 4) {@code output.scores.hdf5}. Note that {@code extract.vcf.gz} is
+ * produced by {@link ExtractVariantAnnotations}. Records will be filtered according to the values provided to the
+ * {@value SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME} and {@value INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME}
+ * arguments; the values below are only meant to be illustrative and should be set as appropriate for a given analysis.
+ *
+ *
+ * One may chain together two runs of this tool to score SNPs and INDELs using different models
+ * (note that SNP and INDEL models have "snp" and "indel" tags in their respective filenames, so these
+ * models can still be contained in the same {@code model_dir} directory).
+ * This may have implications for mixed SNP/INDEL sites, especially if filters are applied; see also the
+ * {@value IGNORE_ALL_FILTERS_LONG_NAME} and {@value IGNORE_FILTER_LONG_NAME} arguments.
+ *
+ *
+ * The primary scoring functionality performed by this tool is accomplished by a "scoring backend"
+ * whose fundamental contract is to take an input annotation matrix and to output corresponding scores,
+ * with both input and output given as HDF5 files. Rather than using one of the available, implemented backends,
+ * advanced users may provide their own backend via the {@value PYTHON_SCRIPT_LONG_NAME} argument.
+ * See documentation in the modeling and scoring interfaces ({@link VariantAnnotationsModel} and
+ * {@link VariantAnnotationsScorer}, respectively), as well as the default Python IsolationForest implementation at
+ * org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/isolation-forest.py.
+ *
+ *
+ * DEVELOPER NOTE: See documentation in {@link LabeledVariantAnnotationsWalker}.
+ *
+ * @author Samuel Lee <slee@broadinstitute.org>
+ */
+@CommandLineProgramProperties(
+ summary = "Scores variant calls in a VCF file based on site-level annotations using a previously trained model.",
+ oneLineSummary = "Scores variant calls in a VCF file based on site-level annotations using a previously trained model",
+ programGroup = VariantFilteringProgramGroup.class
+)
+@DocumentedFeature
+@BetaFeature
+public class ScoreVariantAnnotations extends LabeledVariantAnnotationsWalker {
+
+ public static final String MODEL_PREFIX_LONG_NAME = "model-prefix";
+ public static final String MODEL_BACKEND_LONG_NAME = TrainVariantAnnotationsModel.MODEL_BACKEND_LONG_NAME;
+ public static final String PYTHON_SCRIPT_LONG_NAME = "python-script";
+ public static final String SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME = "snp-calibration-sensitivity-threshold";
+ public static final String INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME = "indel-calibration-sensitivity-threshold";
+
+ public static final String SNP_KEY_LONG_NAME = "snp-key";
+ public static final String SCORE_KEY_LONG_NAME = "score-key";
+ public static final String CALIBRATION_SENSITIVITY_KEY_LONG_NAME = "calibration-sensitivity-key";
+ public static final String LOW_SCORE_FILTER_NAME_LONG_NAME = "low-score-filter-name";
+ public static final String DOUBLE_FORMAT_LONG_NAME = "double-format";
+
+ public static final String DEFAULT_SNP_KEY = LabeledVariantAnnotationsData.SNP_LABEL;
+ public static final String DEFAULT_SCORE_KEY = "SCORE";
+ public static final String DEFAULT_CALIBRATION_SENSITIVITY_KEY = "CALIBRATION_SENSITIVITY";
+ public static final String DEFAULT_LOW_SCORE_FILTER_NAME = "LOW_SCORE";
+ public static final String DEFAULT_DOUBLE_FORMAT = "%.4f";
+
+ public static final String SCORES_HDF5_SUFFIX = ".scores.hdf5";
+
+ @Argument(
+ fullName = MODEL_PREFIX_LONG_NAME)
+ private String modelPrefix;
+
+ @Argument(
+ fullName = MODEL_BACKEND_LONG_NAME,
+ doc = "Backend to use for scoring. " +
+ "JAVA_BGMM will use a pure Java implementation (ported from Python scikit-learn) of the Bayesian Gaussian Mixture Model. " +
+ "PYTHON_IFOREST will use the Python scikit-learn implementation of the IsolationForest method and " +
+ "will require that the corresponding Python dependencies are present in the environment. " +
+ "PYTHON_SCRIPT will use the script specified by the " + PYTHON_SCRIPT_LONG_NAME + " argument. " +
+ "See the tool documentation for more details." )
+ private VariantAnnotationsModelBackend modelBackend = VariantAnnotationsModelBackend.PYTHON_IFOREST;
+
+ @Argument(
+ fullName = PYTHON_SCRIPT_LONG_NAME,
+ doc = "Python script used for specifying a custom scoring backend. If provided, " + MODEL_BACKEND_LONG_NAME + " must also be set to PYTHON_SCRIPT.",
+ optional = true)
+ private File pythonScriptFile;
+
+ @Argument(
+ fullName = SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME,
+ doc = "If specified, SNPs with scores corresponding to a calibration sensitivity that is greater than or equal to this threshold will be hard filtered.",
+ optional = true,
+ minValue = 0.,
+ maxValue = 1.)
+ private Double snpCalibrationSensitivityThreshold;
+
+ @Argument(
+ fullName = INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME,
+ doc = "If specified, indels with scores corresponding to a calibration sensitivity that is greater than or equal to this threshold will be hard filtered.",
+ optional = true,
+ minValue = 0.,
+ maxValue = 1.)
+ private Double indelCalibrationSensitivityThreshold;
+
+ @Argument(
+ fullName = SNP_KEY_LONG_NAME,
+ doc = "Annotation flag to use for labeling sites as SNPs in output. " +
+ "Set this to \"null\" to omit these labels.")
+ private String snpKey = DEFAULT_SNP_KEY;
+
+ @Argument(
+ fullName = SCORE_KEY_LONG_NAME,
+ doc = "Annotation key to use for score values in output.")
+ private String scoreKey = DEFAULT_SCORE_KEY;
+
+ @Argument(
+ fullName = CALIBRATION_SENSITIVITY_KEY_LONG_NAME,
+ doc = "Annotation key to use for calibration-sensitivity values in output.")
+ private String calibrationSensitivityKey = DEFAULT_CALIBRATION_SENSITIVITY_KEY;
+
+ @Argument(
+ fullName = LOW_SCORE_FILTER_NAME_LONG_NAME,
+ doc = "Name to use for low-score filter in output.")
+ private String lowScoreFilterName = DEFAULT_LOW_SCORE_FILTER_NAME;
+
+ @Argument(
+ fullName = DOUBLE_FORMAT_LONG_NAME,
+ doc = "Format string to use for formatting score and calibration-sensitivity values in output.")
+ private String doubleFormat = DEFAULT_DOUBLE_FORMAT;
+
+ private File outputScoresFile;
+ private Iterator scoresIterator;
+ private Iterator isSNPIterator;
+
+ private VariantAnnotationsScorer snpScorer;
+ private VariantAnnotationsScorer indelScorer;
+
+ private Function snpCalibrationSensitivityConverter;
+ private Function indelCalibrationSensitivityConverter;
+
+ @Override
+ protected int numberOfPasses() {
+ return 2;
+ }
+
+ @Override
+ public void afterOnTraversalStart() {
+
+ Utils.nonNull(scoreKey);
+ Utils.nonNull(calibrationSensitivityKey);
+ Utils.nonNull(lowScoreFilterName);
+ Utils.nonNull(doubleFormat);
+
+ switch (modelBackend) {
+ case JAVA_BGMM:
+ Utils.validateArg(pythonScriptFile == null,
+ "Python script should not be provided when using JAVA_BGMM backend.");
+ logger.info("Running in JAVA_BGMM mode...");
+ snpScorer = deserializeScorerFromSerFiles(VariantType.SNP);
+ indelScorer = deserializeScorerFromSerFiles(VariantType.INDEL);
+ break;
+ case PYTHON_IFOREST:
+ Utils.validateArg(pythonScriptFile == null,
+ "Python script should not be provided when using PYTHON_IFOREST backend.");
+
+ pythonScriptFile = IOUtils.writeTempResource(new Resource(TrainVariantAnnotationsModel.ISOLATION_FOREST_PYTHON_SCRIPT, TrainVariantAnnotationsModel.class));
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("argparse");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("h5py");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("numpy");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("sklearn");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("dill");
+ logger.info("Running in PYTHON_IFOREST mode...");
+ snpScorer = deserializeScorerFromPklFiles(VariantType.SNP);
+ indelScorer = deserializeScorerFromPklFiles(VariantType.INDEL);
+ break;
+ case PYTHON_SCRIPT:
+ IOUtils.canReadFile(pythonScriptFile);
+ logger.info("Running in PYTHON_SCRIPT mode...");
+ snpScorer = deserializeScorerFromPklFiles(VariantType.SNP);
+ indelScorer = deserializeScorerFromPklFiles(VariantType.INDEL);
+ break;
+ default:
+ throw new GATKException.ShouldNeverReachHereException("Unknown model-backend mode.");
+ }
+
+ if (snpScorer == null && indelScorer == null) {
+ throw new UserException.BadInput(String.format("At least one serialized scorer must be present " +
+ "in the model files with the prefix %s.", modelPrefix));
+ }
+ if (variantTypesToExtract.contains(VariantType.SNP) && snpScorer == null) {
+ throw new UserException.BadInput(String.format("SNPs were indicated for extraction via the %s argument, " +
+ "but no serialized SNP scorer was available in the model files with the prefix.", MODE_LONG_NAME, modelPrefix));
+ }
+ if (variantTypesToExtract.contains(VariantType.INDEL) && indelScorer == null) {
+ throw new UserException.BadInput(String.format("INDELs were indicated for extraction via the %s argument, " +
+ "but no serialized INDEL scorer was available in the model files with the prefix.", MODE_LONG_NAME, modelPrefix));
+ }
+
+ snpCalibrationSensitivityConverter = readCalibrationScoresAndCreateConverter(VariantType.SNP);
+ indelCalibrationSensitivityConverter = readCalibrationScoresAndCreateConverter(VariantType.INDEL);
+
+ if (snpCalibrationSensitivityConverter == null && snpCalibrationSensitivityThreshold != null) {
+ throw new UserException.BadInput(String.format("The %s argument was specified, " +
+ "but no SNP calibration scores were provided in the model files with the prefix %s.",
+ SNP_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, modelPrefix));
+ }
+ if (indelCalibrationSensitivityConverter == null && indelCalibrationSensitivityThreshold != null) {
+ throw new UserException.BadInput(String.format("The %s argument was specified, " +
+ "but no INDEL calibration scores were provided in the model files with the prefix %s.",
+ INDEL_CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, modelPrefix));
+ }
+
+ outputScoresFile = new File(outputPrefix + SCORES_HDF5_SUFFIX);
+
+ // TODO this validation method should perhaps be moved outside of the CNV code
+ CopyNumberArgumentValidationUtils.validateOutputFiles(outputScoresFile);
+ }
+
+ @Override
+ protected void nthPassApply(final VariantContext variant,
+ final ReadsContext readsContext,
+ final ReferenceContext referenceContext,
+ final FeatureContext featureContext,
+ final int n) {
+ final List, VariantType, TreeSet>> metadata = extractVariantMetadata(variant, featureContext, true);
+ final boolean isVariantExtracted = !metadata.isEmpty();
+ if (n == 0 && isVariantExtracted) {
+ addExtractedVariantToData(data, variant, metadata);
+ }
+ if (n == 1) {
+ if (isVariantExtracted) {
+ writeExtractedVariantToVCF(variant, metadata);
+ } else {
+ vcfWriter.add(variant);
+ }
+ }
+ }
+
+ @Override
+ protected void afterNthPass(final int n) {
+ if (n == 0) {
+ // TODO if BGMM, preprocess annotations and write to HDF5 with BGMMVariantAnnotationsScorer.preprocessAnnotationsWithBGMMAndWriteHDF5
+ writeAnnotationsToHDF5();
+ if (data.size() > 0) {
+ data.clear();
+ readAnnotationsAndWriteScoresToHDF5();
+ scoresIterator = Arrays.stream(VariantAnnotationsScorer.readScores(outputScoresFile)).iterator();
+ isSNPIterator = LabeledVariantAnnotationsData.readLabel(outputAnnotationsFile, LabeledVariantAnnotationsData.SNP_LABEL).iterator();
+ } else {
+ scoresIterator = Collections.emptyIterator();
+ isSNPIterator = Collections.emptyIterator();
+ }
+ }
+ if (n == 1) {
+ if (scoresIterator.hasNext()) {
+ throw new IllegalStateException("Traversals of scores and variants " +
+ "(or alleles, in allele-specific mode) were not correctly synchronized.");
+ }
+ if (vcfWriter != null) {
+ vcfWriter.close();
+ }
+ }
+ }
+
+ private VariantAnnotationsScorer deserializeScorerFromPklFiles(final VariantType variantType) {
+ final String variantTypeTag = '.' + variantType.toString().toLowerCase();
+ final File scorerPklFile = new File(
+ modelPrefix + variantTypeTag + PythonSklearnVariantAnnotationsScorer.PYTHON_SCORER_PKL_SUFFIX);
+ final File negativeScorerPklFile = new File(
+ modelPrefix + variantTypeTag + TrainVariantAnnotationsModel.NEGATIVE_TAG + PythonSklearnVariantAnnotationsScorer.PYTHON_SCORER_PKL_SUFFIX);
+ return scorerPklFile.canRead()
+ ? negativeScorerPklFile.canRead()
+ ? VariantAnnotationsScorer.combinePositiveAndNegativeScorer(
+ new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, scorerPklFile),
+ new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, negativeScorerPklFile))
+ : new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, scorerPklFile)
+ : null;
+ }
+
+ private VariantAnnotationsScorer deserializeScorerFromSerFiles(final VariantType variantType) {
+ final String variantTypeTag = '.' + variantType.toString().toLowerCase();
+ final File scorerSerFile = new File(
+ modelPrefix + variantTypeTag + BGMMVariantAnnotationsScorer.BGMM_SCORER_SER_SUFFIX);
+ final File negativeScorerSerFile = new File(
+ modelPrefix + variantTypeTag + TrainVariantAnnotationsModel.NEGATIVE_TAG + BGMMVariantAnnotationsScorer.BGMM_SCORER_SER_SUFFIX);
+ return scorerSerFile.canRead()
+ ? negativeScorerSerFile.canRead()
+ ? VariantAnnotationsScorer.combinePositiveAndNegativeScorer(
+ BGMMVariantAnnotationsScorer.deserialize(scorerSerFile),
+ BGMMVariantAnnotationsScorer.deserialize(negativeScorerSerFile))
+ : BGMMVariantAnnotationsScorer.deserialize(scorerSerFile)
+ : null;
+ }
+
+ private Function readCalibrationScoresAndCreateConverter(final VariantType variantType) {
+ final String variantTypeTag = '.' + variantType.toString().toLowerCase();
+ final File calibrationScores = new File(
+ modelPrefix + variantTypeTag + TrainVariantAnnotationsModel.CALIBRATION_SCORES_HDF5_SUFFIX);
+ return calibrationScores.canRead()
+ ? VariantAnnotationsScorer.createScoreToCalibrationSensitivityConverter(VariantAnnotationsScorer.readScores(calibrationScores))
+ : null;
+ }
+
+ private void readAnnotationsAndWriteScoresToHDF5() {
+ final List annotationNames = LabeledVariantAnnotationsData.readAnnotationNames(outputAnnotationsFile);
+ final List isSNP = LabeledVariantAnnotationsData.readLabel(outputAnnotationsFile, LabeledVariantAnnotationsData.SNP_LABEL);
+ final double[][] allAnnotations = LabeledVariantAnnotationsData.readAnnotations(outputAnnotationsFile);
+ final int numAll = allAnnotations.length;
+ final List allScores = new ArrayList<>(Collections.nCopies(numAll, Double.NaN));
+ if (variantTypesToExtract.contains(VariantType.SNP)) {
+ logger.info("Scoring SNP variants...");
+ scoreVariantTypeAndSetElementsOfAllScores(annotationNames, allAnnotations, isSNP, snpScorer, allScores);
+ }
+ if (variantTypesToExtract.contains(VariantType.INDEL)) {
+ logger.info("Scoring INDEL variants...");
+ final List isIndel = isSNP.stream().map(x -> !x).collect(Collectors.toList());
+ scoreVariantTypeAndSetElementsOfAllScores(annotationNames, allAnnotations, isIndel, indelScorer, allScores);
+ }
+ VariantAnnotationsScorer.writeScores(outputScoresFile, Doubles.toArray(allScores));
+ logger.info(String.format("Scores written to %s.", outputScoresFile.getAbsolutePath()));
+ }
+
+ private static void scoreVariantTypeAndSetElementsOfAllScores(final List annotationNames,
+ final double[][] allAnnotations,
+ final List isVariantType,
+ final VariantAnnotationsScorer variantTypeScorer,
+ final List allScores) {
+ final File variantTypeAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, allAnnotations, isVariantType);
+ final File variantTypeScoresFile = IOUtils.createTempFile("temp", ".scores.hdf5");
+ variantTypeScorer.score(variantTypeAnnotationsFile, variantTypeScoresFile); // TODO we do not fail until here in the case of mismatched annotation names; we could fail earlier
+ final double[] variantTypeScores = VariantAnnotationsScorer.readScores(variantTypeScoresFile);
+ final Iterator variantTypeScoresIterator = Arrays.stream(variantTypeScores).iterator();
+ IntStream.range(0, allScores.size()).filter(isVariantType::get).forEach(i -> allScores.set(i, variantTypeScoresIterator.next()));
+ }
+
+ @Override
+ void writeExtractedVariantToVCF(final VariantContext vc,
+ final List altAlleles,
+ final Set labels) {
+ final VariantContextBuilder builder = new VariantContextBuilder(vc);
+ labels.forEach(l -> builder.attribute(l, true)); // labels should already be sorted as a TreeSet
+
+ final List scores = useASAnnotations
+ ? altAlleles.stream().map(a -> scoresIterator.next()).collect(Collectors.toList())
+ : Collections.singletonList(scoresIterator.next());
+ final double score = Collections.max(scores);
+ final int scoreIndex = scores.indexOf(score);
+ builder.attribute(scoreKey, formatDouble(score));
+
+ final List isSNP = useASAnnotations
+ ? altAlleles.stream().map(a -> isSNPIterator.next()).collect(Collectors.toList())
+ : Collections.singletonList(isSNPIterator.next());
+ final boolean isSNPMax = isSNP.get(scoreIndex);
+
+ if (snpKey != null) {
+ builder.attribute(snpKey, isSNPMax);
+ }
+
+ final Function calibrationSensitivityConverter = isSNPMax ? snpCalibrationSensitivityConverter : indelCalibrationSensitivityConverter;
+ if (calibrationSensitivityConverter != null) {
+ final double calibrationSensitivity = calibrationSensitivityConverter.apply(score);
+ builder.attribute(calibrationSensitivityKey, formatDouble(calibrationSensitivity));
+ final Double calibrationSensitivityThreshold = isSNPMax ? snpCalibrationSensitivityThreshold : indelCalibrationSensitivityThreshold;
+ if (calibrationSensitivityThreshold != null && calibrationSensitivity >= calibrationSensitivityThreshold) {
+ builder.filter(lowScoreFilterName); // TODO does this sufficiently cover the desired behavior when dealing with previously filtered sites, etc.?
+ }
+ }
+
+ vcfWriter.add(builder.make());
+ }
+
+ private String formatDouble(final double x) {
+ return String.format(doubleFormat, x);
+ }
+
+ /**
+ * Copies the header from the input VCF and adds info lines for the score, calibration-sensitivity, and label keys,
+ * as well as the filter line.
+ */
+ @Override
+ VCFHeader constructVCFHeader(final List sortedLabels) {
+ final VCFHeader inputHeader = getHeaderForVariants();
+ final Set inputHeaders = inputHeader.getMetaDataInSortedOrder();
+
+ final Set hInfo = new HashSet<>(inputHeaders);
+ hInfo.add(new VCFInfoHeaderLine(scoreKey, 1, VCFHeaderLineType.Float,
+ "Score according to the model applied by ScoreVariantAnnotations"));
+ hInfo.add(new VCFInfoHeaderLine(calibrationSensitivityKey, 1, VCFHeaderLineType.Float,
+ String.format("Calibration sensitivity corresponding to the value of %s", scoreKey)));
+ hInfo.add(new VCFFilterHeaderLine(lowScoreFilterName, "Low score (corresponding to high calibration sensitivity)"));
+
+ hInfo.addAll(getDefaultToolVCFHeaderLines());
+ if (snpKey != null) {
+ hInfo.add(new VCFInfoHeaderLine(snpKey, 1, VCFHeaderLineType.Flag, "This site was considered a SNP during filtering"));
+ }
+ hInfo.addAll(sortedLabels.stream()
+ .map(l -> new VCFInfoHeaderLine(l, 1, VCFHeaderLineType.Flag, String.format(RESOURCE_LABEL_INFO_HEADER_LINE_FORMAT_STRING, l)))
+ .collect(Collectors.toList()));
+
+ return new VCFHeader(hInfo, inputHeader.getGenotypeSamples());
+ }
+
+ @Override
+ public Object onTraversalSuccess() {
+
+ logger.info(String.format("%s complete.", getClass().getSimpleName()));
+
+ return null;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/TrainVariantAnnotationsModel.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/TrainVariantAnnotationsModel.java
new file mode 100644
index 00000000000..9a8a1c8b845
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/TrainVariantAnnotationsModel.java
@@ -0,0 +1,570 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable;
+
+import com.google.common.collect.Streams;
+import com.google.common.primitives.Doubles;
+import org.apache.commons.math3.stat.descriptive.moment.Variance;
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.broadinstitute.barclay.argparser.Argument;
+import org.broadinstitute.barclay.argparser.BetaFeature;
+import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
+import org.broadinstitute.barclay.help.DocumentedFeature;
+import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
+import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
+import org.broadinstitute.hellbender.exceptions.GATKException;
+import org.broadinstitute.hellbender.exceptions.UserException;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantRecalibrator;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.LabeledVariantAnnotationsData;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data.VariantType;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsModel;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.BGMMVariantAnnotationsScorer;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonSklearnVariantAnnotationsModel;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.PythonSklearnVariantAnnotationsScorer;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModel;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModelBackend;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
+import org.broadinstitute.hellbender.utils.Utils;
+import org.broadinstitute.hellbender.utils.io.IOUtils;
+import org.broadinstitute.hellbender.utils.io.Resource;
+import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
+import picard.cmdline.programgroups.VariantFilteringProgramGroup;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Trains a model for scoring variant calls based on site-level annotations.
+ *
+ *
+ * This tool is intended to be used as the second step in a variant-filtering workflow that supersedes the
+ * {@link VariantRecalibrator} workflow. Given training (and optionally, calibration) sets of site-level annotations
+ * produced by {@link ExtractVariantAnnotations}, this tool can be used to train a model for scoring variant
+ * calls. The outputs of the tool are TODO
+ *
+ *
+ *
+ * The model trained by this tool can in turn be provided along with a VCF file to the {@link ScoreVariantAnnotations}
+ * tool, which assigns a score to each call (with a lower score indicating that a call is more likely to be an artifact
+ * and should perhaps be filtered). Each score can also be converted to a corresponding sensitivity to a
+ * calibration set, if the latter is available.
+ *
+ *
+ *
+ * TODO model definition
+ *
+ *
+ *
+ * TODO calibration-sensitivity conversion, considerations, and comparison to tranche files
+ *
+ *
+ *
+ * TODO positive vs. positive-negative
+ *
+ * *
+ *
+ * TODO IsolationForest section with description of method and hyperparameters
+ *
+ *
+ *
+ * Note that HDF5 files may be viewed using hdfview
+ * or loaded in Python using PyTables or h5py.
+ *
+ *
+ *
Inputs
+ *
+ *
+ *
+ * Labeled-annotations HDF5 file (.annot.hdf5). Annotation data and metadata for labeled sites are stored in the
+ * HDF5 directory structure given in the documentation for the {@link ExtractVariantAnnotations} tool. In typical
+ * usage, both the {@value LabeledVariantAnnotationsData#TRAINING_LABEL} and
+ * {@value LabeledVariantAnnotationsData#CALIBRATION_LABEL} labels would be available for non-empty sets of
+ * sites of the requested variant type.
+ *
+ *
+ * (Optional) Unlabeled-annotations HDF5 file (.unlabeled.annot.hdf5). Annotation data and metadata for
+ * unlabeled sites are stored in the HDF5 directory structure given in the documentation for the
+ * {@link ExtractVariantAnnotations} tool. If provided, a positive-negative modeling approach (similar to
+ * that used in {@link VariantRecalibrator} will be used.
+ *
+ *
+ * Variant types (i.e., SNP and/or INDEL) for which to train models. Logic for determining variant type was retained from
+ * {@link VariantRecalibrator}; see {@link VariantType}. A separate model will be trained for each variant type
+ * and separate sets of outputs with corresponding tags in the filenames (i.e., "snp" or "indel") will be produced.
+ * TODO can run tool twice
+ *
+ *
+ * (Optional) Model backend. The default Python IsolationForest implementation requires either the GATK Python environment
+ * or that certain Python packages (argparse, h5py, numpy, sklearn, and dill) are otherwise available.
+ * A custom backend can also be specified in conjunction with the {@value PYTHON_SCRIPT_LONG_NAME} argument.
+ *
+ *
+ * (Optional) Model hyperparameters JSON file. TODO
+ *
+ *
+ * (Optional) Calibration-set sensitivity threshold. TODO if separate SNP/INDEL thresholds, run tool twice
+ *
+ *
+ * Output prefix.
+ * This is used as the basename for output files.
+ *
+ *
+ *
+ *
Outputs
+ *
+ *
+ *
+ * TODO
+ *
+ *
+ * (Optional) TODO
+ *
+ *
+ *
+ *
Usage examples
+ *
+ *
+ * TODO, positive-only, producing the outputs 1)
+ *
+ *
+ * gatk TrainVariantAnnotationsModel \
+ * TODO
+ *
+ *
+ *
+ *
+ * TODO, positive-negative, producing the outputs 1)
+ *
+ *
+ * gatk TrainVariantAnnotationsModel \
+ * TODO
+ *
+ *
+ *
+ *
Custom modeling/scoring backends (ADVANCED)
+ *
+ *
+ * The primary modeling functionality performed by this tool is accomplished by a "modeling backend"
+ * whose fundamental contract is to take an input HDF5 file containing an annotation matrix for sites of a
+ * single variant type (i.e., SNP or INDEL) and to output a serialized scorer for that variant type.
+ * Rather than using one of the available, implemented backends, advanced users may provide their own backend
+ * via the {@value PYTHON_SCRIPT_LONG_NAME} argument. See documentation in the modeling and scoring interfaces
+ * ({@link VariantAnnotationsModel} and {@link VariantAnnotationsScorer}, respectively), as well as the default
+ * Python IsolationForest implementation at org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/isolation-forest.py.
+ *
+ *
+ *
+ * Extremely advanced users could potentially substitute their own implementation for the entire
+ * {@link TrainVariantAnnotationsModel} tool, while still making use of the up/downstream
+ * {@link ExtractVariantAnnotations} and {@link ScoreVariantAnnotations} tools. To do so, one would additionally
+ * have to implement functionality for subsetting training/calibration sets by variant type,
+ * calling modeling backends as appropriate, and scoring calibration sets.
+ *
+ *
+ * @author Samuel Lee <slee@broadinstitute.org>
+ */
+@CommandLineProgramProperties(
+ summary = "Trains a model for scoring variant calls based on site-level annotations.",
+ oneLineSummary = "Trains a model for scoring variant calls based on site-level annotations",
+ programGroup = VariantFilteringProgramGroup.class
+)
+@DocumentedFeature
+@BetaFeature
+public final class TrainVariantAnnotationsModel extends CommandLineProgram {
+
+ public static final String MODE_LONG_NAME = "mode";
+ public static final String ANNOTATIONS_HDF5_LONG_NAME = "annotations-hdf5";
+ public static final String UNLABELED_ANNOTATIONS_HDF5_LONG_NAME = "unlabeled-annotations-hdf5";
+ public static final String MODEL_BACKEND_LONG_NAME = "model-backend";
+ public static final String PYTHON_SCRIPT_LONG_NAME = "python-script";
+ public static final String HYPERPARAMETERS_JSON_LONG_NAME = "hyperparameters-json";
+ public static final String CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME = "calibration-sensitivity-threshold";
+
+ public static final String ISOLATION_FOREST_PYTHON_SCRIPT = "isolation-forest.py";
+ public static final String ISOLATION_FOREST_HYPERPARAMETERS_JSON = "isolation-forest-hyperparameters.json";
+
+ enum AvailableLabelsMode {
+ POSITIVE_ONLY, POSITIVE_UNLABELED
+ }
+
+ public static final String TRAINING_SCORES_HDF5_SUFFIX = ".trainingScores.hdf5";
+ public static final String CALIBRATION_SCORES_HDF5_SUFFIX = ".calibrationScores.hdf5";
+ public static final String UNLABELED_SCORES_HDF5_SUFFIX = ".unlabeledScores.hdf5";
+ public static final String NEGATIVE_TAG = ".negative";
+
+ @Argument(
+ fullName = ANNOTATIONS_HDF5_LONG_NAME,
+ doc = "HDF5 file containing annotations extracted with ExtractVariantAnnotations.")
+ private File inputAnnotationsFile;
+
+ @Argument(
+ fullName = UNLABELED_ANNOTATIONS_HDF5_LONG_NAME,
+ doc = "HDF5 file containing annotations extracted with ExtractVariantAnnotations. " +
+ "If specified with " + CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME + ", " +
+ "a positive-unlabeled modeling approach will be used; otherwise, a positive-only modeling " +
+ "approach will be used.",
+ optional = true)
+ private File inputUnlabeledAnnotationsFile;
+
+ @Argument(
+ fullName = MODEL_BACKEND_LONG_NAME,
+ doc = "Backend to use for training models. " +
+ "JAVA_BGMM will use a pure Java implementation (ported from Python scikit-learn) of the Bayesian Gaussian Mixture Model. " +
+ "PYTHON_IFOREST will use the Python scikit-learn implementation of the IsolationForest method and " +
+ "will require that the corresponding Python dependencies are present in the environment. " +
+ "PYTHON_SCRIPT will use the script specified by the " + PYTHON_SCRIPT_LONG_NAME + " argument. " +
+ "See the tool documentation for more details.")
+ private VariantAnnotationsModelBackend modelBackend = VariantAnnotationsModelBackend.PYTHON_IFOREST;
+
+ @Argument(
+ fullName = PYTHON_SCRIPT_LONG_NAME,
+ doc = "Python script used for specifying a custom scoring backend. If provided, " + MODEL_BACKEND_LONG_NAME + " must also be set to PYTHON_SCRIPT.",
+ optional = true)
+ private File pythonScriptFile;
+
+ @Argument(
+ fullName = HYPERPARAMETERS_JSON_LONG_NAME,
+ doc = "JSON file containing hyperparameters. Optional if the PYTHON_IFOREST backend is used " +
+ "(if not specified, a default set of hyperparameters will be used); otherwise required.",
+ optional = true)
+ private File hyperparametersJSONFile;
+
+ @Argument(
+ fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME,
+ shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
+ doc = "Output prefix.")
+ private String outputPrefix;
+
+ @Argument(
+ fullName = CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME,
+ doc = "Calibration-sensitivity threshold that determines which sites will be used for training the negative model " +
+ "in the positive-unlabeled modeling approach. " +
+ "Increasing this will decrease the corresponding positive-model score threshold; sites with scores below this score " +
+ "threshold will be used for training the negative model. Thus, this parameter should typically be chosen to " +
+ "be close to 1, so that sites that score highly according to the positive model will not be used to train the negative model. " +
+ "The " + UNLABELED_ANNOTATIONS_HDF5_LONG_NAME + " argument must be specified in conjunction with this argument. " +
+ "If separate thresholds for SNP and INDEL models are desired, run the tool separately for each mode with its respective threshold.",
+ optional = true,
+ minValue = 0.,
+ maxValue = 1.)
+ private Double calibrationSensitivityThreshold;
+
+ @Argument(
+ fullName = MODE_LONG_NAME,
+ doc = "Variant types for which to train models. Duplicate values will be ignored.",
+ minElements = 1)
+ public List variantTypes = new ArrayList<>(Arrays.asList(VariantType.SNP, VariantType.INDEL));
+
+ private AvailableLabelsMode availableLabelsMode;
+
+ @Override
+ protected Object doWork() {
+
+ validateArgumentsAndSetModes();
+
+ logger.info("Starting training...");
+
+ for (final VariantType variantType : VariantType.values()) { // enforces order in which models are trained
+ if (variantTypes.contains(variantType)) {
+ doModelingWorkForVariantType(variantType);
+ }
+ }
+
+ logger.info(String.format("%s complete.", getClass().getSimpleName()));
+
+ return null;
+ }
+
+ private void validateArgumentsAndSetModes() {
+ IOUtils.canReadFile(inputAnnotationsFile);
+
+ Utils.validateArg((inputUnlabeledAnnotationsFile == null) == (calibrationSensitivityThreshold == null),
+ "Unlabeled annotations and calibration-sensitivity threshold must both be unspecified (for positive-only model training) " +
+ "or specified (for positive-unlabeled model training).");
+
+ availableLabelsMode = inputUnlabeledAnnotationsFile != null && calibrationSensitivityThreshold != null
+ ? AvailableLabelsMode.POSITIVE_UNLABELED
+ : AvailableLabelsMode.POSITIVE_ONLY;
+
+ if (inputUnlabeledAnnotationsFile != null) {
+ IOUtils.canReadFile(inputUnlabeledAnnotationsFile);
+ final List annotationNames = LabeledVariantAnnotationsData.readAnnotationNames(inputAnnotationsFile);
+ final List unlabeledAnnotationNames = LabeledVariantAnnotationsData.readAnnotationNames(inputUnlabeledAnnotationsFile);
+ Utils.validateArg(annotationNames.equals(unlabeledAnnotationNames), "Annotation names must be identical for positive and unlabeled annotations.");
+ }
+
+ switch (modelBackend) {
+ case JAVA_BGMM:
+ Utils.validateArg(pythonScriptFile == null,
+ "Python script should not be provided when using JAVA_BGMM backend.");
+ IOUtils.canReadFile(hyperparametersJSONFile);
+ logger.info("Running in JAVA_BGMM mode...");
+ break;
+ case PYTHON_IFOREST:
+ Utils.validateArg(pythonScriptFile == null,
+ "Python script should not be provided when using PYTHON_IFOREST backend.");
+
+ pythonScriptFile = IOUtils.writeTempResource(new Resource(ISOLATION_FOREST_PYTHON_SCRIPT, TrainVariantAnnotationsModel.class));
+ if (hyperparametersJSONFile == null) {
+ hyperparametersJSONFile = IOUtils.writeTempResource(new Resource(ISOLATION_FOREST_HYPERPARAMETERS_JSON, TrainVariantAnnotationsModel.class));
+ }
+ IOUtils.canReadFile(hyperparametersJSONFile);
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("argparse");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("h5py");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("numpy");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("sklearn");
+ PythonScriptExecutor.checkPythonEnvironmentForPackage("dill");
+ logger.info("Running in PYTHON_IFOREST mode...");
+ break;
+ case PYTHON_SCRIPT:
+ IOUtils.canReadFile(pythonScriptFile);
+ IOUtils.canReadFile(hyperparametersJSONFile);
+ logger.info("Running in PYTHON_SCRIPT mode...");
+ break;
+ default:
+ throw new GATKException.ShouldNeverReachHereException("Unknown model-backend mode.");
+ }
+ }
+
+ /**
+ * TODO
+ */
+ private void doModelingWorkForVariantType(final VariantType variantType) {
+ // positive model
+ final List annotationNames = LabeledVariantAnnotationsData.readAnnotationNames(inputAnnotationsFile);
+ final double[][] annotations = LabeledVariantAnnotationsData.readAnnotations(inputAnnotationsFile);
+
+ final List isTraining = LabeledVariantAnnotationsData.readLabel(inputAnnotationsFile, LabeledVariantAnnotationsData.TRAINING_LABEL);
+ final List isCalibration = LabeledVariantAnnotationsData.readLabel(inputAnnotationsFile, LabeledVariantAnnotationsData.CALIBRATION_LABEL);
+ final List isSNP = LabeledVariantAnnotationsData.readLabel(inputAnnotationsFile, LabeledVariantAnnotationsData.SNP_LABEL);
+ final List isVariantType = variantType == VariantType.SNP ? isSNP : isSNP.stream().map(x -> !x).collect(Collectors.toList());
+
+ final List isTrainingAndVariantType = Streams.zip(isTraining.stream(), isVariantType.stream(), (a, b) -> a && b).collect(Collectors.toList());
+ final int numTrainingAndVariantType = numPassingFilter(isTrainingAndVariantType);
+
+ final String variantTypeString = variantType.toString();
+ final String outputPrefixTag = '.' + variantType.toString().toLowerCase();
+
+ if (numTrainingAndVariantType > 0) {
+ logger.info(String.format("Training %s model with %d training sites x %d annotations %s...",
+ variantTypeString, numTrainingAndVariantType, annotationNames.size(), annotationNames));
+ final File labeledTrainingAndVariantTypeAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isTrainingAndVariantType);
+ trainAndSerializeModel(labeledTrainingAndVariantTypeAnnotationsFile, outputPrefixTag);
+ logger.info(String.format("%s model trained and serialized with output prefix \"%s\".", variantTypeString, outputPrefix + outputPrefixTag));
+
+ if (modelBackend == VariantAnnotationsModelBackend.JAVA_BGMM) {
+ BGMMVariantAnnotationsScorer.preprocessAnnotationsWithBGMMAndWriteHDF5(
+ annotationNames, outputPrefix + outputPrefixTag, labeledTrainingAndVariantTypeAnnotationsFile, logger);
+ }
+
+ logger.info(String.format("Scoring %d %s training sites...", numTrainingAndVariantType, variantTypeString));
+ final File labeledTrainingAndVariantTypeScoresFile = score(labeledTrainingAndVariantTypeAnnotationsFile, outputPrefixTag, TRAINING_SCORES_HDF5_SUFFIX);
+ logger.info(String.format("%s training scores written to %s.", variantTypeString, labeledTrainingAndVariantTypeScoresFile.getAbsolutePath()));
+
+ final List isLabeledCalibrationAndVariantType = Streams.zip(isCalibration.stream(), isVariantType.stream(), (a, b) -> a && b).collect(Collectors.toList());
+ final int numLabeledCalibrationAndVariantType = numPassingFilter(isLabeledCalibrationAndVariantType);
+ if (numLabeledCalibrationAndVariantType > 0) {
+ logger.info(String.format("Scoring %d %s calibration sites...", numLabeledCalibrationAndVariantType, variantTypeString));
+ final File labeledCalibrationAndVariantTypeAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isLabeledCalibrationAndVariantType);
+ final File labeledCalibrationAndVariantTypeScoresFile = score(labeledCalibrationAndVariantTypeAnnotationsFile, outputPrefixTag, CALIBRATION_SCORES_HDF5_SUFFIX);
+ logger.info(String.format("%s calibration scores written to %s.", variantTypeString, labeledCalibrationAndVariantTypeScoresFile.getAbsolutePath()));
+ } else {
+ logger.warn(String.format("No %s calibration sites were available.", variantTypeString));
+ }
+
+ // negative model
+ if (availableLabelsMode == AvailableLabelsMode.POSITIVE_UNLABELED) {
+ final double[][] unlabeledAnnotations = LabeledVariantAnnotationsData.readAnnotations(inputUnlabeledAnnotationsFile);
+ final List unlabeledIsSNP = LabeledVariantAnnotationsData.readLabel(inputUnlabeledAnnotationsFile, "snp");
+ final List isUnlabeledVariantType = variantType == VariantType.SNP ? unlabeledIsSNP : unlabeledIsSNP.stream().map(x -> !x).collect(Collectors.toList());
+
+ final int numUnlabeledVariantType = numPassingFilter(isUnlabeledVariantType);
+
+ if (numUnlabeledVariantType > 0) {
+ final File labeledCalibrationAndVariantTypeScoresFile = new File(outputPrefix + outputPrefixTag + CALIBRATION_SCORES_HDF5_SUFFIX);
+ final double[] labeledCalibrationAndVariantTypeScores = VariantAnnotationsScorer.readScores(labeledCalibrationAndVariantTypeScoresFile);
+ final double scoreThreshold = calibrationSensitivityThreshold == 1. // Percentile requires quantile > 0, so we treat this as a special case
+ ? Doubles.min(labeledCalibrationAndVariantTypeScores)
+ : new Percentile(100. * (1. - calibrationSensitivityThreshold)).evaluate(labeledCalibrationAndVariantTypeScores);
+ logger.info(String.format("Using %s score threshold of %.4f corresponding to specified calibration-sensitivity threshold of %.4f ...",
+ variantTypeString, scoreThreshold, calibrationSensitivityThreshold));
+
+ final double[] labeledTrainingAndVariantTypeScores = VariantAnnotationsScorer.readScores(labeledTrainingAndVariantTypeScoresFile);
+ final List isNegativeTrainingFromLabeledTrainingAndVariantType = Arrays.stream(labeledTrainingAndVariantTypeScores).boxed().map(s -> s < scoreThreshold).collect(Collectors.toList());
+ final int numNegativeTrainingFromLabeledTrainingAndVariantType = numPassingFilter(isNegativeTrainingFromLabeledTrainingAndVariantType);
+ logger.info(String.format("Selected %d labeled %s sites below score threshold of %.4f for negative-model training...",
+ numNegativeTrainingFromLabeledTrainingAndVariantType, variantTypeString, scoreThreshold));
+
+ logger.info(String.format("Scoring %d unlabeled %s sites...", numUnlabeledVariantType, variantTypeString));
+ final File unlabeledVariantTypeAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, unlabeledAnnotations, isUnlabeledVariantType);
+ final File unlabeledVariantTypeScoresFile = score(unlabeledVariantTypeAnnotationsFile, outputPrefixTag, UNLABELED_SCORES_HDF5_SUFFIX);
+ final double[] unlabeledVariantTypeScores = VariantAnnotationsScorer.readScores(unlabeledVariantTypeScoresFile);
+ final List isNegativeTrainingFromUnlabeledVariantType = Arrays.stream(unlabeledVariantTypeScores).boxed().map(s -> s < scoreThreshold).collect(Collectors.toList()); // length matches unlabeledAnnotationsFile
+ final int numNegativeTrainingFromUnlabeledVariantType = numPassingFilter(isNegativeTrainingFromUnlabeledVariantType);
+ logger.info(String.format("Selected %d unlabeled %s sites below score threshold of %.4f for negative-model training...",
+ numNegativeTrainingFromUnlabeledVariantType, variantTypeString, scoreThreshold));
+
+ final double[][] negativeTrainingAndVariantTypeAnnotations = concatenateLabeledAndUnlabeledNegativeTrainingData(
+ annotationNames, annotations, unlabeledAnnotations, isNegativeTrainingFromLabeledTrainingAndVariantType, isNegativeTrainingFromUnlabeledVariantType);
+ final int numNegativeTrainingAndVariantType = negativeTrainingAndVariantTypeAnnotations.length;
+ final List isNegativeTrainingAndVariantType = Collections.nCopies(numNegativeTrainingAndVariantType, true);
+
+ logger.info(String.format("Training %s negative model with %d negative-training sites x %d annotations %s...",
+ variantTypeString, numNegativeTrainingAndVariantType, annotationNames.size(), annotationNames));
+ final File negativeTrainingAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(
+ annotationNames, negativeTrainingAndVariantTypeAnnotations, isNegativeTrainingAndVariantType);
+ trainAndSerializeModel(negativeTrainingAnnotationsFile, outputPrefixTag + NEGATIVE_TAG);
+ logger.info(String.format("%s negative model trained and serialized with output prefix \"%s\".", variantTypeString, outputPrefix + outputPrefixTag + NEGATIVE_TAG));
+
+ if (numLabeledCalibrationAndVariantType > 0) {
+ logger.info(String.format("Re-scoring %d %s calibration sites...", numLabeledCalibrationAndVariantType, variantTypeString));
+ final File labeledCalibrationAnnotationsFile = LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isLabeledCalibrationAndVariantType);
+ final File labeledCalibrationScoresFile = positiveNegativeScore(labeledCalibrationAnnotationsFile, outputPrefixTag, CALIBRATION_SCORES_HDF5_SUFFIX);
+ logger.info(String.format("Calibration scores written to %s.", labeledCalibrationScoresFile.getAbsolutePath()));
+ }
+ } else {
+ throw new UserException.BadInput(String.format("Attempted to train %s negative model, " +
+ "but no suitable sites were found in the provided annotations.", variantTypeString));
+ }
+ }
+ } else {
+ throw new UserException.BadInput(String.format("Attempted to train %s model, " +
+ "but no suitable training sites were found in the provided annotations.", variantTypeString));
+ }
+ }
+
+ private static int numPassingFilter(List isPassing) {
+ return isPassing.stream().mapToInt(x -> x ? 1 : 0).sum();
+ }
+
+ private void trainAndSerializeModel(final File trainingAnnotationsFile,
+ final String outputPrefixTag) {
+ readAndValidateTrainingAnnotations(trainingAnnotationsFile, outputPrefixTag);
+ final VariantAnnotationsModel model;
+ switch (modelBackend) {
+ case JAVA_BGMM:
+ model = new BGMMVariantAnnotationsModel(hyperparametersJSONFile);
+ break;
+ case PYTHON_IFOREST:
+ model = new PythonSklearnVariantAnnotationsModel(pythonScriptFile, hyperparametersJSONFile);
+ break;
+ case PYTHON_SCRIPT:
+ model = new PythonSklearnVariantAnnotationsModel(pythonScriptFile, hyperparametersJSONFile);
+ break;
+ default:
+ throw new GATKException.ShouldNeverReachHereException("Unknown model mode.");
+ }
+ model.trainAndSerialize(trainingAnnotationsFile, outputPrefix + outputPrefixTag);
+ }
+
+ /**
+ * When training models on data that has been subset to a given variant type,
+ * we FAIL if any annotation is completely missing and WARN if any annotation has zero variance.
+ */
+ private void readAndValidateTrainingAnnotations(final File trainingAnnotationsFile,
+ final String outputPrefixTag) {
+ final List annotationNames = LabeledVariantAnnotationsData.readAnnotationNames(trainingAnnotationsFile);
+ final double[][] annotations = LabeledVariantAnnotationsData.readAnnotations(trainingAnnotationsFile);
+
+ // these checks are redundant, but we err on the side of robustness
+ final int numAnnotationNames = annotationNames.size();
+ final int numData = annotations.length;
+ Utils.validateArg(numAnnotationNames > 0, "Number of annotation names must be positive.");
+ Utils.validateArg(numData > 0, "Number of data points must be positive.");
+ final int numFeatures = annotations[0].length;
+ Utils.validateArg(numAnnotationNames == numFeatures,
+ "Number of annotation names must match the number of features in the annotation data.");
+
+ final List completelyMissingAnnotationNames = new ArrayList<>(numFeatures);
+ IntStream.range(0, numFeatures).forEach(
+ i -> {
+ if (new Variance().evaluate(IntStream.range(0, numData).mapToDouble(n -> annotations[n][i]).toArray()) == 0.) {
+ logger.warn(String.format("All values of the annotation %s are identical in the training data for the %s model.",
+ annotationNames.get(i), outputPrefix + outputPrefixTag));
+ }
+ if (IntStream.range(0, numData).boxed().map(n -> annotations[n][i]).allMatch(x -> Double.isNaN(x))) {
+ completelyMissingAnnotationNames.add(annotationNames.get(i));
+ }
+ }
+ );
+
+ if (!completelyMissingAnnotationNames.isEmpty()) {
+ throw new UserException.BadInput(
+ String.format("All values of the following annotations are missing in the training data for the %s model: %s. " +
+ "Consider repeating the extraction step with this annotation dropped. " +
+ "If this is a negative model and the amount of negative training data is small, " +
+ "perhaps also consider lowering the value of the %s argument so that more " +
+ "training data is considered, which may ultimately admit data with non-missing values for the annotation " +
+ "(although note that this will also have implications for the resulting model fit); " +
+ "alternatively, consider excluding the %s and %s arguments and running positive-only modeling.",
+ outputPrefix + outputPrefixTag, completelyMissingAnnotationNames,
+ CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME, UNLABELED_ANNOTATIONS_HDF5_LONG_NAME, CALIBRATION_SENSITIVITY_THRESHOLD_LONG_NAME));
+ }
+ }
+
+ private File score(final File annotationsFile,
+ final String outputPrefixTag,
+ final String outputSuffix) {
+ final VariantAnnotationsScorer scorer;
+ switch (modelBackend) {
+ case JAVA_BGMM:
+ scorer = BGMMVariantAnnotationsScorer.deserialize(new File(outputPrefix + outputPrefixTag + BGMMVariantAnnotationsScorer.BGMM_SCORER_SER_SUFFIX));
+ break;
+ case PYTHON_IFOREST:
+ case PYTHON_SCRIPT:
+ scorer = new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, new File(outputPrefix + outputPrefixTag + PythonSklearnVariantAnnotationsScorer.PYTHON_SCORER_PKL_SUFFIX));
+ break;
+
+ default:
+ throw new GATKException.ShouldNeverReachHereException("Unknown model mode.");
+ }
+ final File outputScoresFile = new File(outputPrefix + outputPrefixTag + outputSuffix);
+ scorer.score(annotationsFile, outputScoresFile);
+ return outputScoresFile;
+ }
+
+ private File positiveNegativeScore(final File annotationsFile,
+ final String outputPrefixTag,
+ final String outputSuffix) {
+ final VariantAnnotationsScorer scorer;
+ switch (modelBackend) {
+ case JAVA_BGMM:
+ scorer = VariantAnnotationsScorer.combinePositiveAndNegativeScorer(
+ BGMMVariantAnnotationsScorer.deserialize(new File(outputPrefix + outputPrefixTag + BGMMVariantAnnotationsScorer.BGMM_SCORER_SER_SUFFIX)),
+ BGMMVariantAnnotationsScorer.deserialize(new File(outputPrefix + outputPrefixTag + NEGATIVE_TAG + BGMMVariantAnnotationsScorer.BGMM_SCORER_SER_SUFFIX)));
+ break;
+ case PYTHON_IFOREST:
+ case PYTHON_SCRIPT:
+ scorer = VariantAnnotationsScorer.combinePositiveAndNegativeScorer(
+ new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, new File(outputPrefix + outputPrefixTag + PythonSklearnVariantAnnotationsScorer.PYTHON_SCORER_PKL_SUFFIX)),
+ new PythonSklearnVariantAnnotationsScorer(pythonScriptFile, new File(outputPrefix + outputPrefixTag + NEGATIVE_TAG + PythonSklearnVariantAnnotationsScorer.PYTHON_SCORER_PKL_SUFFIX)));
+ break;
+ default:
+ throw new GATKException.ShouldNeverReachHereException("Unknown model mode.");
+ }
+ final File outputScoresFile = new File(outputPrefix + outputPrefixTag + outputSuffix);
+ scorer.score(annotationsFile, outputScoresFile);
+ return outputScoresFile;
+ }
+
+ private static double[][] concatenateLabeledAndUnlabeledNegativeTrainingData(final List annotationNames,
+ final double[][] annotations,
+ final double[][] unlabeledAnnotations,
+ final List isNegativeTrainingFromLabeledTrainingAndVariantType,
+ final List isNegativeTrainingFromUnlabeledVariantType) {
+ final File negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile =
+ LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, annotations, isNegativeTrainingFromLabeledTrainingAndVariantType);
+ final double[][] negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromLabeledTrainingAndVariantTypeAnnotationsFile);
+
+ final File negativeTrainingFromUnlabeledVariantTypeAnnotationsFile =
+ LabeledVariantAnnotationsData.subsetAnnotationsToTemporaryFile(annotationNames, unlabeledAnnotations, isNegativeTrainingFromUnlabeledVariantType);
+ final double[][] negativeTrainingFromUnlabeledVariantTypeAnnotations = LabeledVariantAnnotationsData.readAnnotations(negativeTrainingFromUnlabeledVariantTypeAnnotationsFile);
+
+ return Streams.concat(
+ Arrays.stream(negativeTrainingFromLabeledTrainingAndVariantTypeAnnotations),
+ Arrays.stream(negativeTrainingFromUnlabeledVariantTypeAnnotations)).toArray(double[][]::new);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsData.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsData.java
new file mode 100644
index 00000000000..2abd7fce48b
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsData.java
@@ -0,0 +1,284 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data;
+
+import com.google.common.collect.ImmutableList;
+import htsjdk.variant.variantcontext.Allele;
+import htsjdk.variant.variantcontext.VariantContext;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.broadinstitute.hdf5.HDF5File;
+import org.broadinstitute.hdf5.HDF5LibException;
+import org.broadinstitute.hellbender.exceptions.GATKException;
+import org.broadinstitute.hellbender.tools.copynumber.utils.HDF5Utils;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsModel;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.modeling.VariantAnnotationsScorer;
+import org.broadinstitute.hellbender.utils.Utils;
+import org.broadinstitute.hellbender.utils.io.IOUtils;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+/**
+ * Represents a collection of {@link LabeledVariantAnnotationsDatum} as a list of lists of datums.
+ * The outer list is always per-variant. In allele-specific mode, each datum in the inner lists
+ * corresponds to a single allele; otherwise, each inner list trivially contains a single datum corresponding
+ * to the variant.
+ */
+public final class LabeledVariantAnnotationsData {
+ private static final Logger logger = LogManager.getLogger(LabeledVariantAnnotationsData.class);
+
+ // chunk size in temporary annotation files
+ // TODO this could be exposed
+ private static final int CHUNK_DIVISOR = 16;
+ private static final int MAXIMUM_CHUNK_SIZE = HDF5Utils.MAX_NUMBER_OF_VALUES_PER_HDF5_MATRIX / CHUNK_DIVISOR;
+
+ private static final int INITIAL_SIZE = 10_000_000;
+
+ public static final String TRAINING_LABEL = "training";
+ public static final String CALIBRATION_LABEL = "calibration";
+ public static final String SNP_LABEL = "snp";
+
+ public static final String INTERVALS_PATH = "/intervals";
+ public static final String ALLELES_REF_PATH = "/alleles/ref";
+ public static final String ALLELES_ALT_PATH = "/alleles/alt";
+ public static final String ANNOTATIONS_NAMES_PATH = "/annotations/names";
+ public static final String ANNOTATIONS_PATH = "/annotations";
+ public static final String LABELS_PATH = "/labels";
+ public static final String LABELS_SNP_PATH = LABELS_PATH + "/snp";
+
+ private final List sortedAnnotationNames;
+ final List sortedLabels;
+
+ private final List> data;
+ private final boolean useASAnnotations;
+
+ public LabeledVariantAnnotationsData(final Collection annotationNames,
+ final Collection labels,
+ final boolean useASAnnotations,
+ final int initialSize) {
+ data = new ArrayList<>(initialSize);
+ sortedAnnotationNames = ImmutableList.copyOf(annotationNames.stream().distinct().sorted().collect(Collectors.toList()));
+ Utils.validateArg(sortedAnnotationNames.size() > 0, "Number of annotation names must be positive.");
+ if (sortedAnnotationNames.size() != annotationNames.size()) {
+ logger.warn(String.format("Ignoring duplicate annotations: %s.", Utils.getDuplicatedItems(annotationNames)));
+ }
+ sortedLabels = ImmutableList.copyOf(labels.stream().distinct().sorted().collect(Collectors.toList()));
+ if (sortedLabels.size() != labels.size()) {
+ logger.warn(String.format("Ignoring duplicate labels: %s.", Utils.getDuplicatedItems(labels)));
+ }
+ this.useASAnnotations = useASAnnotations;
+ }
+
+ public LabeledVariantAnnotationsData(final Collection annotationNames,
+ final Collection labels,
+ final boolean useASAnnotations) {
+ this(annotationNames, labels, useASAnnotations, INITIAL_SIZE);
+ }
+
+ public List getSortedAnnotationNames() {
+ return sortedAnnotationNames;
+ }
+
+ public List getSortedLabels() {
+ return sortedLabels;
+ }
+
+ public int size() {
+ return data.size();
+ }
+
+ public void clear() {
+ data.clear();
+ }
+
+ /**
+ * Adds an element to the underlying {@link #data} collection.
+ */
+ public void add(final VariantContext vc,
+ final List> altAllelesPerDatum,
+ final List variantTypePerDatum,
+ final List> labelsPerDatum) {
+ if (!useASAnnotations) {
+ data.add(Collections.singletonList(new LabeledVariantAnnotationsDatum(
+ vc, altAllelesPerDatum.get(0), variantTypePerDatum.get(0), labelsPerDatum.get(0), sortedAnnotationNames, useASAnnotations)));
+ } else {
+ data.add(IntStream.range(0, altAllelesPerDatum.size()).boxed()
+ .map(i -> new LabeledVariantAnnotationsDatum(
+ vc, altAllelesPerDatum.get(i), variantTypePerDatum.get(i), labelsPerDatum.get(i), sortedAnnotationNames, useASAnnotations))
+ .collect(Collectors.toList()));
+ }
+ }
+
+ /**
+ * Sets the element at a specified index in the underlying {@link #data} collection.
+ */
+ public void set(final int index,
+ final VariantContext vc,
+ final List> altAllelesPerDatum,
+ final List variantTypePerDatum,
+ final List> labelsPerDatum) {
+ if (!useASAnnotations) {
+ data.set(index, Collections.singletonList(new LabeledVariantAnnotationsDatum(
+ vc, altAllelesPerDatum.get(0), variantTypePerDatum.get(0), labelsPerDatum.get(0), sortedAnnotationNames, useASAnnotations)));
+ } else {
+ data.set(index, IntStream.range(0, altAllelesPerDatum.size()).boxed()
+ .map(i -> new LabeledVariantAnnotationsDatum(
+ vc, altAllelesPerDatum.get(i), variantTypePerDatum.get(i), labelsPerDatum.get(i), sortedAnnotationNames, useASAnnotations))
+ .collect(Collectors.toList()));
+ }
+ }
+
+ /**
+ * @return list of {@link VariantType} indicators, with length given by the number of corresponding sites
+ */
+ public List getVariantTypeFlat() {
+ return streamFlattenedData().map(datum -> datum.variantType).collect(Collectors.toList());
+ }
+
+ /**
+ * @return list of boolean label indicators, with length given by the number of sites;
+ * an element in the list will be true if the corresponding site is assigned to the specified label
+ */
+ public List isLabelFlat(final String label) {
+ return streamFlattenedData().map(datum -> datum.labels.contains(label)).collect(Collectors.toList());
+ }
+
+ private Stream streamFlattenedData() {
+ return data.stream().flatMap(List::stream);
+ }
+
+ /**
+ * Writes a representation of the collection to an HDF5 file with the following directory structure:
+ *
+ *
+ *
+ * Here, each chunk is a double matrix, with dimensions given by (number of sites in the chunk) x (number of annotations).
+ * See the methods {@link HDF5Utils#writeChunkedDoubleMatrix} and {@link HDF5Utils#writeIntervals} for additional details.
+ *
+ * @param omitAllelesInHDF5 string arrays containing ref/alt alleles can be large, so we allow the option of omitting them
+ */
+ public void writeHDF5(final File outputFile,
+ final boolean omitAllelesInHDF5) {
+
+ try (final HDF5File outputHDF5File = new HDF5File(outputFile, HDF5File.OpenMode.CREATE)) {
+ IOUtils.canReadFile(outputHDF5File.getFile());
+ HDF5Utils.writeIntervals(outputHDF5File, INTERVALS_PATH,
+ streamFlattenedData().map(datum -> datum.interval).collect(Collectors.toList()));
+ if (!omitAllelesInHDF5) {
+ outputHDF5File.makeStringArray(ALLELES_REF_PATH,
+ streamFlattenedData().map(datum -> datum.refAllele.getDisplayString()).toArray(String[]::new));
+ if (!useASAnnotations) {
+ outputHDF5File.makeStringArray(ALLELES_ALT_PATH,
+ streamFlattenedData()
+ .map(datum -> datum.altAlleles.stream().map(Allele::getDisplayString).collect(Collectors.joining(",")))
+ .toArray(String[]::new));
+ } else {
+ outputHDF5File.makeStringArray(ALLELES_ALT_PATH,
+ streamFlattenedData().map(datum -> datum.altAlleles.get(0).getDisplayString()).toArray(String[]::new));
+ }
+ }
+ outputHDF5File.makeStringArray(ANNOTATIONS_NAMES_PATH, sortedAnnotationNames.toArray(new String[0]));
+ HDF5Utils.writeChunkedDoubleMatrix(outputHDF5File, ANNOTATIONS_PATH,
+ streamFlattenedData().map(datum -> datum.annotations).toArray(double[][]::new), MAXIMUM_CHUNK_SIZE);
+ outputHDF5File.makeDoubleArray(LABELS_SNP_PATH,
+ streamFlattenedData().mapToDouble(datum -> datum.variantType == VariantType.SNP ? 1 : 0).toArray());
+ for (final String label : sortedLabels) {
+ outputHDF5File.makeDoubleArray(String.format("%s/%s", LABELS_PATH, label),
+ streamFlattenedData().mapToDouble(datum -> datum.labels.contains(label) ? 1 : 0).toArray());
+ }
+ } catch (final HDF5LibException exception) {
+ throw new GATKException(String.format("Exception encountered during writing of annotations and metadata (%s). Output file at %s may be in a bad state.",
+ exception, outputFile.getAbsolutePath()));
+ }
+ }
+
+ /**
+ * @return list of annotation names, with length given by the number of annotations, read from the specified file
+ */
+ public static List readAnnotationNames(final File annotationsFile) {
+ try (final HDF5File annotationsHDF5File = new HDF5File(annotationsFile, HDF5File.OpenMode.READ_ONLY)) {
+ IOUtils.canReadFile(annotationsHDF5File.getFile());
+ return Arrays.asList(annotationsHDF5File.readStringArray(ANNOTATIONS_NAMES_PATH));
+ } catch (final HDF5LibException exception) {
+ throw new GATKException(String.format("Exception encountered during reading of annotation names from %s: %s",
+ annotationsFile.getAbsolutePath(), exception));
+ }
+ }
+
+ /**
+ * @return matrix with dimensions (number of sites) x (number of annotations), read from the specified file
+ */
+ public static double[][] readAnnotations(final File annotationsFile) {
+ try (final HDF5File annotationsHDF5File = new HDF5File(annotationsFile, HDF5File.OpenMode.READ_ONLY)) {
+ IOUtils.canReadFile(annotationsHDF5File.getFile());
+ return HDF5Utils.readChunkedDoubleMatrix(annotationsHDF5File, ANNOTATIONS_PATH);
+ } catch (final HDF5LibException exception) {
+ throw new GATKException(String.format("Exception encountered during reading of annotations from %s: %s",
+ annotationsFile.getAbsolutePath(), exception));
+ }
+ }
+
+ /**
+ * @return list of boolean label indicators, with length given by the number of corresponding sites, read from the specified file;
+ * an element in the list will be true if the corresponding site is assigned to the specified label
+ */
+ public static List readLabel(final File annotationsFile,
+ final String label) {
+ try (final HDF5File annotationsHDF5File = new HDF5File(annotationsFile, HDF5File.OpenMode.READ_ONLY)) {
+ IOUtils.canReadFile(annotationsHDF5File.getFile());
+ return Arrays.stream(annotationsHDF5File.readDoubleArray(String.format("/labels/%s", label))).boxed().map(d -> d == 1).collect(Collectors.toList());
+ } catch (final HDF5LibException exception) {
+ throw new GATKException(String.format("Exception encountered during reading of label %s from %s: %s",
+ label, annotationsFile.getAbsolutePath(), exception));
+ }
+ }
+
+ /**
+ * Subsets annotation data according to a boolean filter and writes a limited representation to a temporary HDF5 file.
+ * Intended for passing annotations via the file interfaces of {@link VariantAnnotationsModel} and {@link VariantAnnotationsScorer}.
+ */
+ public static File subsetAnnotationsToTemporaryFile(final List annotationNames,
+ final double[][] allAnnotations,
+ final List isSubset) {
+ Utils.validateArg(annotationNames.size() > 0, "Number of annotation names must be positive.");
+ Utils.validateArg(allAnnotations.length > 0, "Number of annotation data points must be positive.");
+ Utils.validateArg(annotationNames.size() == allAnnotations[0].length,
+ "Number of annotation names must match number of features in annotation data.");
+ final double[][] subsetData = IntStream.range(0, isSubset.size()).boxed().filter(isSubset::get).map(i -> allAnnotations[i]).toArray(double[][]::new);
+ final File subsetAnnotationsFile = IOUtils.createTempFile("subset.annot", ".hdf5");
+ try (final HDF5File subsetAnnotationsHDF5File = new HDF5File(subsetAnnotationsFile, HDF5File.OpenMode.CREATE)) {
+ subsetAnnotationsHDF5File.makeStringArray(ANNOTATIONS_NAMES_PATH, annotationNames.toArray(new String[0]));
+ HDF5Utils.writeChunkedDoubleMatrix(subsetAnnotationsHDF5File, ANNOTATIONS_PATH, subsetData, MAXIMUM_CHUNK_SIZE);
+ } catch (final HDF5LibException exception) {
+ throw new GATKException(String.format("Exception encountered during writing of annotations (%s). Output file at %s may be in a bad state.",
+ exception, subsetAnnotationsFile.getAbsolutePath()));
+ }
+ return subsetAnnotationsFile;
+ }
+}
diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsDatum.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsDatum.java
new file mode 100644
index 00000000000..884529f5c56
--- /dev/null
+++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/scalable/data/LabeledVariantAnnotationsDatum.java
@@ -0,0 +1,104 @@
+package org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.data;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import htsjdk.samtools.util.Locatable;
+import htsjdk.variant.variantcontext.Allele;
+import htsjdk.variant.variantcontext.VariantContext;
+import org.broadinstitute.hellbender.exceptions.UserException;
+import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.LabeledVariantAnnotationsWalker;
+import org.broadinstitute.hellbender.utils.SimpleInterval;
+import org.broadinstitute.hellbender.utils.Utils;
+import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
+
+import java.util.List;
+import java.util.TreeSet;
+
+/**
+ * Represents metadata and annotations extracted from either a variant or a single alt allele (if in allele-specific mode).
+ * Intended to be package-private and accessed only by {@link LabeledVariantAnnotationsData}.
+ */
+final class LabeledVariantAnnotationsDatum implements Locatable {
+ final SimpleInterval interval;
+ final Allele refAllele;
+ final ImmutableList altAlleles; // in allele-specific mode, this contains a single alt allele; otherwise, it contains all alt alleles that passed variant-type checks
+ final VariantType variantType;
+ final ImmutableSet labels; // sorted TreeSet
+ final double[] annotations; // TODO use ImmutableDoubleArray?
+
+ LabeledVariantAnnotationsDatum(final VariantContext vc,
+ final List altAlleles,
+ final VariantType variantType,
+ final TreeSet labels,
+ final List sortedAnnotationNames,
+ final boolean useASAnnotations) {
+ Utils.validate(!useASAnnotations || altAlleles.size() == 1,
+ "Datum should only be associated with one alt allele in allele-specific mode.");
+ this.interval = new SimpleInterval(vc);
+ this.refAllele = vc.getReference();
+ this.altAlleles = ImmutableList.copyOf(altAlleles);
+ this.variantType = variantType;
+ this.labels = ImmutableSet.copyOf(labels);
+ this.annotations = sortedAnnotationNames.stream()
+ .mapToDouble(a -> decodeAnnotation(vc, altAlleles, a, useASAnnotations))
+ .toArray();
+ }
+
+ @Override
+ public String getContig() {
+ return interval.getContig();
+ }
+
+ @Override
+ public int getStart() {
+ return interval.getStart();
+ }
+
+ @Override
+ public int getEnd() {
+ return interval.getEnd();
+ }
+
+ // code mostly retained from VQSR; some exception catching added
+ private static double decodeAnnotation(final VariantContext vc,
+ final List altAlleles,
+ final String annotationName,
+ final boolean useASAnnotations) {
+ double value;
+ try {
+ // if we're in allele-specific mode and an allele-specific annotation has been requested, parse the appropriate value from the list
+ // TODO: can we trigger allele-specific parsing based on annotation prefix or some other logic?
+ if (useASAnnotations && annotationName.startsWith(GATKVCFConstants.ALLELE_SPECIFIC_PREFIX)) {
+ final List