diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java index 25f288a..329f4b4 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java @@ -19,19 +19,26 @@ import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.genomics.dataflow.utils.CallFilters; import com.google.cloud.genomics.dataflow.utils.PairGenerator; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; +import com.google.common.collect.BiMap; +import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultiset; import com.google.common.collect.Iterables; import com.google.common.collect.Multiset; +import com.google.common.collect.Ordering; /** * Emits a callset pair every time they share a variant. */ -public class ExtractSimilarCallsets extends DoFn, Long>> { +public class ExtractSimilarCallsets extends DoFn, Long>> { - private ImmutableMultiset.Builder> accumulator; + private BiMap dataIndices; + private ImmutableMultiset.Builder> accumulator; + + public ExtractSimilarCallsets(BiMap dataIndices) { + this.dataIndices = dataIndices; + } @Override public void startBundle(Context c) { @@ -40,27 +47,27 @@ public void startBundle(Context c) { @Override public void processElement(ProcessContext context) { - for (KV pair : PairGenerator.WITH_REPLACEMENT.allPairs( - getSamplesWithVariant(context.element()), String.CASE_INSENSITIVE_ORDER)) { + FluentIterable> pairs = PairGenerator.WITH_REPLACEMENT.allPairs( + getSamplesWithVariant(context.element()), Ordering.natural()); + for (KV pair : pairs) { accumulator.add(pair); } } @Override public void finishBundle(Context context) { - for (Multiset.Entry> entry : accumulator.build().entrySet()) { + for (Multiset.Entry> entry : accumulator.build().entrySet()) { context.output(KV.of(entry.getElement(), Long.valueOf(entry.getCount()))); } } - @VisibleForTesting - static ImmutableList getSamplesWithVariant(Variant variant) { + ImmutableList getSamplesWithVariant(Variant variant) { return ImmutableList.copyOf(Iterables.transform( - CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { + CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { @Override - public String apply(Call call) { - return call.getCallSetName(); + public Integer apply(Call call) { + return dataIndices.get(call.getCallSetName()); } })); diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java index 518249e..9cb1d59 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java @@ -24,6 +24,7 @@ import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.common.collect.BiMap; import java.util.ArrayList; import java.util.List; @@ -36,10 +37,10 @@ * The input data must be for a similarity matrix which will be symmetric. This is not * the same as Principal Component Analysis. */ -public class OutputPCoAFile extends PTransform, Long>>, PDone> { +public class OutputPCoAFile extends PTransform, Long>>, PDone> { - private static final Combine.CombineFn, Long>, - List, Long>>, Iterable, Long>>> TO_LIST = + private static final Combine.CombineFn, Long>, + List, Long>>, Iterable, Long>>> TO_LIST = toList(); private static Combine.CombineFn, Iterable> toList() { @@ -69,20 +70,21 @@ private static Combine.CombineFn, Iterable> toList() { }; } + private BiMap dataIndices; private final String outputFile; - public OutputPCoAFile(String outputFile) { + public OutputPCoAFile(BiMap dataIndices, String outputFile) { + this.dataIndices = dataIndices; this.outputFile = outputFile; } @Override - public PDone apply(PCollection, Long>> similarPairs) { + public PDone apply(PCollection, Long>> similarPairs) { return similarPairs - .apply(Sum.>longsPerKey()) + .apply(Sum.>longsPerKey()) .apply(Combine.globally(TO_LIST)) - .apply(ParDo.named("PCoAAnalysis").of(PCoAnalysis.of())) - .apply(ParDo.named("FormatGraphData") - .of(new DoFn, String>() { + .apply(ParDo.of(new PCoAnalysis(dataIndices))) + .apply(ParDo.of(new DoFn, String>() { @Override public void processElement(ProcessContext c) throws Exception { Iterable graphResults = c.element(); diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java index 269a68f..97a14f9 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java @@ -20,6 +20,9 @@ import com.google.api.client.util.Preconditions; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.genomics.dataflow.functions.PCoAnalysis.GraphResult; +import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions; +import com.google.cloud.genomics.utils.GenomicsFactory; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; @@ -44,17 +47,16 @@ * The input data to this algorithm must be for a similarity matrix - and the * resulting matrix must be symmetric. * - * Input: KV(KV(dataName, dataName), count of how similar the data pair is) + * Input: KV(KV(dataIndex, dataIndex), count of how similar the data pair is) * Output: GraphResults - an x/y pair and a label * * Example input for a tiny dataset of size 2: * - * KV(KV(data1, data1), 5) - * KV(KV(data1, data2), 2) - * KV(KV(data2, data2), 5) - * KV(KV(data2, data1), 2) + * KV(KV(0, 0), 5) + * KV(KV(1, 1), 5) + * KV(KV(1, 0), 2) */ -public class PCoAnalysis extends DoFn, Long>>, +public class PCoAnalysis extends DoFn, Long>>, Iterable> { public static class GraphResult implements Serializable { @@ -118,20 +120,10 @@ public boolean equals(Object obj) { } } - private static final PCoAnalysis INSTANCE = new PCoAnalysis(); - private BiMap dataIndicies; - - public static PCoAnalysis of() { - return INSTANCE; - } - - private PCoAnalysis() {} + private BiMap dataIndices; - private int getDataIndex(Map dataIndicies, String dataName) { - if (!dataIndicies.containsKey(dataName)) { - dataIndicies.put(dataName, dataIndicies.size()); - } - return dataIndicies.get(dataName); + public PCoAnalysis(BiMap dataIndices) { + this.dataIndices = dataIndices; } // Convert the similarity matrix to an Eigen matrix. @@ -206,22 +198,16 @@ private List getPcaData(double[][] data, BiMap dat return results; } - @Override public void processElement(ProcessContext context) { - Collection, Long>> element = ImmutableList.copyOf(context.element()); - // Transform the data into a matrix - BiMap dataIndicies = HashBiMap.create(); + @Override + public void processElement(ProcessContext context) { + Collection, Long>> element = ImmutableList.copyOf(context.element()); - // TODO: Clean up this code - for (KV, Long> entry : element) { - getDataIndex(dataIndicies, entry.getKey().getKey()); - getDataIndex(dataIndicies, entry.getKey().getValue()); - } - - int dataSize = dataIndicies.size(); + // Transform the data into a matrix + int dataSize = dataIndices.size(); double[][] matrixData = new double[dataSize][dataSize]; - for (KV, Long> entry : element) { - int d1 = getDataIndex(dataIndicies, entry.getKey().getKey()); - int d2 = getDataIndex(dataIndicies, entry.getKey().getValue()); + for (KV, Long> entry : element) { + int d1 = entry.getKey().getKey(); + int d2 = entry.getKey().getValue(); double value = entry.getValue(); matrixData[d1][d2] = value; @@ -229,6 +215,6 @@ private List getPcaData(double[][] data, BiMap dat matrixData[d2][d1] = value; } } - context.output(getPcaData(matrixData, dataIndicies.inverse())); + context.output(getPcaData(matrixData, dataIndices.inverse())); } } \ No newline at end of file diff --git a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java index 0918ac0..94b61cc 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java @@ -23,15 +23,19 @@ import com.google.cloud.genomics.dataflow.functions.ExtractSimilarCallsets; import com.google.cloud.genomics.dataflow.functions.OutputPCoAFile; import com.google.cloud.genomics.dataflow.readers.VariantReader; +import com.google.cloud.genomics.dataflow.readers.VariantStreamer; import com.google.cloud.genomics.dataflow.utils.DataflowWorkarounds; import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions; import com.google.cloud.genomics.dataflow.utils.GenomicsOptions; import com.google.cloud.genomics.utils.Contig.SexChromosomeFilter; import com.google.cloud.genomics.utils.GenomicsFactory; import com.google.cloud.genomics.utils.Paginator.ShardBoundary; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; import java.io.IOException; import java.security.GeneralSecurityException; +import java.util.Collections; import java.util.List; /** @@ -57,15 +61,21 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept GenomicsDatasetOptions.Methods.getVariantRequests(options, auth, SexChromosomeFilter.EXCLUDE_XY); + // Use integer indices instead of string callSetNames to reduce data sizes. + List callSetNames = VariantStreamer.getCallSetsNames(options.getDatasetId() , auth); + Collections.sort(callSetNames); // Ensure a stable sort order for reproducible results. + BiMap dataIndices = HashBiMap.create(); + for(String callSetName : callSetNames) { + dataIndices.put(callSetName, dataIndices.size()); + } + Pipeline p = Pipeline.create(options); DataflowWorkarounds.registerGenomicsCoders(p); p.begin() .apply(Create.of(requests)) - .apply( - ParDo.named("VariantReader").of( - new VariantReader(auth, ShardBoundary.STRICT, VARIANT_FIELDS))) - .apply(ParDo.named("ExtractSimilarCallsets").of(new ExtractSimilarCallsets())) - .apply(new OutputPCoAFile(options.getOutput())); + .apply(ParDo.of(new VariantReader(auth, ShardBoundary.STRICT, VARIANT_FIELDS))) + .apply(ParDo.of(new ExtractSimilarCallsets(dataIndices))) + .apply(new OutputPCoAFile(dataIndices, options.getOutput())); p.run(); } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java index e87ebb8..8892c9c 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java @@ -13,6 +13,8 @@ */ package com.google.cloud.genomics.dataflow.readers; +import com.google.api.services.genomics.model.CallSet; +import com.google.api.services.genomics.model.SearchCallSetsRequest; import com.google.api.services.genomics.model.SearchVariantSetsRequest; import com.google.api.services.genomics.model.VariantSet; import com.google.cloud.dataflow.sdk.transforms.Aggregator; @@ -32,6 +34,7 @@ import com.google.genomics.v1.StreamVariantsResponse; import com.google.genomics.v1.StreamingVariantServiceGrpc; import com.google.genomics.v1.Variant; + import java.io.IOException; import java.security.GeneralSecurityException; import java.util.Iterator; @@ -64,6 +67,25 @@ public static List getVariantSetIds(String datasetId, GenomicsFactory.Of return output; } + /** + * Gets CallSets Names for a given variantSetId using the Genomics API. + */ + public static List getCallSetsNames(String variantSetId, GenomicsFactory.OfflineAuth auth) + throws IOException, GeneralSecurityException { + List output = Lists.newArrayList(); + Iterable cs = Paginator.Callsets.create( + auth.getGenomics(auth.getDefaultFactory())) + .search(new SearchCallSetsRequest().setVariantSetIds(Lists.newArrayList(variantSetId)), + "callSets/name,nextPageToken"); + for (CallSet c : cs) { + output.add(c.getName()); + } + if (output.isEmpty()) { + throw new IOException("VariantSet " + variantSetId + " does not contain any CallSets"); + } + return output; + } + /** * Constructs a StreamVariantsRequest for a variantSetId, assuming that the user wants all * to include all references. diff --git a/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java b/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java index db2fd06..fac1138 100644 --- a/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java +++ b/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java @@ -15,24 +15,22 @@ */ package com.google.cloud.genomics.dataflow.functions; -import static com.google.cloud.genomics.dataflow.functions.ExtractSimilarCallsets.getSamplesWithVariant; import static org.junit.Assert.assertEquals; -import com.google.api.client.util.Lists; -import com.google.api.services.genomics.model.Call; -import com.google.api.services.genomics.model.Variant; -import com.google.cloud.dataflow.sdk.values.KV; -import com.google.cloud.genomics.dataflow.utils.PairGenerator; -import com.google.cloud.genomics.dataflow.utils.DataUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; +import com.google.api.services.genomics.model.Call; +import com.google.api.services.genomics.model.Variant; +import com.google.cloud.genomics.dataflow.utils.DataUtils; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; @RunWith(JUnit4.class) public class ExtractSimilarCallsetsTest { @@ -40,21 +38,30 @@ public class ExtractSimilarCallsetsTest { @Test public void testGetSamplesWithVariant() throws Exception { + + BiMap dataIndices = HashBiMap.create(); + dataIndices.put("ref", dataIndices.size()); + dataIndices.put("alt1", dataIndices.size()); + dataIndices.put("alt2", dataIndices.size()); + dataIndices.put("alt3", dataIndices.size()); + + ExtractSimilarCallsets doFn = new ExtractSimilarCallsets(dataIndices); + Variant variant = new Variant(); List calls = new ArrayList(); variant.setCalls(calls); - assertEquals(Collections.emptyList(), getSamplesWithVariant(variant)); + assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant)); calls.add(DataUtils.makeCall("ref", 0, 0)); - assertEquals(Collections.emptyList(), getSamplesWithVariant(variant)); + assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant)); calls.add(DataUtils.makeCall("alt1", 1, 0)); - assertEquals(Collections.singletonList("alt1"), getSamplesWithVariant(variant)); + assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant)); calls.add(DataUtils.makeCall("alt2", 0, 1)); calls.add(DataUtils.makeCall("alt3", 1, 1)); - assertEquals(Arrays.asList("alt1", "alt2", "alt3"), getSamplesWithVariant(variant)); + assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant)); } }