diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohort.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohort.java new file mode 100644 index 00000000000..d4ac5343a4d --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohort.java @@ -0,0 +1,202 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +import htsjdk.variant.variantcontext.writer.VariantContextWriter; +import htsjdk.variant.vcf.VCFHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup; +import org.broadinstitute.hellbender.engine.GATKTool; +import org.broadinstitute.hellbender.tools.walkers.annotator.Annotation; +import org.broadinstitute.hellbender.tools.walkers.annotator.StandardAnnotation; +import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine; +import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.AS_StandardAnnotation; +import org.broadinstitute.hellbender.utils.bigquery.TableReference; +import org.broadinstitute.hellbender.utils.io.IOUtils; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.*; + + +@CommandLineProgramProperties( + summary = "(\"ExtractCohort\") - Filter and extract arrayvariants out of big query.", + oneLineSummary = "Tool to extract variants out of big query for a subset of samples", + programGroup = ShortVariantDiscoveryProgramGroup.class +) +@DocumentedFeature +public class ArrayExtractCohort extends GATKTool { + private static final Logger logger = LogManager.getLogger(ExtractCohort.class); + public static final int DEFAULT_LOCAL_SORT_MAX_RECORDS_IN_RAM = 1000000; + private VariantContextWriter vcfWriter = null; + private ArrayExtractCohortEngine engine; + + public enum QueryMode { + LOCAL_SORT, + QUERY + } + + @Argument( + shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, + fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, + doc = "Output VCF file to which annotated variants should be written.", + optional = false + ) + private String outputVcfPathString = null; + + @Argument( + fullName = "project-id", + doc = "ID of the Google Cloud project to use when executing queries", + optional = false + ) + private String projectID = null; + + @Argument( + fullName = "sample-info-table", + doc = "Fully qualified name of a bigquery table containing a single column `sample` that describes the full list of samples to evoque", + optional = true + ) + private String sampleTableName = null; + + @Argument( + fullName = "probe-info-table", + doc = "Fully qualified name of a bigquery table containing probe information", + optional = true + ) + private String probeTableName = null; + + @Argument( + fullName = "probe-info-csv", + doc = "Filepath to CSV export of probe-info table", + optional = true +) + private String probeCsvExportFile = null; + + @Argument( + fullName = "cohort-extract-table", + doc = "Fully qualified name of the table where the cohort data exists (already subsetted)", + optional = false + ) + private String cohortTable = null; + + @Argument( + fullName = "print-debug-information", + doc = "If true, print extra debugging output", + optional = true) + private boolean printDebugInformation = false; + + @Argument( + fullName = "local-sort-max-records-in-ram", + doc = "When doing local sort, store at most this many records in memory at once", + optional = true + ) + private int localSortMaxRecordsInRam = DEFAULT_LOCAL_SORT_MAX_RECORDS_IN_RAM; + + @Override + public boolean requiresReference() { + return true; + } + + @Override + public boolean useVariantAnnotations() { return true; } + + @Override + public List> getDefaultVariantAnnotationGroups() { + return Arrays.asList( + StandardAnnotation.class, AS_StandardAnnotation.class + ); + } + + @Override + protected void onStartup() { + super.onStartup(); + + //TODO verify what we really need here + final VariantAnnotatorEngine annotationEngine = new VariantAnnotatorEngine(makeVariantAnnotations(), null, Collections.emptyList(), false, false); + + vcfWriter = createVCFWriter(IOUtils.getPath(outputVcfPathString)); + + Map sampleIdMap = ExtractCohortBQ.getSampleIdMap(sampleTableName, printDebugInformation); + + Collection sampleNames = sampleIdMap.values(); + VCFHeader header = CommonCode.generateRawArrayVcfHeader(new HashSet<>(sampleNames), reference.getSequenceDictionary()); + + Map probeIdMap; + + if (probeCsvExportFile == null) { + probeIdMap = ExtractCohortBQ.getProbeIdMap(probeTableName, printDebugInformation); + } else { + probeIdMap = new HashMap<>(); + String line = ""; + try (BufferedReader br = new BufferedReader(new FileReader(probeCsvExportFile))) { + /// skip the header + br.readLine(); + + while ((line = br.readLine()) != null) { + + // use comma as separator + String[] fields = line.split(","); + //ProbeId,Name,GenomeBuild,Chr,Position,Ref,AlleleA,AlleleB,build37Flag + //6,ilmnseq_rs9651229_F2BT,37,1,567667,,,,PROBE_SEQUENCE_MISMATCH + ProbeInfo p = new ProbeInfo(Long.parseLong(fields[0]), + fields[1], // name + fields[3], // contig + Long.parseLong(fields[4]), // position + fields[5], // ref + fields[6], // alleleA + fields[7]);// alleleB + + probeIdMap.put(p.probeId, p); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + + + + //ChromosomeEnum.setRefVersion(refVersion); + + engine = new ArrayExtractCohortEngine( + projectID, + vcfWriter, + header, + annotationEngine, + reference, + sampleIdMap, + probeIdMap, + cohortTable, + localSortMaxRecordsInRam, + false, + printDebugInformation, + progressMeter); + vcfWriter.writeHeader(header); + } + + @Override + // maybe think about creating a BigQuery Row walker? + public void traverse() { + progressMeter.setRecordsBetweenTimeChecks(100L); + engine.traverse(); + } + + @Override + protected void onShutdown() { + super.onShutdown(); + + if ( engine != null ) { + logger.info(String.format("***Processed %d total sites", engine.getTotalNumberOfSites())); + logger.info(String.format("***Processed %d total variants", engine.getTotalNumberOfVariants())); + } + + // Close up our writer if we have to: + if ( vcfWriter != null ) { + vcfWriter.close(); + } + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohortEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohortEngine.java new file mode 100644 index 00000000000..e2d45e49e34 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ArrayExtractCohortEngine.java @@ -0,0 +1,372 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +import htsjdk.variant.variantcontext.Allele; +import htsjdk.variant.variantcontext.GenotypeBuilder; +import htsjdk.variant.variantcontext.VariantContext; +import htsjdk.variant.variantcontext.VariantContextBuilder; +import htsjdk.variant.variantcontext.writer.VariantContextWriter; +import htsjdk.variant.vcf.VCFHeader; +import org.apache.avro.generic.GenericRecord; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.broadinstitute.hellbender.engine.ProgressMeter; +import org.broadinstitute.hellbender.engine.ReferenceDataSource; +import org.broadinstitute.hellbender.tools.variantdb.RawArrayData.ArrayGenotype; +import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger; +import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine; +import org.broadinstitute.hellbender.utils.SimpleInterval; +import org.broadinstitute.hellbender.utils.bigquery.*; +import org.broadinstitute.hellbender.utils.localsort.SortingCollection; + +import java.text.DecimalFormat; +import java.util.*; +import static org.broadinstitute.hellbender.tools.variantdb.ExtractCohortBQ.*; + + +public class ArrayExtractCohortEngine { + private final DecimalFormat df = new DecimalFormat(); + private final String DOT = "."; + + private static final Logger logger = LogManager.getLogger(ArrayExtractCohortEngine.class); + + private final VariantContextWriter vcfWriter; + + private final boolean useCompressedData; + private final boolean printDebugInformation; + private final int localSortMaxRecordsInRam; + private final TableReference cohortTableRef; + private final ReferenceDataSource refSource; + + private final ProgressMeter progressMeter; + private final String projectID; + + /** List of sample names seen in the variant data from BigQuery. */ + private final Map sampleIdMap; + private final Set sampleNames; + + private final Map probeIdMap; + private final ReferenceConfidenceVariantContextMerger variantContextMerger; + + private int totalNumberOfVariants = 0; + private int totalNumberOfSites = 0; + + public ArrayExtractCohortEngine(final String projectID, + final VariantContextWriter vcfWriter, + final VCFHeader vcfHeader, + final VariantAnnotatorEngine annotationEngine, + final ReferenceDataSource refSource, + final Map sampleIdMap, + final Map probeIdMap, + final String cohortTableName, + final int localSortMaxRecordsInRam, + final boolean useCompressedData, + final boolean printDebugInformation, + final ProgressMeter progressMeter) { + + this.df.setMaximumFractionDigits(3); + this.df.setGroupingSize(0); + + this.localSortMaxRecordsInRam = localSortMaxRecordsInRam; + + this.projectID = projectID; + this.vcfWriter = vcfWriter; + this.refSource = refSource; + this.sampleIdMap = sampleIdMap; + this.sampleNames = new HashSet<>(sampleIdMap.values()); + + this.probeIdMap = probeIdMap; + + this.cohortTableRef = new TableReference(cohortTableName, useCompressedData?SchemaUtils.RAW_ARRAY_COHORT_FIELDS_COMPRESSED:SchemaUtils.RAW_ARRAY_COHORT_FIELDS_UNCOMPRESSED); + + this.useCompressedData = useCompressedData; + this.printDebugInformation = printDebugInformation; + this.progressMeter = progressMeter; + + // TODO: what is the right variant context merger for arrays? + this.variantContextMerger = new ReferenceConfidenceVariantContextMerger(annotationEngine, vcfHeader); + + } + + int getTotalNumberOfVariants() { return totalNumberOfVariants; } + int getTotalNumberOfSites() { return totalNumberOfSites; } + + public void traverse() { + if (printDebugInformation) { + logger.debug("using storage api with local sort"); + } + final StorageAPIAvroReader storageAPIAvroReader = new StorageAPIAvroReader(cohortTableRef); + createVariantsFromUngroupedTableResult(storageAPIAvroReader); + } + + + private void createVariantsFromUngroupedTableResult(final GATKAvroReader avroReader) { + + // stream out the data and sort locally + final org.apache.avro.Schema schema = avroReader.getSchema(); + final Set columnNames = new HashSet<>(); + schema.getFields().forEach(field -> columnNames.add(field.name())); + + Comparator comparator = this.useCompressedData ? COMPRESSED_PROBE_ID_COMPARATOR : UNCOMPRESSED_PROBE_ID_COMPARATOR; + + SortingCollection sortingCollection = getAvroProbeIdSortingCollection(schema, localSortMaxRecordsInRam, comparator); + for ( final GenericRecord queryRow : avroReader ) { + sortingCollection.add(queryRow); + } + + sortingCollection.printTempFileStats(); + + // iterate through records and process them + final List currentPositionRecords = new ArrayList<>(sampleIdMap.size() * 2); + long currentProbeId = -1; + + for ( final GenericRecord sortedRow : sortingCollection ) { + long probeId; + if (useCompressedData) { + final long rawData = (Long) sortedRow.get(SchemaUtils.RAW_ARRAY_DATA_FIELD_NAME); + RawArrayData data = RawArrayData.decode(rawData); + probeId = data.probeId; + } else { + probeId = (Long) sortedRow.get("probe_id"); + } + + if ( probeId != currentProbeId && currentProbeId != -1 ) { + ++totalNumberOfSites; + processSampleRecordsForLocation(currentProbeId, currentPositionRecords, columnNames); + currentPositionRecords.clear(); + } + + currentPositionRecords.add(sortedRow); + currentProbeId = probeId; + } + + if ( ! currentPositionRecords.isEmpty() ) { + ++totalNumberOfSites; + processSampleRecordsForLocation(currentProbeId, currentPositionRecords, columnNames); + } + } + + private void processSampleRecordsForLocation(final long probeId, final Iterable sampleRecordsAtPosition, final Set columnNames) { + final List unmergedCalls = new ArrayList<>(); + final Set currentPositionSamplesSeen = new HashSet<>(); + boolean currentPositionHasVariant = false; + + final ProbeInfo probeInfo = probeIdMap.get(probeId); + if (probeInfo == null) { + throw new RuntimeException("Unable to find probeInfo for " + probeId); + } + + final String contig = probeInfo.contig; + final long position = probeInfo.position; + final Allele refAllele = Allele.create(refSource.queryAndPrefetch(contig, position, position).getBaseString(), true); + + int numRecordsAtPosition = 0; + + for ( final GenericRecord sampleRecord : sampleRecordsAtPosition ) { + final long sampleId = (Long) sampleRecord.get(SchemaUtils.SAMPLE_ID_FIELD_NAME); + + // TODO: handle missing values + String sampleName = sampleIdMap.get((int) sampleId); + currentPositionSamplesSeen.add(sampleName); + + ++numRecordsAtPosition; + + if ( printDebugInformation ) { + logger.info("\t" + contig + ":" + position + ": found record for sample " + sampleName + ": " + sampleRecord); + } + + ++totalNumberOfVariants; + unmergedCalls.add(createVariantContextFromSampleRecord(probeInfo, sampleRecord, columnNames, contig, position, sampleName)); + + } + + if ( printDebugInformation ) { + logger.info(contig + ":" + position + ": processed " + numRecordsAtPosition + " total sample records"); + } + + finalizeCurrentVariant(unmergedCalls, currentPositionSamplesSeen, contig, position, refAllele); + } + + private void finalizeCurrentVariant(final List unmergedCalls, final Set currentVariantSamplesSeen, final String contig, final long start, final Allele refAllele) { + + // TODO: this is where we infer missing data points... once we know what we want to drop + // final Set samplesNotEncountered = Sets.difference(sampleNames, currentVariantSamplesSeen); + // for ( final String missingSample : samplesNotEncountered ) { + // unmergedCalls.add(createRefSiteVariantContext(missingSample, contig, start, refAllele)); + // } + + final VariantContext mergedVC = variantContextMerger.merge( + unmergedCalls, + new SimpleInterval(contig, (int) start, (int) start), + refAllele.getBases()[0], + true, + false, + true); + + + final VariantContext finalVC = mergedVC; + + // TODO: this was commented out... probably need to re-enable +// final VariantContext annotatedVC = enableVariantAnnotator ? +// variantAnnotator.annotateContext(finalizedVC, new FeatureContext(), null, null, a -> true): finalVC; + +// if ( annotatedVC != null ) { +// vcfWriter.add(annotatedVC); +// progressMeter.update(annotatedVC); +// } + + if ( finalVC != null ) { + vcfWriter.add(finalVC); + progressMeter.update(finalVC); + } else { + // TODO should i print a warning here? + vcfWriter.add(mergedVC); + progressMeter.update(mergedVC); + } + } + + private String formatFloatForVcf(final Float value) { + if (value == null || Double.isNaN(value)) { + return DOT; + } + return df.format(value); + } + + private Float getNullableFloatFromDouble(Object d) { + return d == null ? null : (float) ((Double) d).doubleValue(); + } + + private VariantContext createVariantContextFromSampleRecord(final ProbeInfo probeInfo, final GenericRecord sampleRecord, final Set columnNames, final String contig, final long startPosition, final String sample) { + final VariantContextBuilder builder = new VariantContextBuilder(); + final GenotypeBuilder genotypeBuilder = new GenotypeBuilder(); + + builder.chr(contig); + builder.start(startPosition); + builder.id(probeInfo.name); + + final List alleles = new ArrayList<>(); + Allele ref = Allele.create(probeInfo.ref, true); + alleles.add(ref); + + Allele alleleA = Allele.create(probeInfo.alleleA, false); + Allele alleleB = Allele.create(probeInfo.alleleB, false); + + boolean alleleAisRef = probeInfo.ref.equals(probeInfo.alleleA); + boolean alleleBisRef = probeInfo.ref.equals(probeInfo.alleleB); + + if (alleleAisRef) { + alleleA = ref; + } else { + alleles.add(alleleA); + } + + if (alleleBisRef) { + alleleB = ref; + } else { + alleles.add(alleleB); + } + + builder.alleles(alleles); + builder.stop(startPosition + alleles.get(0).length() - 1); + + Float normx; + Float normy; + Float baf; + Float lrr; + List genotypeAlleles = new ArrayList(); + + if (this.useCompressedData) { + final RawArrayData data = RawArrayData.decode((Long) sampleRecord.get(SchemaUtils.RAW_ARRAY_DATA_FIELD_NAME)); + normx = data.normx; + normy = data.normy; + lrr = data.lrr; + baf = data.baf; + + if (data.genotype == ArrayGenotype.AA) { + genotypeAlleles.add(alleleA); + genotypeAlleles.add(alleleA); + } else if (data.genotype == ArrayGenotype.AB) { + genotypeAlleles.add(alleleA); + genotypeAlleles.add(alleleB); + } else if (data.genotype == ArrayGenotype.BB) { + genotypeAlleles.add(alleleB); + genotypeAlleles.add(alleleB); + } else { + genotypeAlleles.add(Allele.NO_CALL); + genotypeAlleles.add(Allele.NO_CALL); + } + } else { + Object gt = sampleRecord.get("GT_encoded"); + ArrayGenotype agt; + if (gt == null || gt.toString().length() == 0) { + genotypeAlleles.add(alleleA); + genotypeAlleles.add(alleleA); + agt = ArrayGenotype.AA; + } else if ("X".equals(gt.toString())) { + genotypeAlleles.add(alleleA); + genotypeAlleles.add(alleleB); + agt = ArrayGenotype.AB; + } else if ("B".equals(gt.toString())) { + genotypeAlleles.add(alleleB); + genotypeAlleles.add(alleleB); + agt = ArrayGenotype.BB; + } else if ("U".equals(gt.toString())) { + genotypeAlleles.add(Allele.NO_CALL); + genotypeAlleles.add(Allele.NO_CALL); + agt = ArrayGenotype.NO_CALL; + } else { + System.out.println("Processing getnotype " + gt.toString()); + throw new RuntimeException(); + } + + // TODO: constantize + try { + normx = getNullableFloatFromDouble(sampleRecord.get("NORMX")); + normy = getNullableFloatFromDouble(sampleRecord.get("NORMY")); + baf = getNullableFloatFromDouble(sampleRecord.get("BAF")); + lrr = getNullableFloatFromDouble(sampleRecord.get("LRR")); + + // Hack to pack and unpack data + RawArrayData d = new RawArrayData(); + d.probeId = (int) probeInfo.probeId; + d.genotype = agt; + d.baf = baf; + d.lrr = lrr; + d.normx = normx; + d.normy = normy; + + long bits = d.encode(); + RawArrayData d2 = RawArrayData.decode(bits); + normx = d2.normx; + normy = d2.normy; + baf = d2.baf; + lrr = d2.lrr; + + } catch (NullPointerException npe) { + System.out.println("NPE on " + sampleRecord); + System.out.println("NPE on BAF " + sampleRecord.get("BAF")); + System.out.println("NPE on LRR " +sampleRecord.get("LRR")); + throw npe; + } + + + } + genotypeBuilder.alleles(genotypeAlleles); + + genotypeBuilder.attribute(CommonCode.NORMX, formatFloatForVcf(normx)); + genotypeBuilder.attribute(CommonCode.NORMY, formatFloatForVcf(normy)); + genotypeBuilder.attribute(CommonCode.BAF, formatFloatForVcf(baf)); + genotypeBuilder.attribute(CommonCode.LRR, formatFloatForVcf(lrr)); + + genotypeBuilder.name(sample); + + builder.genotypes(genotypeBuilder.make()); + + try { + VariantContext vc = builder.make(); + return vc; + } catch (Exception e) { + System.out.println("Error: "+ e.getMessage() + " processing " + sampleRecord + " and ref: " +ref + " PI: " + probeInfo.alleleA + "/" +probeInfo.alleleB + " with ga " + genotypeAlleles + " and alleles " + alleles); + throw e; + } + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/BinaryUtils.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/BinaryUtils.java new file mode 100644 index 00000000000..9684ebacd65 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/BinaryUtils.java @@ -0,0 +1,68 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +public class BinaryUtils { + // Function to extract k bits from p position (0-based) + // and returns the extracted value as integer + static long extractBits(long number, int p, int k) { + // make a bit-mask of the desired number of bits + long mask = ((1L << k) - 1L); + + // shift desired data to be the lowest ordered bits, and apply mask + return (mask) & (number >>> p); + } + + // 0xFF (255) is reserved as NULL + static long encodeTo8Bits(Float e, float minValue, float maxValue) { + if (e == null) { + return 255; + } + + if (e > maxValue) { + e = maxValue; + } + + if (e < minValue) { + e = minValue; + } + + float range = maxValue - minValue; + float n = (e - minValue) / range; + return Math.round(n * 254.0f); + } + + + // 0xFF (255) is reserved as NULL + static Float decodeFrom8Bits(int i, float minValue, float maxValue) { + if (i == 255) { + return null; + } + + float range = maxValue - minValue; + float n = (1.0f / 254.0f) * ((float) i); + return n * range + minValue; + } + + /** + * Converts an long to a 64-bit binary string + * @param number + * The number to convert + * @param groupSize + * The number of bits in a group + * @return + * The 64-bit long bit string + */ + public static String longToBinaryString(long number, int groupSize) { + StringBuilder result = new StringBuilder(); + + for(long i = 63; i >= 0 ; i--) { + long mask = 1L << i; + result.append((number & mask) != 0 ? "1" : "0"); + + if (i % groupSize == 0) + result.append(" "); + } + result.replace(result.length() - 1, result.length(), ""); + + return result.toString(); + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/CommonCode.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/CommonCode.java index 40f28072e7a..a4f4a0aaa34 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/CommonCode.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/CommonCode.java @@ -10,6 +10,27 @@ //TODO rename this or get rid of it. a place holder for now public class CommonCode { + public static final String NORMX = "NORMX"; + public static final String NORMY = "NORMY"; + public static final String BAF = "BAF"; + public static final String LRR = "LRR"; + + + public static VCFHeader generateRawArrayVcfHeader(Set sampleNames, final SAMSequenceDictionary sequenceDictionary) { + final Set lines = new HashSet<>(); + + lines.add(VCFStandardHeaderLines.getFormatLine(VCFConstants.GENOTYPE_KEY)); + lines.add(new VCFFormatHeaderLine(NORMX, 1, VCFHeaderLineType.Float, "Normalized X intensity")); + lines.add(new VCFFormatHeaderLine(NORMY, 1, VCFHeaderLineType.Float, "Normalized Y intensity")); + lines.add(new VCFFormatHeaderLine(BAF, 1, VCFHeaderLineType.Float, "B Allele Frequency")); + lines.add(new VCFFormatHeaderLine(LRR, 1, VCFHeaderLineType.Float, "Log R Ratio")); + + + final VCFHeader header = new VCFHeader(lines, sampleNames); + header.setSequenceDictionary(sequenceDictionary); + + return header; + } public static VCFHeader generateVcfHeader(Set sampleNames,//) { //final Set defaultHeaderLines, final SAMSequenceDictionary sequenceDictionary) { diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortBQ.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortBQ.java index 34ad452a583..cf8b3265a25 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortBQ.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortBQ.java @@ -1,24 +1,39 @@ package org.broadinstitute.hellbender.tools.variantdb; +import com.google.cloud.bigquery.FieldValue; import com.google.cloud.bigquery.FieldValueList; import com.google.cloud.bigquery.TableResult; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.broadinstitute.hellbender.utils.bigquery.BigQueryUtils; import org.broadinstitute.hellbender.utils.bigquery.TableReference; +import org.broadinstitute.hellbender.utils.localsort.AvroSortingCollectionCodec; +import org.broadinstitute.hellbender.utils.localsort.SortingCollection; +import org.apache.avro.generic.GenericRecord; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; public class ExtractCohortBQ { private static final Logger logger = LogManager.getLogger(ExtractCohortBQ.class); public static Set populateSampleNames(TableReference sampleTableRef, boolean printDebugInformation) { - Set results = new HashSet<>(); + String fqSampleTableName = sampleTableRef.getFQTableName(); + return new HashSet(getSampleIdMap(fqSampleTableName, printDebugInformation).values()); + } + + public static Map getSampleIdMap(String fqSampleTableName, boolean printDebugInformation) { + + Map results = new HashMap<>(); // Get the query string: - final String sampleListQueryString = "SELECT " + SchemaUtils.SAMPLE_NAME_FIELD_NAME + " FROM `" + sampleTableRef.getFQTableName() + "`"; - ; + final String sampleListQueryString = + "SELECT " + SchemaUtils.SAMPLE_ID_FIELD_NAME + ", " + SchemaUtils.SAMPLE_NAME_FIELD_NAME + + " FROM `" + fqSampleTableName + "`"; + // Execute the query: final TableResult result = BigQueryUtils.executeQuery(sampleListQueryString); @@ -32,10 +47,84 @@ public static Set populateSampleNames(TableReference sampleTableRef, boo // Add our samples to our map: for (final FieldValueList row : result.iterateAll()) { - results.add(row.get(0).getStringValue()); + results.put((int) row.get(0).getLongValue(), row.get(1).getStringValue()); + } + + return results; + } + + private static String getOptionalString(FieldValue v) { + return (v == null || v.isNull()) ? null : v.getStringValue(); + } + + public static Map getProbeIdMap(String fqProbeTableName, boolean printDebugInformation) { + + Map results = new HashMap<>(); + + // Get the query string: + final String sampleListQueryString = + "SELECT probeId, Name, Chr, Position, Ref, AlleleA, AlleleB" + + " FROM `" + fqProbeTableName + "`"; + + + // Execute the query: + final TableResult result = BigQueryUtils.executeQuery(sampleListQueryString); + + System.out.println("Beginning probe retrieval..."); + for (final FieldValueList row : result.iterateAll()) { + ProbeInfo p = new ProbeInfo(row.get(0).getLongValue(), + getOptionalString(row.get(1)), // name + row.get(2).getStringValue(), // contig + row.get(3).getLongValue(), // position + getOptionalString(row.get(4)), // ref + getOptionalString(row.get(5)), // alleleA + getOptionalString(row.get(6)));// alleleB + + results.put(p.probeId, p); + } + System.out.println("Done probe retrieval..."); return results; } + public static SortingCollection getAvroSortingCollection(org.apache.avro.Schema schema, int localSortMaxRecordsInRam) { + final SortingCollection.Codec sortingCollectionCodec = new AvroSortingCollectionCodec(schema); + final Comparator sortingCollectionComparator = new Comparator() { + @Override + public int compare( GenericRecord o1, GenericRecord o2 ) { + final long firstPosition = Long.parseLong(o1.get(SchemaUtils.LOCATION_FIELD_NAME).toString()); + final long secondPosition = Long.parseLong(o2.get(SchemaUtils.LOCATION_FIELD_NAME).toString()); + + return Long.compare(firstPosition, secondPosition); + } + }; + return SortingCollection.newInstance(GenericRecord.class, sortingCollectionCodec, sortingCollectionComparator, localSortMaxRecordsInRam); + } + + public static SortingCollection getAvroProbeIdSortingCollection(org.apache.avro.Schema schema, int localSortMaxRecordsInRam, Comparator comparator) { + final SortingCollection.Codec sortingCollectionCodec = new AvroSortingCollectionCodec(schema); + return SortingCollection.newInstance(GenericRecord.class, sortingCollectionCodec, comparator, localSortMaxRecordsInRam); + } + + final static Comparator COMPRESSED_PROBE_ID_COMPARATOR = new Comparator() { + @Override + public int compare( GenericRecord o1, GenericRecord o2 ) { + final long firstProbeId = RawArrayData.decode((Long) o1.get(SchemaUtils.RAW_ARRAY_DATA_FIELD_NAME)).probeId; + final long secondProbeId = RawArrayData.decode((Long) o2.get(SchemaUtils.RAW_ARRAY_DATA_FIELD_NAME)).probeId; + + return Long.compare(firstProbeId, secondProbeId); + } + }; + + final static Comparator UNCOMPRESSED_PROBE_ID_COMPARATOR = new Comparator() { + @Override + public int compare( GenericRecord o1, GenericRecord o2 ) { + final long firstProbeId = (Long) o1.get("probe_id"); + final long secondProbeId = (Long) o2.get("probe_id"); + return Long.compare(firstProbeId, secondProbeId); + } + }; + + } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortEngine.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortEngine.java index d557bc6b3a8..2a25c82da9d 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortEngine.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ExtractCohortEngine.java @@ -28,6 +28,7 @@ import java.util.*; import java.util.stream.Collectors; +import static org.broadinstitute.hellbender.tools.variantdb.ExtractCohortBQ.*; public class ExtractCohortEngine { private static final Logger logger = LogManager.getLogger(ExtractCohortEngine.class); @@ -158,20 +159,7 @@ private void createVariantsFromSortedTableResults(final QueryAPIRowReader reader } - public static SortingCollection getAvroSortingCollection(org.apache.avro.Schema schema, int localSortMaxRecordsInRam) { - final SortingCollection.Codec sortingCollectionCodec = new AvroSortingCollectionCodec(schema); - final Comparator sortingCollectionComparator = new Comparator() { - @Override - public int compare( GenericRecord o1, GenericRecord o2 ) { - final long firstPosition = Long.parseLong(o1.get(SchemaUtils.LOCATION_FIELD_NAME).toString()); - final long secondPosition = Long.parseLong(o2.get(SchemaUtils.LOCATION_FIELD_NAME).toString()); - - return Long.compare(firstPosition, secondPosition); - } - }; - return SortingCollection.newInstance(GenericRecord.class, sortingCollectionCodec, sortingCollectionComparator, localSortMaxRecordsInRam); - } - + diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ProbeInfo.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ProbeInfo.java new file mode 100644 index 00000000000..4a4960c35b4 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/ProbeInfo.java @@ -0,0 +1,23 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +public class ProbeInfo { + long probeId; + + String contig; + long position; + String ref; + String alleleA; + String alleleB; + String name; + + public ProbeInfo(long probeId, String name, String contig, long position, String ref, String alleleA, String alleleB) { + this.probeId = probeId; + this.name = name; + this.contig = contig; + this.position = position; + this.ref = ref; + this.alleleA = alleleA; + this.alleleB = alleleB; + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayData.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayData.java new file mode 100644 index 00000000000..cbf2291a129 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayData.java @@ -0,0 +1,76 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +import static org.broadinstitute.hellbender.tools.variantdb.BinaryUtils.*; + +public class RawArrayData { + public static enum ArrayGenotype { + // Order is critical here, the ordinal is the int encoding + AA,AB, BB, NO_CALL + } + + // TODO: turn these all into getters/setters with precision checks (e.g. baf) + int probeId; + ArrayGenotype genotype; + Float normx; + Float normy; + Float baf; + Float lrr; + + static ArrayGenotype decodeGenotype(int i) { + return ArrayGenotype.values()[i]; + } + + static int encodeGenotype(ArrayGenotype g) { + return g.ordinal(); + } + + public static final int LRR_OFFSET = 0; + public static final float LRR_MIN = -28; + public static final float LRR_MAX = 7; + + public static final int BAF_OFFSET = 8; + public static final float BAF_MIN = 0; + public static final float BAF_MAX = 1; + + public static final int NORMX_OFFSET = 16; + public static final float NORMX_MIN = 0; + public static final float NORMX_MAX = 8; + + public static final int NORMY_OFFSET = 24; + public static final float NORMY_MIN = 0; + public static final float NORMY_MAX = 8; + + public static final int GT_OFFSET = 32; + public static final int PROBE_ID_OFFSET = 42; + + // GTC Data Ranges: https://github.com/Illumina/BeadArrayFiles/blob/develop/docs/GTC_File_Format_v5.pdf + public static RawArrayData decode(long bits) { + + RawArrayData data = new RawArrayData(); + data.lrr = decodeFrom8Bits((int) extractBits(bits, LRR_OFFSET, 8), LRR_MIN, LRR_MAX); + data.baf = decodeFrom8Bits((int) extractBits(bits, BAF_OFFSET, 8), BAF_MIN, BAF_MAX); + data.normx = decodeFrom8Bits((int) extractBits(bits, NORMX_OFFSET, 8), NORMX_MIN, NORMX_MAX); + data.normy = decodeFrom8Bits((int) extractBits(bits, NORMY_OFFSET, 8), NORMY_MIN, NORMY_MAX); + data.genotype = decodeGenotype((int) extractBits(bits, GT_OFFSET, 2)); + data.probeId = (int) extractBits(bits, PROBE_ID_OFFSET, 22); + + return data; + } + + public long encode() { + long lrrBits = encodeTo8Bits(this.lrr, LRR_MIN, LRR_MAX); + long bafBits = encodeTo8Bits(this.baf, BAF_MIN, BAF_MAX); + long normxBits = encodeTo8Bits(this.normx, NORMX_MIN, NORMX_MAX); + long normyBits = encodeTo8Bits(this.normy, NORMX_MIN, NORMX_MAX); + long gtBits = (long) encodeGenotype(this.genotype); + + return ( + (lrrBits << LRR_OFFSET) | + (bafBits << BAF_OFFSET) | + (normxBits << NORMX_OFFSET) | + (normyBits << NORMY_OFFSET) | + (gtBits << GT_OFFSET) | + ((long) this.probeId << PROBE_ID_OFFSET ) + ); + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/SchemaUtils.java b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/SchemaUtils.java index d736f500f19..f96e45904ac 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/variantdb/SchemaUtils.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/variantdb/SchemaUtils.java @@ -13,6 +13,8 @@ public class SchemaUtils { public static final String SAMPLE_NAME_FIELD_NAME = "sample_name"; public static final String SAMPLE_ID_FIELD_NAME = "sample_id"; + public static final String RAW_ARRAY_DATA_FIELD_NAME = "raw_array_data"; + // TODO remove this one - we should not have this ambiguous field // public static final String SAMPLE_FIELD_NAME = "sample"; public static final String STATE_FIELD_NAME = "state"; @@ -32,6 +34,11 @@ public class SchemaUtils { public static final List COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, "call_GT", "call_GQ", "call_RGQ"); public static final List ARRAY_COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, "call_GT", "call_GQ"); + + public static final List RAW_ARRAY_COHORT_FIELDS_COMPRESSED = Arrays.asList(SAMPLE_ID_FIELD_NAME, RAW_ARRAY_DATA_FIELD_NAME); + public static final List RAW_ARRAY_COHORT_FIELDS_UNCOMPRESSED = + Arrays.asList(SAMPLE_ID_FIELD_NAME, "probe_id", "GT_encoded","NORMX","NORMY","BAF","LRR"); + public static final List SAMPLE_FIELDS = Arrays.asList(SAMPLE_NAME_FIELD_NAME); public static final List YNG_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME); diff --git a/src/test/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayDataTest.java b/src/test/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayDataTest.java new file mode 100644 index 00000000000..7bdd8fba7fd --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/variantdb/RawArrayDataTest.java @@ -0,0 +1,70 @@ +package org.broadinstitute.hellbender.tools.variantdb; + +//import org.testng.annotations.Test; +import org.testng.annotations.Test; +import static org.testng.Assert.*; + +import org.broadinstitute.hellbender.tools.variantdb.RawArrayData.ArrayGenotype; +import static org.broadinstitute.hellbender.tools.variantdb.BinaryUtils.*; +import static org.broadinstitute.hellbender.tools.variantdb.RawArrayData.*; + + +public final class RawArrayDataTest { + + @Test + public void testEncodeDecode() { + RawArrayData original = new RawArrayData(); + + original.lrr = -0.035f; + original.baf = 0.988f; + original.normx = 0.059f; + original.normy = 0.849f; + original.genotype = ArrayGenotype.AA; + original.probeId = 222324; + + long bits = original.encode(); + RawArrayData r = RawArrayData.decode(bits); + + assertEquals(r.lrr, original.lrr, ( (LRR_MAX - LRR_MIN) / 255.0 )); + assertEquals(r.baf, original.baf, ( (BAF_MAX - BAF_MIN) / 255.0 )); + assertEquals(r.normx, original.normx, ( (NORMX_MAX - NORMX_MIN) / 255.0 )); + assertEquals(r.normy, original.normy, ( (NORMY_MAX - NORMY_MIN) / 255.0 )); + assertEquals(r.genotype, original.genotype); + assertEquals(r.probeId, original.probeId); + } + + @Test + public void testBasicFloat() { + float orig = (1f/255f); + float f = decodeFrom8Bits((int) encodeTo8Bits(orig, 0, 1), 0, 1); + assertEquals(orig, f, ( (1 - 0) / 255.0 )); + } + + @Test + public void testBasicNull() { + Float orig = null; + Float f = decodeFrom8Bits((int) encodeTo8Bits(orig, 0, 1), 0, 1); + assertNull(f); + } + + @Test + public void testBasicMax() { + Float orig = 1.0f; + Float f = decodeFrom8Bits((int) encodeTo8Bits(orig, 0, 1), 0, 1); + assertEquals(orig, f); + } + + @Test + public void testBasicMin() { + Float orig = 0.0f; + Float f = decodeFrom8Bits((int) encodeTo8Bits(orig, 0, 1), 0, 1); + assertEquals(orig, f); + } + + @Test + public void testBasicLrr() { + Float orig = 0.058f; + Float f = decodeFrom8Bits((int) encodeTo8Bits(orig, LRR_MIN, LRR_MAX), LRR_MIN, LRR_MAX); + assertEquals(orig, f, ( (LRR_MAX - LRR_MIN) / 255.0 )); + } +}