From d7321240dcd2f20624facaffcc51908d6d8cbe59 Mon Sep 17 00:00:00 2001 From: Ilia Tulchinsky Date: Sun, 19 Jul 2015 22:44:34 -0400 Subject: [PATCH] Sharded BAM Writer, merged from dev.branch --- pom.xml | 2 +- .../dataflow/pipelines/ShardedBAMWriting.java | 347 +++++++++++++----- .../genomics/dataflow/readers/bam/BAMIO.java | 20 +- .../readers/bam/ReadBAMTransform.java | 21 +- .../genomics/dataflow/readers/bam/Reader.java | 24 +- .../dataflow/readers/bam/Sharder.java | 9 +- .../dataflow/utils/TruncatedOutputStream.java | 17 +- 7 files changed, 318 insertions(+), 122 deletions(-) diff --git a/pom.xml b/pom.xml index 1bb3004..aad218c 100644 --- a/pom.xml +++ b/pom.xml @@ -208,7 +208,7 @@ com.google.protobuf protobuf-java 3.0.0-alpha-3 - + diff --git a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java index 4b0fb13..79bde24 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java @@ -16,23 +16,26 @@ import com.google.api.services.genomics.model.Read; import com.google.api.services.storage.Storage; import com.google.api.services.storage.Storage.Objects.Compose; +import com.google.api.services.storage.model.Bucket; import com.google.api.services.storage.model.ComposeRequest; import com.google.api.services.storage.model.ComposeRequest.SourceObjects; import com.google.api.services.storage.model.StorageObject; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.DelegateCoder; -import com.google.cloud.dataflow.sdk.coders.SerializableCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.Default; import com.google.cloud.dataflow.sdk.options.Description; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; +import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn; import com.google.cloud.dataflow.sdk.transforms.View; import com.google.cloud.dataflow.sdk.util.GcsUtil; import com.google.cloud.dataflow.sdk.util.Transport; @@ -43,6 +46,7 @@ import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.cloud.genomics.dataflow.readers.bam.BAMIO; +import com.google.cloud.genomics.dataflow.readers.bam.BAMShard; import com.google.cloud.genomics.dataflow.readers.bam.ReadBAMTransform; import com.google.cloud.genomics.dataflow.readers.bam.ReaderOptions; import com.google.cloud.genomics.dataflow.readers.bam.ShardingPolicy; @@ -54,9 +58,11 @@ import com.google.cloud.genomics.utils.Contig; import com.google.cloud.genomics.utils.GenomicsFactory; import com.google.cloud.genomics.utils.ReadUtils; +import com.google.common.base.Stopwatch; import com.google.common.collect.Lists; import htsjdk.samtools.BAMBlockWriter; +import htsjdk.samtools.BAMIndexer; import htsjdk.samtools.SAMFileHeader; import htsjdk.samtools.SAMRecord; import htsjdk.samtools.SAMRecordIterator; @@ -66,19 +72,17 @@ import htsjdk.samtools.util.BlockCompressedStreamConstants; import htsjdk.samtools.util.StringLineReader; -import java.io.BufferedWriter; import java.io.IOException; import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.Serializable; import java.io.StringWriter; -import java.io.Writer; import java.nio.channels.Channels; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.logging.Logger; /** @@ -86,6 +90,7 @@ */ public class ShardedBAMWriting { private static final Logger LOG = Logger.getLogger(ShardedBAMWriting.class.getName()); + private static final int MAX_RETRIES_FOR_WRITING_A_SHARD = 4; private static ShardedBAMWritingOptions options; private static Pipeline p; private static GenomicsFactory.OfflineAuth auth; @@ -97,6 +102,12 @@ public static interface ShardedBAMWritingOptions extends GenomicsDatasetOptions, String getBAMFilePath(); void setBAMFilePath(String filePath); + + @Description("Loci per writing shard") + @Default.Long(10000) + long getLociPerWritingShard(); + + void setLociPerWritingShard(long lociPerShard); } public static void main(String[] args) throws GeneralSecurityException, IOException { @@ -110,11 +121,11 @@ public static void main(String[] args) throws GeneralSecurityException, IOExcept // Register coders DataflowWorkarounds.registerGenomicsCoders(p); DataflowWorkarounds.registerCoder(p, Contig.class, CONTIG_CODER); - // Get contigs + // Process options contigs = Contig.parseContigsFromCommandLine(options.getReferences()); // Get header info final HeaderInfo headerInfo = getHeader(); - + // Get the reads and shard them final PCollection reads = getReadsFromBAMFile(); final PCollection>> shardedReads = ShardReadsTransform.shard(reads); @@ -125,22 +136,22 @@ public static void main(String[] args) throws GeneralSecurityException, IOExcept .to(options.getOutput() + "-result") .named("Write Output Result") .withoutSharding()); - p.run(); + p.run(); } - + public static class HeaderInfo { public SAMFileHeader header; public Contig firstShard; - + public HeaderInfo(SAMFileHeader header, Contig firstShard) { this.header = header; this.firstShard = firstShard; } } - + private static HeaderInfo getHeader() throws IOException { HeaderInfo result = null; - + // Get first contig final ArrayList contigsList = Lists.newArrayList(contigs); if (contigsList.size() <= 0) { @@ -157,7 +168,7 @@ public int compare(Contig o1, Contig o2) { } }); final Contig firstContig = contigsList.get(0); - + // Open and read start of BAM final Storage.Objects storage = Transport.newStorageClient( options @@ -168,23 +179,24 @@ public int compare(Contig o1, Contig o2) { final SamReader samReader = BAMIO .openBAM(storage, options.getBAMFilePath(), ValidationStringency.DEFAULT_STRINGENCY); final SAMFileHeader header = samReader.getFileHeader(); - + LOG.info("Reading first chunk of reads from " + options.getBAMFilePath()); final SAMRecordIterator recordIterator = samReader.query( firstContig.referenceName, (int)firstContig.start + 1, (int)firstContig.end + 1, false); + Contig firstShard = null; - while (recordIterator.hasNext()) { + while (recordIterator.hasNext() && result == null) { SAMRecord record = recordIterator.next(); final int alignmentStart = record.getAlignmentStart(); if (firstShard == null && alignmentStart > firstContig.start && alignmentStart < firstContig.end) { - firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart); + firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart, options.getLociPerWritingShard()); LOG.info("Determined first shard to be " + firstShard); result = new HeaderInfo(header, firstShard); } } recordIterator.close(); samReader.close(); - + if (result == null) { throw new IOException("Did not find reads for the first contig " + firstContig.toString()); } @@ -192,22 +204,30 @@ public int compare(Contig o1, Contig o2) { return result; } + private static final ShardingPolicy READ_SHARDING_POLICY = ShardingPolicy.BYTE_SIZE_POLICY; + /* new ShardingPolicy() { + @Override + public boolean shardBigEnough(BAMShard shard) { + return shard.sizeInLoci() > 10000000; + } + };*/ + private static PCollection getReadsFromBAMFile() throws IOException { LOG.info("Sharded reading of "+ options.getBAMFilePath()); - + final ReaderOptions readerOptions = new ReaderOptions( - ValidationStringency.LENIENT, + ValidationStringency.DEFAULT_STRINGENCY, true); - + return ReadBAMTransform.getReadsFromBAMFilesSharded(p, auth, contigs, readerOptions, options.getBAMFilePath(), - ShardingPolicy.BYTE_SIZE_POLICY); + READ_SHARDING_POLICY); } - - public static class ShardReadsTransform extends PTransform, + + public static class ShardReadsTransform extends PTransform, PCollection>>> { @Override public PCollection>> apply(PCollection reads) { @@ -215,21 +235,43 @@ public PCollection>> apply(PCollection reads) { .apply(ParDo.named("KeyReads").of(new KeyReadsFn())) .apply(GroupByKey.create()); } - + public static PCollection>> shard(PCollection reads) { return (new ShardReadsTransform()).apply(reads); } - } + } public static class KeyReadsFn extends DoFn> { + private Aggregator readCountAggregator; + private Aggregator unmappedReadCountAggregator; + private long lociPerShard; + + public KeyReadsFn() { + readCountAggregator = createAggregator("Keyed reads", new SumIntegerFn()); + unmappedReadCountAggregator = createAggregator("Keyed unmapped reads", new SumIntegerFn()); + } + + @Override + public void startBundle(Context c) { + lociPerShard = c.getPipelineOptions() + .as(ShardedBAMWritingOptions.class) + .getLociPerWritingShard(); + } @Override public void processElement(DoFn>.ProcessContext c) throws Exception { final Read read = c.element(); - c.output(KV.of(shardKeyForRead(read), read)); + c.output( + KV.of( + shardKeyForRead(read, lociPerShard), + read)); + readCountAggregator.addValue(1); + if (isUnmapped(read)) { + unmappedReadCountAggregator.addValue(1); + } } } - + static Coder CONTIG_CODER = DelegateCoder.of( StringUtf8Coder.of(), new DelegateCoder.CodingFunction() { @@ -244,12 +286,12 @@ public Contig apply(String str) throws Exception { return Contig.parseContigsFromCommandLine(str).iterator().next(); } }); - + static final SAMTextHeaderCodec SAM_HEADER_CODEC = new SAMTextHeaderCodec(); static { SAM_HEADER_CODEC.setValidationStringency(ValidationStringency.SILENT); } - + static Coder HEADER_INFO_CODER = DelegateCoder.of( StringUtf8Coder.of(), new DelegateCoder.CodingFunction() { @@ -267,15 +309,24 @@ public HeaderInfo apply(String str) throws Exception { String contigStr = str.substring(0, newLinePos); String headerStr = str.substring(newLinePos + 1); return new HeaderInfo( - SAM_HEADER_CODEC.decode(new StringLineReader(headerStr), + SAM_HEADER_CODEC.decode(new StringLineReader(headerStr), "HEADER_INFO_CODER"), Contig.parseContigsFromCommandLine(contigStr).iterator().next()); } }); - static final long LOCI_PER_SHARD = 10000; - - static Contig shardKeyForRead(Read read) { + static boolean isUnmapped(Read read) { + if (read.getAlignment() == null || read.getAlignment().getPosition() == null) { + return true; + } + final String reference = read.getAlignment().getPosition().getReferenceName(); + if (reference == null || reference.isEmpty() || reference.equals("*")) { + return true; + } + return false; + } + + static Contig shardKeyForRead(Read read, long lociPerShard) { String referenceName = null; Long alignmentStart = null; if (read.getAlignment() != null) { @@ -285,79 +336,85 @@ static Contig shardKeyForRead(Read read) { } } // If this read is unmapped but its mate is mapped, group them together - if (referenceName == null || alignmentStart == null) { + if (referenceName == null || referenceName.isEmpty() || + referenceName.equals("*") || alignmentStart == null) { if (read.getNextMatePosition() != null) { referenceName = read.getNextMatePosition().getReferenceName(); alignmentStart = read.getNextMatePosition().getPosition(); } } - if (referenceName == null || alignmentStart == null) { + if (referenceName == null || referenceName.isEmpty()) { referenceName = "*"; + } + if (alignmentStart == null) { alignmentStart = new Long(0); } - return shardFromAlignmentStart(referenceName, alignmentStart); + return shardFromAlignmentStart(referenceName, alignmentStart, lociPerShard); } - static Contig shardFromAlignmentStart(String referenceName, long alignmentStart) { - final long shardStart = (alignmentStart / LOCI_PER_SHARD) * LOCI_PER_SHARD; - return new Contig(referenceName, shardStart, shardStart + LOCI_PER_SHARD); + static Contig shardFromAlignmentStart(String referenceName, long alignmentStart, long lociPerShard) { + final long shardStart = (alignmentStart / lociPerShard) * lociPerShard; + return new Contig(referenceName, shardStart, shardStart + lociPerShard); } - + public static TupleTag>> SHARDED_READS_TAG = new TupleTag<>(); public static TupleTag HEADER_TAG = new TupleTag<>(); - - public static class WriteReadsTransform + + public static class WriteReadsTransform extends PTransform> { - + @Override public PCollection apply(PCollectionTuple tuple) { final PCollection header = tuple.get(HEADER_TAG); final PCollectionView headerView = header.apply(View.asSingleton()); - + final PCollection>> shardedReads = tuple.get(SHARDED_READS_TAG); - - final PCollection writtenShardNames = + + final PCollection writtenShardNames = shardedReads.apply(ParDo.named("Write shards") .withSideInputs(Arrays.asList(headerView)) .of(new WriteShardFn(headerView))); - - final PCollectionView> writtenShardsView = + + final PCollectionView> writtenShardsView = writtenShardNames.apply(View.asIterable()); - + final PCollection destinationPath = p.apply( Create.of(options.getOutput())); - + final PCollection writtenFile = destinationPath.apply( ParDo.named("Combine shards") .withSideInputs(writtenShardsView) .of(new CombineShardsFn(writtenShardsView))); - + return writtenFile; } - + public static PCollection write(PCollection>> shardedReads, HeaderInfo headerInfo) { final PCollectionTuple tuple = PCollectionTuple .of(SHARDED_READS_TAG,shardedReads) - .and(HEADER_TAG, p.apply(Create.of(headerInfo)) - .setCoder(HEADER_INFO_CODER)); + .and(HEADER_TAG, p.apply(Create.of(headerInfo).withCoder(HEADER_INFO_CODER))); return (new WriteReadsTransform()).apply(tuple); } } - + public static class WriteShardFn extends DoFn>, String> { final PCollectionView headerView; Storage.Objects storage; + Aggregator readCountAggregator; + Aggregator unmappedReadCountAggregator; public WriteShardFn(final PCollectionView headerView) { this.headerView = headerView; + readCountAggregator = createAggregator("Written reads", new SumIntegerFn()); + unmappedReadCountAggregator = createAggregator("Written unmapped reads", new SumIntegerFn()); } - + @Override public void startBundle(DoFn>, String>.Context c) throws IOException { storage = Transport.newStorageClient(c.getPipelineOptions().as(GCSOptions.class)).build().objects(); } - + @Override public void processElement(DoFn>, String>.ProcessContext c) throws Exception { @@ -366,65 +423,94 @@ public void processElement(DoFn>, String>.ProcessConte final Contig shardContig = shard.getKey(); final Iterable reads = shard.getValue(); final boolean isFirstShard = shardContig.equals(headerInfo.firstShard); - if (isFirstShard) { LOG.info("Writing first shard " + shardContig); } else { LOG.info("Writing non-first shard " + shardContig); } - final String writeResult = writeShard(headerInfo.header, - shardContig, reads, - c.getPipelineOptions().as(ShardedBAMWritingOptions.class), - isFirstShard); - c.output(writeResult); + int numRetriesLeft = MAX_RETRIES_FOR_WRITING_A_SHARD; + boolean done = false; + do { + try { + final String writeResult = writeShard(headerInfo.header, + shardContig, reads, + c.getPipelineOptions().as(ShardedBAMWritingOptions.class), + isFirstShard); + c.output(writeResult); + done = true; + } catch (IOException iox) { + LOG.warning("Write shard failed for " + shardContig + ": " + iox.getMessage()); + if (--numRetriesLeft <= 0) { + LOG.warning("No more retries - failing the task for " + shardContig); + throw iox; + } + } + } while (!done); LOG.info("Finished writing " + shardContig); } - - String writeShard(SAMFileHeader header, Contig shardContig, Iterable reads, + + String writeShard(SAMFileHeader header, Contig shardContig, Iterable reads, ShardedBAMWritingOptions options, boolean isFirstShard) throws IOException { final String outputFileName = options.getOutput(); - final String shardName = outputFileName + "-" + shardContig; + final String shardName = outputFileName + "-" + shardContig.referenceName + + ":" + String.format("%012d", shardContig.start) + "-" + + String.format("%012d", shardContig.end); LOG.info("Writing shard file " + shardName); - final OutputStream outputStream = + final OutputStream outputStream = Channels.newOutputStream( new GcsUtil.GcsUtilFactory().create(options) - .create(GcsPath.fromUri(shardName), + .create(GcsPath.fromUri(shardName), "application/octet-stream")); int count = 0; + int countUnmapped = 0; // Use a TruncatedOutputStream to avoid writing the empty gzip block that // indicates EOF. final BAMBlockWriter bw = new BAMBlockWriter(new TruncatedOutputStream( outputStream, BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK.length), null /*file*/); - bw.setSortOrder(header.getSortOrder(), true /*presorted*/); + // If reads are unsorted then we do not care about their order + // otherwise we need to sort them as we write. + final boolean treatReadsAsPresorted = + header.getSortOrder() == SAMFileHeader.SortOrder.unsorted; + bw.setSortOrder(header.getSortOrder(), treatReadsAsPresorted); bw.setHeader(header); if (isFirstShard) { + LOG.info("First shard - writing header to " + shardName); bw.writeHeader(header); } for (Read read : reads) { SAMRecord samRecord = ReadUtils.makeSAMRecord(read, header); + if (samRecord.getReadUnmappedFlag()) { + if (!samRecord.getMateUnmappedFlag()) { + samRecord.setReferenceName(samRecord.getMateReferenceName()); + samRecord.setAlignmentStart(samRecord.getMateAlignmentStart()); + } + countUnmapped++; + } bw.addAlignment(samRecord); count++; } bw.close(); - LOG.info("Wrote " + count + " reads into " + shardName); + LOG.info("Wrote " + count + " reads, " + countUnmapped + " umapped, into " + shardName); + readCountAggregator.addValue(count); + unmappedReadCountAggregator.addValue(countUnmapped); return shardName; } } - + public static class CombineShardsFn extends DoFn { final PCollectionView> shards; - + public CombineShardsFn(PCollectionView> shards) { this.shards= shards; } - + @Override public void processElement(DoFn.ProcessContext c) throws Exception { - final String result = + final String result = combineShards( - c.getPipelineOptions().as(ShardedBAMWritingOptions.class), + c.getPipelineOptions().as(ShardedBAMWritingOptions.class), c.element(), c.sideInput(shards)); c.output(result); @@ -439,11 +525,6 @@ static String combineShards(ShardedBAMWritingOptions options, String dest, .build() .objects(); - final GcsPath destPath = GcsPath.fromUri(dest); - - StorageObject destination = new StorageObject() - .setContentType("application/octet-stream"); - ArrayList sortedShardsNames = Lists.newArrayList(shards); Collections.sort(sortedShardsNames); @@ -456,22 +537,100 @@ static String combineShards(ShardedBAMWritingOptions options, String dest, os.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK); os.close(); sortedShardsNames.add(eofFileName); - // list of files to concatenate - ArrayList sourceObjects = new ArrayList(); - for (String shard : sortedShardsNames) { - final GcsPath shardPath = GcsPath.fromUri(shard); - LOG.info("Adding object " + shardPath); - sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) ); + + int stageNumber = 0; + while (sortedShardsNames.size() > 32) { + LOG.info("Have " + sortedShardsNames.size() + + " shards: must combine in groups 32"); + final ArrayList combinedShards = Lists.newArrayList(); + for (int idx = 0; idx < sortedShardsNames.size(); idx += 32) { + final int endIdx = Math.min(idx + 32, sortedShardsNames.size()); + final List combinableShards = sortedShardsNames.subList( + idx, endIdx); + final String intermediateCombineResultName = dest + "-" + + String.format("%02d",stageNumber) + "-" + + String.format("%02d",idx) + "- " + + String.format("%02d",endIdx); + final String combineResult = composeAndCleanupShards(storage, + combinableShards, intermediateCombineResultName); + combinedShards.add(combineResult); + } + sortedShardsNames = combinedShards; + stageNumber++; } - - final ComposeRequest composeRequest = new ComposeRequest() - .setDestination(destination) - .setSourceObjects(sourceObjects); - final Compose compose = storage.compose( - destPath.getBucket(), destPath.getObject(), composeRequest); - final String combineResult = compose.execute().toString(); - LOG.info("Combine result is " + combineResult); + + LOG.info("Combining a final group of " + sortedShardsNames.size() + " shards"); + final String combineResult = composeAndCleanupShards(storage, + sortedShardsNames, dest); + generateIndex(options, storage, combineResult); return combineResult; } + + static void generateIndex(ShardedBAMWritingOptions options, + Storage.Objects storage, String bamFilePath) throws IOException { + final String baiFilePath = bamFilePath + ".bai"; + Stopwatch timer = Stopwatch.createStarted(); + LOG.info("Generating BAM index: " + baiFilePath); + LOG.info("Reading BAM file: " + bamFilePath); + final SamReader reader = BAMIO.openBAM(storage, bamFilePath, ValidationStringency.LENIENT, true); + + final OutputStream outputStream = + Channels.newOutputStream( + new GcsUtil.GcsUtilFactory().create(options) + .create(GcsPath.fromUri(baiFilePath), + "application/octet-stream")); + BAMIndexer indexer = new BAMIndexer(outputStream, reader.getFileHeader()); + + long processedReads = 0; + + // create and write the content + for (SAMRecord rec : reader) { + if (++processedReads % 1000000 == 0) { + dumpStats(processedReads, timer); + } + indexer.processAlignment(rec); + } + indexer.finish(); + dumpStats(processedReads, timer); + } + + static void dumpStats(long processedReads, Stopwatch timer) { + LOG.info("Processed " + processedReads + " records in " + timer + + ". Speed: " + (processedReads*1000)/timer.elapsed(TimeUnit.MILLISECONDS) + " reads/sec"); + + } + } + + static String composeAndCleanupShards(Storage.Objects storage, + List shardNames, String dest) throws IOException { + LOG.info("Combining shards into " + dest); + + final GcsPath destPath = GcsPath.fromUri(dest); + + StorageObject destination = new StorageObject() + .setContentType("application/octet-stream"); + + ArrayList sourceObjects = new ArrayList(); + for (String shard : shardNames) { + final GcsPath shardPath = GcsPath.fromUri(shard); + LOG.info("Adding shard " + shardPath + " for result " + dest); + sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) ); + } + + final ComposeRequest composeRequest = new ComposeRequest() + .setDestination(destination) + .setSourceObjects(sourceObjects); + final Compose compose = storage.compose( + destPath.getBucket(), destPath.getObject(), composeRequest); + final StorageObject result = compose.execute(); + final String combineResult = GcsPath.fromObject(result).toString(); + LOG.info("Combine result is " + combineResult); + for (SourceObjects sourceObject : sourceObjects) { + final String shardToDelete = sourceObject.getName(); + LOG.info("Cleaning up shard " + shardToDelete + " for result " + dest); + storage.delete(destPath.getBucket(), shardToDelete).execute(); + } + + return combineResult; } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/BAMIO.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/BAMIO.java index b8edf8b..cebd15b 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/BAMIO.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/BAMIO.java @@ -43,13 +43,18 @@ public static ReaderAndIndex openBAMAndExposeIndex(Storage.Objects storageClient ReaderAndIndex result = new ReaderAndIndex(); result.index = openIndexForPath(storageClient, gcsStoragePath); result.reader = openBAMReader( - openBAMFile(storageClient, gcsStoragePath,result.index), stringency); + openBAMFile(storageClient, gcsStoragePath,result.index), stringency, false); return result; } - public static SamReader openBAM(Storage.Objects storageClient, String gcsStoragePath, ValidationStringency stringency) throws IOException { + public static SamReader openBAM(Storage.Objects storageClient, String gcsStoragePath, + ValidationStringency stringency, boolean includeFileSource) throws IOException { return openBAMReader(openBAMFile(storageClient, gcsStoragePath, - openIndexForPath(storageClient, gcsStoragePath)), stringency); + openIndexForPath(storageClient, gcsStoragePath)), stringency, includeFileSource); + } + + public static SamReader openBAM(Storage.Objects storageClient, String gcsStoragePath, ValidationStringency stringency) throws IOException { + return openBAM(storageClient, gcsStoragePath, stringency, false); } private static SeekableStream openIndexForPath(Storage.Objects storageClient,String gcsStoragePath) { @@ -74,9 +79,14 @@ private static SamInputResource openBAMFile(Storage.Objects storageClient, Strin return samInputResource; } - private static SamReader openBAMReader(SamInputResource resource, ValidationStringency stringency) { - SamReaderFactory samReaderFactory = SamReaderFactory.makeDefault().validationStringency(stringency) + private static SamReader openBAMReader(SamInputResource resource, ValidationStringency stringency, boolean includeFileSource) { + SamReaderFactory samReaderFactory = SamReaderFactory + .makeDefault() + .validationStringency(stringency) .enable(SamReaderFactory.Option.CACHE_FILE_BASED_INDEXES); + if (includeFileSource) { + samReaderFactory.enable(SamReaderFactory.Option.INCLUDE_SOURCE_IN_RECORDS); + } final SamReader samReader = samReaderFactory.open(resource); return samReader; } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/ReadBAMTransform.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/ReadBAMTransform.java index f7b7a0b..4f8b648 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/ReadBAMTransform.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/ReadBAMTransform.java @@ -20,11 +20,13 @@ import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.SerializableCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.ParDo; import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn; import com.google.cloud.dataflow.sdk.util.Transport; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionTuple; @@ -52,10 +54,20 @@ public static class ReadFn extends DoFn { GenomicsFactory.OfflineAuth auth; Storage.Objects storage; ReaderOptions options; + Aggregator recordCountAggregator; + Aggregator readCountAggregator; + Aggregator skippedStartCountAggregator; + Aggregator skippedEndCountAggregator; + Aggregator skippedRefMismatchAggregator; public ReadFn(GenomicsFactory.OfflineAuth auth, ReaderOptions options) { this.auth = auth; this.options = options; + recordCountAggregator = createAggregator("Processed records", new SumIntegerFn()); + readCountAggregator = createAggregator("Reads generated", new SumIntegerFn()); + skippedStartCountAggregator = createAggregator("Skipped start", new SumIntegerFn()); + skippedEndCountAggregator = createAggregator("Skipped end", new SumIntegerFn()); + skippedRefMismatchAggregator = createAggregator("Ref mismatch", new SumIntegerFn()); } @Override @@ -65,8 +77,13 @@ public void startBundle(DoFn.Context c) throws IOException { @Override public void processElement(ProcessContext c) throws java.lang.Exception { - (new Reader(storage, options, c.element(), c)) - .process(); + final Reader reader = new Reader(storage, options, c.element(), c); + reader.process(); + recordCountAggregator.addValue(reader.recordsProcessed); + skippedStartCountAggregator.addValue(reader.recordsBeforeStart); + skippedEndCountAggregator.addValue(reader.recordsAfterEnd); + skippedRefMismatchAggregator.addValue(reader.mismatchedSequence); + readCountAggregator.addValue(reader.readsGenerated); } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Reader.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Reader.java index 81401b5..78cbb1d 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Reader.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Reader.java @@ -56,10 +56,11 @@ enum Filter { Filter filter; - int recordsBeforeStart = 0; - int recordsAfterEnd = 0; - int mismatchedSequence = 0; - int recordsProcessed = 0; + public int recordsBeforeStart = 0; + public int recordsAfterEnd = 0; + public int mismatchedSequence = 0; + public int recordsProcessed = 0; + public int readsGenerated = 0; public Reader(Objects storageClient, ReaderOptions options, BAMShard shard, DoFn.ProcessContext c) { super(); @@ -110,10 +111,10 @@ void openFile() throws IOException { LOG.info("Processing unmapped"); iterator = reader.queryUnmapped(); } else if (shard.span != null) { - LOG.info("Processing span"); + LOG.info("Processing span for " + shard.contig); iterator = reader.indexing().iterator(shard.span); } else if (shard.contig.referenceName != null && !shard.contig.referenceName.isEmpty()) { - LOG.info("Processing all bases for " + shard.contig.referenceName); + LOG.info("Processing all bases for " + shard.contig); iterator = reader.query(shard.contig.referenceName, (int) shard.contig.start, (int) shard.contig.end, false); } @@ -160,6 +161,7 @@ boolean passesFilter(SAMRecord record) { } void processRecord(SAMRecord record) { + recordsProcessed++; if (!passesFilter(record)) { mismatchedSequence++; return; @@ -168,20 +170,20 @@ void processRecord(SAMRecord record) { recordsBeforeStart++; return; } - if (record.getAlignmentStart() >= shard.contig.end) { + if (record.getAlignmentStart() > shard.contig.end) { recordsAfterEnd++; return; } c.output(ReadUtils.makeRead(record)); - recordsProcessed++; + readsGenerated++; } void dumpStats() { timer.stop(); - LOG.info("Processed " + recordsProcessed + + LOG.info("Processed " + recordsProcessed + " outputted " + readsGenerated + " in " + timer + ". Speed: " + (recordsProcessed*1000)/timer.elapsed(TimeUnit.MILLISECONDS) + " reads/sec" - + ", skipped other sequences " + mismatchedSequence + + ", filtered out by reference and mapping " + mismatchedSequence + ", skippedBefore " + recordsBeforeStart + ", skipped after " + recordsAfterEnd); } @@ -219,7 +221,7 @@ public static Iterable readSequentiallyForTesting(Objects storageClient, recordsBeforeStart++; continue; } - if (record.getAlignmentStart() >= contig.end) { + if (record.getAlignmentStart() > contig.end) { recordsAfterEnd++; continue; } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Sharder.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Sharder.java index f652537..c3a8f0a 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Sharder.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/bam/Sharder.java @@ -178,6 +178,7 @@ Contig desiredContigForReference(SAMSequenceRecord reference) { } void createShardsForReference(SAMSequenceRecord reference, Contig contig) { + LOG.info("Creating shard for: " + contig); final BitSet overlappingBins = GenomicIndexUtil.regionToBins( (int) contig.start, (int) contig.end); if (overlappingBins == null) { @@ -185,6 +186,7 @@ void createShardsForReference(SAMSequenceRecord reference, Contig contig) { return; } + BAMShard currentShard = null; for (int binIndex = overlappingBins.nextSetBit(0); binIndex >= 0; binIndex = overlappingBins.nextSetBit(binIndex + 1)) { final Bin bin = index.getBinData(reference.getSequenceIndex(), binIndex); @@ -234,13 +236,16 @@ void createShardsForReference(SAMSequenceRecord reference, Contig contig) { if (shardingPolicy.shardBigEnough(currentShard)) { LOG.info("Shard size is big enough to finalize: " + currentShard.sizeInLoci() + ", " + currentShard.approximateSizeInBytes() + " bytes"); - output.output(currentShard.finalize(index, Math.min(index.getLastLocusInBin(bin), (int)contig.end))); + final BAMShard bamShard = currentShard.finalize(index, Math.min(index.getLastLocusInBin(bin), (int)contig.end)); + LOG.info("Outputting shard: " + bamShard.contig); + output.output(bamShard); currentShard = null; } } if (currentShard != null) { LOG.info("Outputting last shard of size " + - currentShard.sizeInLoci() + ", " + currentShard.approximateSizeInBytes() + " bytes"); + currentShard.sizeInLoci() + ", " + currentShard.approximateSizeInBytes() + " bytes " + + currentShard.contig); output.output(currentShard.finalize(index, (int)contig.end)); } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java b/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java index 1a4f6dc..70e01cb 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java @@ -46,16 +46,18 @@ public void write(byte[] data, int offset, int length) throws IOException { // We have more than bytesToTruncate to write, so clear the buffer // completely, and write all but bytesToTruncate directly to the stream. os.write(buf, 0, count); - os.write(data, offset, length - bytesToTruncate); - System.arraycopy(data, offset + length - bytesToTruncate, buf, 0, bytesToTruncate); + final int bytesToWriteDrirectly = length - bytesToTruncate; + os.write(data, offset, bytesToWriteDrirectly); + System.arraycopy(data, offset + bytesToWriteDrirectly, buf, 0, bytesToTruncate); count = bytesToTruncate; } else { // Need this many of the current bytes to stay in the buffer to ensure we // have at least bytesToTruncate. - int keepInBuffer = bytesToTruncate - length; + final int keepInBuffer = bytesToTruncate - length; // Write the rest to the stream. - os.write(buf, 0, count - keepInBuffer); - System.arraycopy(buf, count - keepInBuffer, buf, 0, keepInBuffer); + final int bytesToDumpFromBuffer = count - keepInBuffer; + os.write(buf, 0, bytesToDumpFromBuffer); + System.arraycopy(buf, bytesToDumpFromBuffer, buf, 0, keepInBuffer); System.arraycopy(data, offset, buf, keepInBuffer, length); count = bytesToTruncate; } @@ -74,9 +76,10 @@ public void close() throws IOException { } private void flushBuffer() throws IOException { + final int bytesWeCanSafelyWrite = count - bytesToTruncate; if (count > bytesToTruncate) { - os.write(buf, 0, count - bytesToTruncate); - System.arraycopy(buf, count - bytesToTruncate, buf, 0, bytesToTruncate); + os.write(buf, 0, bytesWeCanSafelyWrite); + System.arraycopy(buf, bytesWeCanSafelyWrite, buf, 0, bytesToTruncate); count = bytesToTruncate; } }