Skip to content
This repository has been archived by the owner on Oct 29, 2023. It is now read-only.

Commit

Permalink
Revert back to string indices for variant similarity.
Browse files Browse the repository at this point in the history
This commit reverses the negative impact on performance of commit 98cf2c2 while maintaining a stable sort on the final results.
  • Loading branch information
deflaux committed Aug 13, 2015
1 parent e879e3c commit b5c25c4
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@
*/
public class ExtractSimilarCallsets {

abstract static class ExtractSimilarCallsetsBase<V, C> extends DoFn<V, KV<KV<Integer, Integer>, Long>> {
abstract static class ExtractSimilarCallsetsBase<V, C> extends DoFn<V, KV<KV<String, String>, Long>> {

protected BiMap<String, Integer> dataIndices;
private ImmutableMultiset.Builder<KV<Integer, Integer>> accumulator;

public ExtractSimilarCallsetsBase(BiMap<String, Integer> dataIndices) {
this.dataIndices = dataIndices;
}
private ImmutableMultiset.Builder<KV<String, String>> accumulator;

@Override
public void startBundle(Context c) {
Expand All @@ -50,37 +45,33 @@ public void startBundle(Context c) {

@Override
public void processElement(ProcessContext context) {
FluentIterable<KV<Integer, Integer>> pairs = PairGenerator.WITH_REPLACEMENT.allPairs(
FluentIterable<KV<String, String>> pairs = PairGenerator.WITH_REPLACEMENT.allPairs(
getSamplesWithVariant(context.element()), Ordering.natural());
for (KV<Integer, Integer> pair : pairs) {
for (KV<String, String> pair : pairs) {
accumulator.add(pair);
}
}

@Override
public void finishBundle(Context context) {
for (Multiset.Entry<KV<Integer, Integer>> entry : accumulator.build().entrySet()) {
for (Multiset.Entry<KV<String, String>> entry : accumulator.build().entrySet()) {
context.output(KV.of(entry.getElement(), Long.valueOf(entry.getCount())));
}
}

protected abstract ImmutableList<Integer> getSamplesWithVariant(V variant);
protected abstract ImmutableList<String> getSamplesWithVariant(V variant);
}

public static class v1 extends ExtractSimilarCallsetsBase<Variant, VariantCall> {

public v1(BiMap<String, Integer> dataIndices) {
super(dataIndices);
}

@Override
protected ImmutableList<Integer> getSamplesWithVariant(Variant variant) {
protected ImmutableList<String> getSamplesWithVariant(Variant variant) {
return ImmutableList.copyOf(Iterables.transform(
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<VariantCall, Integer>() {
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<VariantCall, String>() {

@Override
public Integer apply(VariantCall call) {
return dataIndices.get(call.getCallSetName());
public String apply(VariantCall call) {
return call.getCallSetName();
}

}));
Expand All @@ -90,18 +81,14 @@ public Integer apply(VariantCall call) {
@Deprecated // Remove this when fully migrated to gRPC.
public static class v1beta2 extends ExtractSimilarCallsetsBase<com.google.api.services.genomics.model.Variant, Call> {

public v1beta2(BiMap<String, Integer> dataIndices) {
super(dataIndices);
}

@Override
protected ImmutableList<Integer> getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) {
protected ImmutableList<String> getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) {
return ImmutableList.copyOf(Iterables.transform(
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, Integer>() {
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, String>() {

@Override
public Integer apply(Call call) {
return dataIndices.get(call.getCallSetName());
public String apply(Call call) {
return call.getCallSetName();
}

}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
* The input data must be for a similarity matrix which will be symmetric. This is not
* the same as Principal Component Analysis.
*/
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<Integer, Integer>, Long>>, PDone> {
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<String, String>, Long>>, PDone> {

private static final Combine.CombineFn<KV<KV<Integer, Integer>, Long>,
List<KV<KV<Integer, Integer>, Long>>, Iterable<KV<KV<Integer, Integer>, Long>>> TO_LIST =
private static final Combine.CombineFn<KV<KV<String, String>, Long>,
List<KV<KV<String, String>, Long>>, Iterable<KV<KV<String, String>, Long>>> TO_LIST =
toList();

private static <X> Combine.CombineFn<X, List<X>, Iterable<X>> toList() {
Expand Down Expand Up @@ -79,12 +79,13 @@ public OutputPCoAFile(BiMap<String, Integer> dataIndices, String outputFile) {
}

@Override
public PDone apply(PCollection<KV<KV<Integer, Integer>, Long>> similarPairs) {
public PDone apply(PCollection<KV<KV<String, String>, Long>> similarPairs) {
return similarPairs
.apply(Sum.<KV<Integer, Integer>>longsPerKey())
.apply(Sum.<KV<String, String>>longsPerKey())
.apply(Combine.globally(TO_LIST))
.apply(ParDo.of(new PCoAnalysis(dataIndices)))
.apply(ParDo.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
.apply(ParDo.named("PCoAAnalysis").of(new PCoAnalysis(dataIndices)))
.apply(ParDo.named("FormatGraphData")
.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
@Override
public void processElement(ProcessContext c) throws Exception {
Iterable<PCoAnalysis.GraphResult> graphResults = c.element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,20 @@

package com.google.cloud.genomics.dataflow.functions;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import com.google.api.client.util.Lists;
import com.google.api.client.util.Preconditions;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.dataflow.functions.PCoAnalysis.GraphResult;
import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions;
import com.google.cloud.genomics.utils.GenomicsFactory;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
* This function runs a Principal Coordinate Analysis inside of a SeqDo.
* It can not be parallelized.
Expand All @@ -47,16 +42,17 @@
* The input data to this algorithm must be for a similarity matrix - and the
* resulting matrix must be symmetric.
*
* Input: KV(KV(dataIndex, dataIndex), count of how similar the data pair is)
* Input: KV(KV(dataName, dataName), count of how similar the data pair is)
* Output: GraphResults - an x/y pair and a label
*
* Example input for a tiny dataset of size 2:
*
* KV(KV(0, 0), 5)
* KV(KV(1, 1), 5)
* KV(KV(1, 0), 2)
* KV(KV(data1, data1), 5)
* KV(KV(data1, data2), 2)
* KV(KV(data2, data2), 5)
* KV(KV(data2, data1), 2)
*/
public class PCoAnalysis extends DoFn<Iterable<KV<KV<Integer, Integer>, Long>>,
public class PCoAnalysis extends DoFn<Iterable<KV<KV<String, String>, Long>>,
Iterable<PCoAnalysis.GraphResult>> {

public static class GraphResult implements Serializable {
Expand Down Expand Up @@ -198,17 +194,14 @@ private List<GraphResult> getPcaData(double[][] data, BiMap<Integer, String> dat
return results;
}

@Override
public void processElement(ProcessContext context) {
Collection<KV<KV<Integer, Integer>, Long>> element = ImmutableList.copyOf(context.element());
@Override public void processElement(ProcessContext context) {
Collection<KV<KV<String, String>, Long>> element = ImmutableList.copyOf(context.element());

// Transform the data into a matrix
int dataSize = dataIndices.size();
double[][] matrixData = new double[dataSize][dataSize];
for (KV<KV<Integer, Integer>, Long> entry : element) {
int d1 = entry.getKey().getKey();
int d2 = entry.getKey().getValue();

for (KV<KV<String, String>, Long> entry : element) {
int d1 = dataIndices.get(entry.getKey().getKey());
int d2 = dataIndices.get(entry.getKey().getValue());
double value = entry.getValue();
matrixData[d1][d2] = value;
if (d1 != d2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

GenomicsFactory.OfflineAuth auth = GenomicsOptions.Methods.getGenomicsAuth(options);

// Use integer indices instead of string callSetNames to reduce data sizes.
List<String> callSetNames = GenomicsUtils.getCallSetsNames(options.getDatasetId() , auth);
Collections.sort(callSetNames); // Ensure a stable sort order for reproducible results.
BiMap<String, Integer> dataIndices = HashBiMap.create();
Expand All @@ -92,7 +91,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept
Proto2Coder.of(StreamVariantsRequest.class));

p.begin();
PCollection<KV<KV<Integer, Integer>, Long>> similarCallsets = null;
PCollection<KV<KV<String, String>, Long>> similarCallsets = null;

if(options.getUseGrpc()) {
List<StreamVariantsRequest> requests = options.isAllReferences() ?
Expand All @@ -102,7 +101,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

similarCallsets = p.apply(Create.of(requests))
.apply(new VariantStreamer(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS))
.apply(ParDo.of(new ExtractSimilarCallsets.v1(dataIndices)));
.apply(ParDo.of(new ExtractSimilarCallsets.v1()));
} else {
List<SearchVariantsRequest> requests = options.isAllReferences() ?
ShardUtils.getPaginatedVariantRequests(options.getDatasetId(), ShardUtils.SexChromosomeFilter.EXCLUDE_XY,
Expand All @@ -111,7 +110,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

similarCallsets = p.apply(Create.of(requests))
.apply(ParDo.of(new VariantReader(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS)))
.apply(ParDo.of(new ExtractSimilarCallsets.v1beta2(dataIndices)));
.apply(ParDo.of(new ExtractSimilarCallsets.v1beta2()));
}

similarCallsets.apply(new OutputPCoAFile(dataIndices, options.getOutput()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,9 @@
@RunWith(JUnit4.class)
public class ExtractSimilarCallsetsTest {

BiMap<String, Integer> dataIndices;

@Before
public void setUp() {
dataIndices = HashBiMap.create();
dataIndices.put("ref", dataIndices.size());
dataIndices.put("alt1", dataIndices.size());
dataIndices.put("alt2", dataIndices.size());
dataIndices.put("alt3", dataIndices.size());
}

@Test
public void testGetSamplesWithVariant() throws Exception {
ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1(dataIndices);
ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1();

Variant variant = Variant.newBuilder().build();
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));
Expand All @@ -61,18 +50,18 @@ public void testGetSamplesWithVariant() throws Exception {

VariantCall alt1 = DataUtils.makeVariantCall("alt1", 1, 0);
variant = Variant.newBuilder().addCalls(ref).addCalls(alt1).build();
assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant));
assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant));

VariantCall alt2 = DataUtils.makeVariantCall("alt2", 0, 1);
VariantCall alt3 = DataUtils.makeVariantCall("alt3", 1, 1);
variant = Variant.newBuilder().addCalls(ref)
.addCalls(alt1).addCalls(alt2).addCalls(alt3).build();
assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant));
assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant));
}

@Test
public void testGetSamplesWithVariantv1beta2() throws Exception {
ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2(dataIndices);
ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2();

com.google.api.services.genomics.model.Variant variant = new com.google.api.services.genomics.model.Variant();
List<Call> calls = new ArrayList<Call>();
Expand All @@ -84,10 +73,10 @@ public void testGetSamplesWithVariantv1beta2() throws Exception {
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt1", 1, 0));
assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant));
assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt2", 0, 1));
calls.add(DataUtils.makeCall("alt3", 1, 1));
assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant));
assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant));
}
}

0 comments on commit b5c25c4

Please sign in to comment.