diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java index b66aecf..7e46c66 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsets.java @@ -34,14 +34,9 @@ */ public class ExtractSimilarCallsets { - abstract static class ExtractSimilarCallsetsBase extends DoFn, Long>> { + abstract static class ExtractSimilarCallsetsBase extends DoFn, Long>> { - protected BiMap dataIndices; - private ImmutableMultiset.Builder> accumulator; - - public ExtractSimilarCallsetsBase(BiMap dataIndices) { - this.dataIndices = dataIndices; - } + private ImmutableMultiset.Builder> accumulator; @Override public void startBundle(Context c) { @@ -50,37 +45,33 @@ public void startBundle(Context c) { @Override public void processElement(ProcessContext context) { - FluentIterable> pairs = PairGenerator.WITH_REPLACEMENT.allPairs( + FluentIterable> pairs = PairGenerator.WITH_REPLACEMENT.allPairs( getSamplesWithVariant(context.element()), Ordering.natural()); - for (KV pair : pairs) { + for (KV pair : pairs) { accumulator.add(pair); } } @Override public void finishBundle(Context context) { - for (Multiset.Entry> entry : accumulator.build().entrySet()) { + for (Multiset.Entry> entry : accumulator.build().entrySet()) { context.output(KV.of(entry.getElement(), Long.valueOf(entry.getCount()))); } } - protected abstract ImmutableList getSamplesWithVariant(V variant); + protected abstract ImmutableList getSamplesWithVariant(V variant); } public static class v1 extends ExtractSimilarCallsetsBase { - public v1(BiMap dataIndices) { - super(dataIndices); - } - @Override - protected ImmutableList getSamplesWithVariant(Variant variant) { + protected ImmutableList getSamplesWithVariant(Variant variant) { return ImmutableList.copyOf(Iterables.transform( - CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { + CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { @Override - public Integer apply(VariantCall call) { - return dataIndices.get(call.getCallSetName()); + public String apply(VariantCall call) { + return call.getCallSetName(); } })); @@ -90,18 +81,14 @@ public Integer apply(VariantCall call) { @Deprecated // Remove this when fully migrated to gRPC. public static class v1beta2 extends ExtractSimilarCallsetsBase { - public v1beta2(BiMap dataIndices) { - super(dataIndices); - } - @Override - protected ImmutableList getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) { + protected ImmutableList getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) { return ImmutableList.copyOf(Iterables.transform( - CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { + CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function() { @Override - public Integer apply(Call call) { - return dataIndices.get(call.getCallSetName()); + public String apply(Call call) { + return call.getCallSetName(); } })); diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java index 9cb1d59..532425d 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/OutputPCoAFile.java @@ -37,10 +37,10 @@ * The input data must be for a similarity matrix which will be symmetric. This is not * the same as Principal Component Analysis. */ -public class OutputPCoAFile extends PTransform, Long>>, PDone> { +public class OutputPCoAFile extends PTransform, Long>>, PDone> { - private static final Combine.CombineFn, Long>, - List, Long>>, Iterable, Long>>> TO_LIST = + private static final Combine.CombineFn, Long>, + List, Long>>, Iterable, Long>>> TO_LIST = toList(); private static Combine.CombineFn, Iterable> toList() { @@ -79,12 +79,13 @@ public OutputPCoAFile(BiMap dataIndices, String outputFile) { } @Override - public PDone apply(PCollection, Long>> similarPairs) { + public PDone apply(PCollection, Long>> similarPairs) { return similarPairs - .apply(Sum.>longsPerKey()) + .apply(Sum.>longsPerKey()) .apply(Combine.globally(TO_LIST)) - .apply(ParDo.of(new PCoAnalysis(dataIndices))) - .apply(ParDo.of(new DoFn, String>() { + .apply(ParDo.named("PCoAAnalysis").of(new PCoAnalysis(dataIndices))) + .apply(ParDo.named("FormatGraphData") + .of(new DoFn, String>() { @Override public void processElement(ProcessContext c) throws Exception { Iterable graphResults = c.element(); diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java index 97a14f9..16cff4c 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/PCoAnalysis.java @@ -16,25 +16,20 @@ package com.google.cloud.genomics.dataflow.functions; +import java.io.Serializable; +import java.util.Collection; +import java.util.List; + +import Jama.EigenvalueDecomposition; +import Jama.Matrix; + import com.google.api.client.util.Lists; import com.google.api.client.util.Preconditions; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.values.KV; -import com.google.cloud.genomics.dataflow.functions.PCoAnalysis.GraphResult; -import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions; -import com.google.cloud.genomics.utils.GenomicsFactory; import com.google.common.collect.BiMap; -import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; -import Jama.EigenvalueDecomposition; -import Jama.Matrix; - -import java.io.Serializable; -import java.util.Collection; -import java.util.List; -import java.util.Map; - /** * This function runs a Principal Coordinate Analysis inside of a SeqDo. * It can not be parallelized. @@ -47,16 +42,17 @@ * The input data to this algorithm must be for a similarity matrix - and the * resulting matrix must be symmetric. * - * Input: KV(KV(dataIndex, dataIndex), count of how similar the data pair is) + * Input: KV(KV(dataName, dataName), count of how similar the data pair is) * Output: GraphResults - an x/y pair and a label * * Example input for a tiny dataset of size 2: * - * KV(KV(0, 0), 5) - * KV(KV(1, 1), 5) - * KV(KV(1, 0), 2) + * KV(KV(data1, data1), 5) + * KV(KV(data1, data2), 2) + * KV(KV(data2, data2), 5) + * KV(KV(data2, data1), 2) */ -public class PCoAnalysis extends DoFn, Long>>, +public class PCoAnalysis extends DoFn, Long>>, Iterable> { public static class GraphResult implements Serializable { @@ -198,17 +194,14 @@ private List getPcaData(double[][] data, BiMap dat return results; } - @Override - public void processElement(ProcessContext context) { - Collection, Long>> element = ImmutableList.copyOf(context.element()); + @Override public void processElement(ProcessContext context) { + Collection, Long>> element = ImmutableList.copyOf(context.element()); - // Transform the data into a matrix int dataSize = dataIndices.size(); double[][] matrixData = new double[dataSize][dataSize]; - for (KV, Long> entry : element) { - int d1 = entry.getKey().getKey(); - int d2 = entry.getKey().getValue(); - + for (KV, Long> entry : element) { + int d1 = dataIndices.get(entry.getKey().getKey()); + int d2 = dataIndices.get(entry.getKey().getValue()); double value = entry.getValue(); matrixData[d1][d2] = value; if (d1 != d2) { diff --git a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java index dc6ad82..3c3f4e7 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VariantSimilarity.java @@ -76,7 +76,6 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept GenomicsFactory.OfflineAuth auth = GenomicsOptions.Methods.getGenomicsAuth(options); - // Use integer indices instead of string callSetNames to reduce data sizes. List callSetNames = GenomicsUtils.getCallSetsNames(options.getDatasetId() , auth); Collections.sort(callSetNames); // Ensure a stable sort order for reproducible results. BiMap dataIndices = HashBiMap.create(); @@ -92,7 +91,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept Proto2Coder.of(StreamVariantsRequest.class)); p.begin(); - PCollection, Long>> similarCallsets = null; + PCollection, Long>> similarCallsets = null; if(options.getUseGrpc()) { List requests = options.isAllReferences() ? @@ -102,7 +101,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept similarCallsets = p.apply(Create.of(requests)) .apply(new VariantStreamer(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS)) - .apply(ParDo.of(new ExtractSimilarCallsets.v1(dataIndices))); + .apply(ParDo.of(new ExtractSimilarCallsets.v1())); } else { List requests = options.isAllReferences() ? ShardUtils.getPaginatedVariantRequests(options.getDatasetId(), ShardUtils.SexChromosomeFilter.EXCLUDE_XY, @@ -111,7 +110,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept similarCallsets = p.apply(Create.of(requests)) .apply(ParDo.of(new VariantReader(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS))) - .apply(ParDo.of(new ExtractSimilarCallsets.v1beta2(dataIndices))); + .apply(ParDo.of(new ExtractSimilarCallsets.v1beta2())); } similarCallsets.apply(new OutputPCoAFile(dataIndices, options.getOutput())); diff --git a/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java b/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java index ca525ec..c6fa513 100644 --- a/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java +++ b/src/test/java/com/google/cloud/genomics/dataflow/functions/ExtractSimilarCallsetsTest.java @@ -37,20 +37,9 @@ @RunWith(JUnit4.class) public class ExtractSimilarCallsetsTest { - BiMap dataIndices; - - @Before - public void setUp() { - dataIndices = HashBiMap.create(); - dataIndices.put("ref", dataIndices.size()); - dataIndices.put("alt1", dataIndices.size()); - dataIndices.put("alt2", dataIndices.size()); - dataIndices.put("alt3", dataIndices.size()); - } - @Test public void testGetSamplesWithVariant() throws Exception { - ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1(dataIndices); + ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1(); Variant variant = Variant.newBuilder().build(); assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant)); @@ -61,18 +50,18 @@ public void testGetSamplesWithVariant() throws Exception { VariantCall alt1 = DataUtils.makeVariantCall("alt1", 1, 0); variant = Variant.newBuilder().addCalls(ref).addCalls(alt1).build(); - assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant)); + assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant)); VariantCall alt2 = DataUtils.makeVariantCall("alt2", 0, 1); VariantCall alt3 = DataUtils.makeVariantCall("alt3", 1, 1); variant = Variant.newBuilder().addCalls(ref) .addCalls(alt1).addCalls(alt2).addCalls(alt3).build(); - assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant)); + assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant)); } @Test public void testGetSamplesWithVariantv1beta2() throws Exception { - ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2(dataIndices); + ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2(); com.google.api.services.genomics.model.Variant variant = new com.google.api.services.genomics.model.Variant(); List calls = new ArrayList(); @@ -84,10 +73,10 @@ public void testGetSamplesWithVariantv1beta2() throws Exception { assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant)); calls.add(DataUtils.makeCall("alt1", 1, 0)); - assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant)); + assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant)); calls.add(DataUtils.makeCall("alt2", 0, 1)); calls.add(DataUtils.makeCall("alt3", 1, 1)); - assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant)); + assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant)); } }