Skip to content

Commit

Permalink
Merge pull request googlegenomics#134 from deflaux/master
Browse files Browse the repository at this point in the history
Revert back to string indices for variant similarity.
  • Loading branch information
deflaux committed Aug 13, 2015
2 parents b695ce0 + c2f1e0f commit 3028a83
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@
*/
public class ExtractSimilarCallsets {

abstract static class ExtractSimilarCallsetsBase<V, C> extends DoFn<V, KV<KV<Integer, Integer>, Long>> {
abstract static class ExtractSimilarCallsetsBase<V, C> extends DoFn<V, KV<KV<String, String>, Long>> {

protected BiMap<String, Integer> dataIndices;
private ImmutableMultiset.Builder<KV<Integer, Integer>> accumulator;

public ExtractSimilarCallsetsBase(BiMap<String, Integer> dataIndices) {
this.dataIndices = dataIndices;
}
private ImmutableMultiset.Builder<KV<String, String>> accumulator;

@Override
public void startBundle(Context c) {
Expand All @@ -50,37 +45,33 @@ public void startBundle(Context c) {

@Override
public void processElement(ProcessContext context) {
FluentIterable<KV<Integer, Integer>> pairs = PairGenerator.WITH_REPLACEMENT.allPairs(
FluentIterable<KV<String, String>> pairs = PairGenerator.WITH_REPLACEMENT.allPairs(
getSamplesWithVariant(context.element()), Ordering.natural());
for (KV<Integer, Integer> pair : pairs) {
for (KV<String, String> pair : pairs) {
accumulator.add(pair);
}
}

@Override
public void finishBundle(Context context) {
for (Multiset.Entry<KV<Integer, Integer>> entry : accumulator.build().entrySet()) {
for (Multiset.Entry<KV<String, String>> entry : accumulator.build().entrySet()) {
context.output(KV.of(entry.getElement(), Long.valueOf(entry.getCount())));
}
}

protected abstract ImmutableList<Integer> getSamplesWithVariant(V variant);
protected abstract ImmutableList<String> getSamplesWithVariant(V variant);
}

public static class v1 extends ExtractSimilarCallsetsBase<Variant, VariantCall> {

public v1(BiMap<String, Integer> dataIndices) {
super(dataIndices);
}

@Override
protected ImmutableList<Integer> getSamplesWithVariant(Variant variant) {
protected ImmutableList<String> getSamplesWithVariant(Variant variant) {
return ImmutableList.copyOf(Iterables.transform(
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<VariantCall, Integer>() {
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<VariantCall, String>() {

@Override
public Integer apply(VariantCall call) {
return dataIndices.get(call.getCallSetName());
public String apply(VariantCall call) {
return call.getCallSetName();
}

}));
Expand All @@ -90,18 +81,14 @@ public Integer apply(VariantCall call) {
@Deprecated // Remove this when fully migrated to gRPC.
public static class v1beta2 extends ExtractSimilarCallsetsBase<com.google.api.services.genomics.model.Variant, Call> {

public v1beta2(BiMap<String, Integer> dataIndices) {
super(dataIndices);
}

@Override
protected ImmutableList<Integer> getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) {
protected ImmutableList<String> getSamplesWithVariant(com.google.api.services.genomics.model.Variant variant) {
return ImmutableList.copyOf(Iterables.transform(
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, Integer>() {
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, String>() {

@Override
public Integer apply(Call call) {
return dataIndices.get(call.getCallSetName());
public String apply(Call call) {
return call.getCallSetName();
}

}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
* The input data must be for a similarity matrix which will be symmetric. This is not
* the same as Principal Component Analysis.
*/
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<Integer, Integer>, Long>>, PDone> {
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<String, String>, Long>>, PDone> {

private static final Combine.CombineFn<KV<KV<Integer, Integer>, Long>,
List<KV<KV<Integer, Integer>, Long>>, Iterable<KV<KV<Integer, Integer>, Long>>> TO_LIST =
private static final Combine.CombineFn<KV<KV<String, String>, Long>,
List<KV<KV<String, String>, Long>>, Iterable<KV<KV<String, String>, Long>>> TO_LIST =
toList();

private static <X> Combine.CombineFn<X, List<X>, Iterable<X>> toList() {
Expand Down Expand Up @@ -79,12 +79,13 @@ public OutputPCoAFile(BiMap<String, Integer> dataIndices, String outputFile) {
}

@Override
public PDone apply(PCollection<KV<KV<Integer, Integer>, Long>> similarPairs) {
public PDone apply(PCollection<KV<KV<String, String>, Long>> similarPairs) {
return similarPairs
.apply(Sum.<KV<Integer, Integer>>longsPerKey())
.apply(Sum.<KV<String, String>>longsPerKey())
.apply(Combine.globally(TO_LIST))
.apply(ParDo.of(new PCoAnalysis(dataIndices)))
.apply(ParDo.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
.apply(ParDo.named("PCoAAnalysis").of(new PCoAnalysis(dataIndices)))
.apply(ParDo.named("FormatGraphData")
.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
@Override
public void processElement(ProcessContext c) throws Exception {
Iterable<PCoAnalysis.GraphResult> graphResults = c.element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,20 @@

package com.google.cloud.genomics.dataflow.functions;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import com.google.api.client.util.Lists;
import com.google.api.client.util.Preconditions;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.dataflow.functions.PCoAnalysis.GraphResult;
import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions;
import com.google.cloud.genomics.utils.GenomicsFactory;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/**
* This function runs a Principal Coordinate Analysis inside of a SeqDo.
* It can not be parallelized.
Expand All @@ -47,16 +42,17 @@
* The input data to this algorithm must be for a similarity matrix - and the
* resulting matrix must be symmetric.
*
* Input: KV(KV(dataIndex, dataIndex), count of how similar the data pair is)
* Input: KV(KV(dataName, dataName), count of how similar the data pair is)
* Output: GraphResults - an x/y pair and a label
*
* Example input for a tiny dataset of size 2:
*
* KV(KV(0, 0), 5)
* KV(KV(1, 1), 5)
* KV(KV(1, 0), 2)
* KV(KV(data1, data1), 5)
* KV(KV(data1, data2), 2)
* KV(KV(data2, data2), 5)
* KV(KV(data2, data1), 2)
*/
public class PCoAnalysis extends DoFn<Iterable<KV<KV<Integer, Integer>, Long>>,
public class PCoAnalysis extends DoFn<Iterable<KV<KV<String, String>, Long>>,
Iterable<PCoAnalysis.GraphResult>> {

public static class GraphResult implements Serializable {
Expand Down Expand Up @@ -198,17 +194,14 @@ private List<GraphResult> getPcaData(double[][] data, BiMap<Integer, String> dat
return results;
}

@Override
public void processElement(ProcessContext context) {
Collection<KV<KV<Integer, Integer>, Long>> element = ImmutableList.copyOf(context.element());
@Override public void processElement(ProcessContext context) {
Collection<KV<KV<String, String>, Long>> element = ImmutableList.copyOf(context.element());

// Transform the data into a matrix
int dataSize = dataIndices.size();
double[][] matrixData = new double[dataSize][dataSize];
for (KV<KV<Integer, Integer>, Long> entry : element) {
int d1 = entry.getKey().getKey();
int d2 = entry.getKey().getValue();

for (KV<KV<String, String>, Long> entry : element) {
int d1 = dataIndices.get(entry.getKey().getKey());
int d2 = dataIndices.get(entry.getKey().getValue());
double value = entry.getValue();
matrixData[d1][d2] = value;
if (d1 != d2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

GenomicsFactory.OfflineAuth auth = GenomicsOptions.Methods.getGenomicsAuth(options);

// Use integer indices instead of string callSetNames to reduce data sizes.
List<String> callSetNames = GenomicsUtils.getCallSetsNames(options.getDatasetId() , auth);
Collections.sort(callSetNames); // Ensure a stable sort order for reproducible results.
BiMap<String, Integer> dataIndices = HashBiMap.create();
Expand All @@ -92,7 +91,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept
Proto2Coder.of(StreamVariantsRequest.class));

p.begin();
PCollection<KV<KV<Integer, Integer>, Long>> similarCallsets = null;
PCollection<KV<KV<String, String>, Long>> similarCallsets = null;

if(options.getUseGrpc()) {
List<StreamVariantsRequest> requests = options.isAllReferences() ?
Expand All @@ -102,7 +101,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

similarCallsets = p.apply(Create.of(requests))
.apply(new VariantStreamer(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS))
.apply(ParDo.of(new ExtractSimilarCallsets.v1(dataIndices)));
.apply(ParDo.of(new ExtractSimilarCallsets.v1()));
} else {
List<SearchVariantsRequest> requests = options.isAllReferences() ?
ShardUtils.getPaginatedVariantRequests(options.getDatasetId(), ShardUtils.SexChromosomeFilter.EXCLUDE_XY,
Expand All @@ -111,7 +110,7 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept

similarCallsets = p.apply(Create.of(requests))
.apply(ParDo.of(new VariantReader(auth, ShardBoundary.Requirement.STRICT, VARIANT_FIELDS)))
.apply(ParDo.of(new ExtractSimilarCallsets.v1beta2(dataIndices)));
.apply(ParDo.of(new ExtractSimilarCallsets.v1beta2()));
}

similarCallsets.apply(new OutputPCoAFile(dataIndices, options.getOutput()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,9 @@
@RunWith(JUnit4.class)
public class ExtractSimilarCallsetsTest {

BiMap<String, Integer> dataIndices;

@Before
public void setUp() {
dataIndices = HashBiMap.create();
dataIndices.put("ref", dataIndices.size());
dataIndices.put("alt1", dataIndices.size());
dataIndices.put("alt2", dataIndices.size());
dataIndices.put("alt3", dataIndices.size());
}

@Test
public void testGetSamplesWithVariant() throws Exception {
ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1(dataIndices);
ExtractSimilarCallsets.v1 doFn = new ExtractSimilarCallsets.v1();

Variant variant = Variant.newBuilder().build();
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));
Expand All @@ -61,18 +50,18 @@ public void testGetSamplesWithVariant() throws Exception {

VariantCall alt1 = DataUtils.makeVariantCall("alt1", 1, 0);
variant = Variant.newBuilder().addCalls(ref).addCalls(alt1).build();
assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant));
assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant));

VariantCall alt2 = DataUtils.makeVariantCall("alt2", 0, 1);
VariantCall alt3 = DataUtils.makeVariantCall("alt3", 1, 1);
variant = Variant.newBuilder().addCalls(ref)
.addCalls(alt1).addCalls(alt2).addCalls(alt3).build();
assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant));
assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant));
}

@Test
public void testGetSamplesWithVariantv1beta2() throws Exception {
ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2(dataIndices);
ExtractSimilarCallsets.v1beta2 doFn = new ExtractSimilarCallsets.v1beta2();

com.google.api.services.genomics.model.Variant variant = new com.google.api.services.genomics.model.Variant();
List<Call> calls = new ArrayList<Call>();
Expand All @@ -84,10 +73,10 @@ public void testGetSamplesWithVariantv1beta2() throws Exception {
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt1", 1, 0));
assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant));
assertEquals(Collections.singletonList("alt1"), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt2", 0, 1));
calls.add(DataUtils.makeCall("alt3", 1, 1));
assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant));
assertEquals(Arrays.asList("alt1", "alt2", "alt3"), doFn.getSamplesWithVariant(variant));
}
}

0 comments on commit 3028a83

Please sign in to comment.