Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

presorted avro files, fix performance issue #7635

Merged
merged 3 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/variantstore/wdl/GvsCreateFilterSet.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ workflow GvsCreateFilterSet {
File? excluded_intervals

String output_file_base_name
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_fix_flush_20220105/gatk-package-4.2.0.0-452-gb9496ed-SNAPSHOT-local.jar"
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_extract_perf_20220111/gatk-package-4.2.0.0-455-g40a40bc-SNAPSHOT-local.jar"

File dbsnp_vcf
File dbsnp_vcf_index
Expand Down
2 changes: 1 addition & 1 deletion scripts/variantstore/wdl/GvsExtractCallset.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ workflow GvsExtractCallset {

String output_file_base_name
String? output_gcs_dir
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_fix_flush_20220105/gatk-package-4.2.0.0-452-gb9496ed-SNAPSHOT-local.jar"
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_extract_perf_20220111/gatk-package-4.2.0.0-455-g40a40bc-SNAPSHOT-local.jar"
Int local_disk_for_extract = 150

String fq_samples_to_extract_table = "~{data_project}.~{default_dataset}.~{extract_table_prefix}__SAMPLES"
Expand Down
2 changes: 1 addition & 1 deletion scripts/variantstore/wdl/GvsImportGenomes.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ workflow GvsImportGenomes {
Int batch_size = 1

Int? preemptible_tries
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_fix_flush_20220105/gatk-package-4.2.0.0-452-gb9496ed-SNAPSHOT-local.jar"
File? gatk_override = "gs://broad-dsp-spec-ops/scratch/bigquery-jointcalling/jars/kc_extract_perf_20220111/gatk-package-4.2.0.0-455-g40a40bc-SNAPSHOT-local.jar"
String? docker
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,28 @@ public enum VQSLODFilteringType { GENOTYPE, SITES, NONE }

@Argument(
fullName = "vet-avro-file-name",
doc = "Path to unsorted data from Vet table in Avro format",
doc = "Path to data from Vet table in Avro format",
mutex = {"cohort-extract-table"},
optional = true
)
private GATKPath vetAvroFileName = null;

@Argument(
fullName = "ref-ranges-avro-file-name",
doc = "Path to unsorted data from Vet table in Avro format",
doc = "Path to data from Vet table in Avro format",
mutex = {"cohort-extract-table"},
optional = true
)
private GATKPath refRangesAvroFileName = null;

@Argument(
fullName = "presorted-avro-files",
doc = "Indicates if Avro data is pre-sorted",
mutex = {"cohort-extract-table"},
optional = true
)
private boolean presortedAvroFiles = false;

@Argument(
fullName = "filter-set-name",
doc = "Name in filter_set_name column of filtering table to use. Which training set should be applied in extract.",
Expand Down Expand Up @@ -318,7 +326,8 @@ protected void onStartup() {
emitPLs,
vqslodfilteringType,
excludeFilteredSites,
inferredReferenceState);
inferredReferenceState,
presortedAvroFiles);

vcfWriter.writeHeader(header);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package org.broadinstitute.hellbender.tools.gvs.extract;

import com.google.cloud.bigquery.storage.v1.ReadSession;
import com.google.common.collect.Sets;
import static java.util.stream.Collectors.toList;
import htsjdk.samtools.util.CloseableIterator;

import htsjdk.samtools.util.OverlapDetector;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.GenotypeBuilder;
Expand Down Expand Up @@ -86,7 +85,7 @@ public class ExtractCohortEngine {
private final String filterSetName;

private final GQStateEnum inferredReferenceState;

private final boolean presortedAvroFiles;

public ExtractCohortEngine(final String projectID,
final VariantContextWriter vcfWriter,
Expand Down Expand Up @@ -114,7 +113,8 @@ public ExtractCohortEngine(final String projectID,
final boolean emitPLs,
final ExtractCohort.VQSLODFilteringType VQSLODFilteringType,
final boolean excludeFilteredSites,
final GQStateEnum inferredReferenceState
final GQStateEnum inferredReferenceState,
final boolean presortedAvroFiles
) {
this.localSortMaxRecordsInRam = localSortMaxRecordsInRam;

Expand Down Expand Up @@ -156,6 +156,8 @@ public ExtractCohortEngine(final String projectID,
this.variantContextMerger = new ReferenceConfidenceVariantContextMerger(annotationEngine, vcfHeader);

this.inferredReferenceState = inferredReferenceState;

this.presortedAvroFiles = presortedAvroFiles;
}

int getTotalNumberOfVariants() { return totalNumberOfVariants; }
Expand Down Expand Up @@ -230,10 +232,12 @@ public void traverse() {
throw new GATKException("Can not process cross-contig boundaries for Ranges implementation");
}

SortedSet<Long> sampleIdsToExtract = new TreeSet<>(this.sampleIdToName.keySet());

if (vetRangesFQDataSet != null) {
createVariantsFromUnsortedBigQueryRanges(vetRangesFQDataSet, this.sampleIdToName.keySet(), minLocation, maxLocation, fullVqsLodMap, fullYngMap, siteFilterMap, noVqslodFilteringRequested);
createVariantsFromUnsortedBigQueryRanges(vetRangesFQDataSet, sampleIdsToExtract, minLocation, maxLocation, fullVqsLodMap, fullYngMap, siteFilterMap, noVqslodFilteringRequested);
} else {
createVariantsFromUnsortedAvroRanges(vetAvroFileName, refRangesAvroFileName, this.sampleIdToName.keySet(), minLocation, maxLocation, fullVqsLodMap, fullYngMap, siteFilterMap, noVqslodFilteringRequested);
createVariantsFromUnsortedAvroRanges(vetAvroFileName, refRangesAvroFileName, sampleIdsToExtract, minLocation, maxLocation, fullVqsLodMap, fullYngMap, siteFilterMap, noVqslodFilteringRequested, presortedAvroFiles);
}
} else {
if (cohortTableRef != null) {
Expand Down Expand Up @@ -296,14 +300,17 @@ private SortingCollection<GenericRecord> addToVetSortingCollection(final Sorting
if (intervalsOverlapDetector.overlapsAny(simpleInverval)) {
vbs.setVariant(location);
sortingCollection.add(queryRow);
if (recordsProcessed++ % 1000000 == 0) {
if (++recordsProcessed % 1000000 == 0) {
long endTime = System.currentTimeMillis();
logger.info("Processed " + recordsProcessed + " VET records in " + (endTime - startTime) + " ms");
startTime = endTime;
}
}
}

long endTime = System.currentTimeMillis();
logger.info("Processed " + recordsProcessed + " VET records in " + (endTime - startTime) + " ms");

sortingCollection.printTempFileStats();
return sortingCollection;
}
Expand All @@ -320,13 +327,16 @@ private SortingCollection<GenericRecord> addToRefSortingCollection(final Sorting
sortingCollection.add(queryRow);
}

if (recordsProcessed++ % 1000000 == 0) {
if (++recordsProcessed % 1000000 == 0) {
long endTime = System.currentTimeMillis();
logger.info("Processed " + recordsProcessed + " Reference Ranges records in " + (endTime - startTime) + " ms");
startTime = endTime;
}
}

long endTime = System.currentTimeMillis();
logger.info("Processed " + recordsProcessed + " Reference Ranges records in " + (endTime - startTime) + " ms");

sortingCollection.printTempFileStats();
return sortingCollection;
}
Expand Down Expand Up @@ -880,7 +890,7 @@ private SortingCollection<GenericRecord> createSortedReferenceRangeCollectionFro

private void createVariantsFromUnsortedBigQueryRanges(
final String fqDatasetName,
final Set<Long> sampleIdsToExtract,
final SortedSet<Long> sampleIdsToExtract,
final Long minLocation,
final Long maxLocation,
final HashMap<Long, HashMap<Allele, HashMap<Allele, Double>>> fullVqsLodMap,
Expand Down Expand Up @@ -918,31 +928,44 @@ private void createVariantsFromUnsortedBigQueryRanges(
private void createVariantsFromUnsortedAvroRanges(
final GATKPath vetAvroFileName,
final GATKPath refRangesAvroFileName,
final Set<Long> sampleIdsToExtract,
final SortedSet<Long> sampleIdsToExtract,
final Long minLocation,
final Long maxLocation,
final HashMap<Long, HashMap<Allele, HashMap<Allele, Double>>> fullVqsLodMap,
final HashMap<Long, HashMap<Allele, HashMap<Allele, String>>> fullYngMap,
final HashMap<Long, List<String>> siteFilterMap,
final boolean noVqslodFilteringRequested) {
final boolean noVqslodFilteringRequested,
final boolean presortedAvroFiles) {

final AvroFileReader vetReader = new AvroFileReader(vetAvroFileName);
final AvroFileReader refRangesReader = new AvroFileReader(refRangesAvroFileName);

VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);
Iterable<GenericRecord> sortedVet;
Iterable<GenericRecord> sortedReferenceRange;

if (presortedAvroFiles) {
sortedVet = vetReader;
sortedReferenceRange = refRangesReader;
} else {
VariantBitSet vbs = new VariantBitSet(minLocation, maxLocation);

SortingCollection<GenericRecord> sortedVet = getAvroSortingCollection(vetReader.getSchema(), localSortMaxRecordsInRam);
addToVetSortingCollection(sortedVet, vetReader, vbs);
SortingCollection<GenericRecord> localSortedVet = getAvroSortingCollection(vetReader.getSchema(), localSortMaxRecordsInRam);
addToVetSortingCollection(localSortedVet, vetReader, vbs);

SortingCollection<GenericRecord>sortedReferenceRange = getAvroSortingCollection(refRangesReader.getSchema(), localSortMaxRecordsInRam);
addToRefSortingCollection(sortedReferenceRange, refRangesReader, vbs);
SortingCollection<GenericRecord> localSortedReferenceRange = getAvroSortingCollection(refRangesReader.getSchema(), localSortMaxRecordsInRam);
addToRefSortingCollection(localSortedReferenceRange, refRangesReader, vbs);

sortedVet = localSortedVet;
sortedReferenceRange = localSortedReferenceRange;
}

createVariantsFromSortedRanges(sampleIdsToExtract, sortedVet, sortedReferenceRange, fullVqsLodMap, fullYngMap, siteFilterMap, noVqslodFilteringRequested);

}

private void createVariantsFromSortedRanges(final Set<Long> sampleIdsToExtract,
final SortingCollection<GenericRecord> sortedVet,
SortingCollection<GenericRecord> sortedReferenceRange,
private void createVariantsFromSortedRanges(final SortedSet<Long> sampleIdsToExtract,
final Iterable<GenericRecord> sortedVet,
Iterable<GenericRecord> sortedReferenceRange,
final HashMap<Long, HashMap<Allele, HashMap<Allele, Double>>> fullVqsLodMap,
final HashMap<Long, HashMap<Allele, HashMap<Allele, String>>> fullYngMap,
final HashMap<Long, List<String>> siteFilterMap,
Expand Down Expand Up @@ -971,7 +994,7 @@ private void createVariantsFromSortedRanges(final Set<Long> sampleIdsToExtract,
// NOTE: if OverlapDetector takes too long, try using RegionChecker from tws_sv_local_assembler
final OverlapDetector<SimpleInterval> intervalsOverlapDetector = OverlapDetector.create(traversalIntervals);

CloseableIterator<GenericRecord> sortedReferenceRangeIterator = sortedReferenceRange.iterator();
Iterator<GenericRecord> sortedReferenceRangeIterator = sortedReferenceRange.iterator();

for (final GenericRecord sortedRow : sortedVet) {
final ExtractCohortRecord vetRow = new ExtractCohortRecord(sortedRow);
Expand Down Expand Up @@ -1049,16 +1072,21 @@ private void handlePotentialSpanningDeletion(ExtractCohortRecord vetRow, Map<Lon
}
}

private void processReferenceData(Map<Long, ExtractCohortRecord> currentPositionRecords, CloseableIterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long fromSampleId, long toSampleId, Set<Long> sampleIdsToExtract) {
private void processReferenceData(Map<Long, ExtractCohortRecord> currentPositionRecords, Iterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long fromSampleId, long toSampleId, SortedSet<Long> sampleIdsToExtract) {
// in the case where there are two adjacent samples with variants, this method is called where from is greater than to
// this is ok, there is just no reference data to process but subSet will throw an exception so we handle it with this if block
if (toSampleId >= fromSampleId) {
SortedSet<Long> samples = sampleIdsToExtract.subSet(fromSampleId, toSampleId + 1); // subset is start-inclusive, end-exclusive

List<Long> samples = sampleIdsToExtract.stream().filter(x -> x >= fromSampleId && x <= toSampleId).sorted().collect(toList());
for(Long s : samples) {
ExtractCohortRecord e = processReferenceData(sortedReferenceRangeIterator, referenceCache, location, s);
currentPositionRecords.merge(s, e, this::mergeSampleRecord);
// List<Long> samples = sampleIdsToExtract.stream().filter(x -> x >= fromSampleId && x <= toSampleId).sorted().collect(toList());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this line just be deleted?

for (Long s : samples) {
ExtractCohortRecord e = processReferenceData(sortedReferenceRangeIterator, referenceCache, location, s);
currentPositionRecords.merge(s, e, this::mergeSampleRecord);
}
}
}

private ExtractCohortRecord processReferenceData(CloseableIterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long sampleId) {
private ExtractCohortRecord processReferenceData(Iterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long sampleId) {
String state = processReferenceDataFromCache(referenceCache, location, sampleId);

if (state == null) {
Expand Down Expand Up @@ -1101,7 +1129,7 @@ private String processReferenceDataFromCache(Map<Long, TreeSet<ReferenceRecord>>
}
}

private String processReferenceDataFromStream(CloseableIterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long sampleId) {
private String processReferenceDataFromStream(Iterator<GenericRecord> sortedReferenceRangeIterator, Map<Long, TreeSet<ReferenceRecord>> referenceCache, long location, long sampleId) {
while(sortedReferenceRangeIterator.hasNext()) {
final ReferenceRecord refRow = new ReferenceRecord(sortedReferenceRangeIterator.next());
totalRangeRecords++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ public void setDestructiveIteration(boolean destructiveIteration) {
*/
public void spillToDisk() {
try {
long startTime = System.currentTimeMillis();
Arrays.parallelSort(this.ramRecords, 0, this.numRecordsInRam, this.comparator);
log.info(String.format("%d records in ram sorted in %d ms. ", numRecordsInRam, System.currentTimeMillis() - startTime));

startTime = System.currentTimeMillis();
final Path f = newTempFile();
try (OutputStream os
= tempStreamFactory.wrapTempOutputStream(Files.newOutputStream(f), Defaults.BUFFER_SIZE)) {
Expand All @@ -242,6 +245,7 @@ public void spillToDisk() {
throw new RuntimeIOException("Problem writing temporary file " + f.toUri() +
". Try setting TMP_DIR to a file system with lots of space.", ex);
}
log.info(String.format("%d records in ram spilled to disk in %d ms. ", numRecordsInRam, System.currentTimeMillis() - startTime));

this.numRecordsInRam = 0;
this.files.add(f);
Expand Down Expand Up @@ -464,10 +468,13 @@ class InMemoryIterator implements CloseableIterator<T> {
private int iterationIndex = 0;

InMemoryIterator() {
long startTime = System.currentTimeMillis();
Arrays.parallelSort(SortingCollection.this.ramRecords,
0,
SortingCollection.this.numRecordsInRam,
SortingCollection.this.comparator);
log.info(String.format("%d records in ram sorted in %d ms. ", numRecordsInRam, System.currentTimeMillis() - startTime));

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public void testRemoveAnnotations() {
false,
ExtractCohort.VQSLODFilteringType.NONE,
false,
GQStateEnum.SIXTY
GQStateEnum.SIXTY,
false
);

List<VariantContext> variantContexts = VariantContextTestUtils.getVariantContexts(ORIGINAL_TEST_FILE); // list variantContexts from VCF file
Expand Down