Skip to content
This repository has been archived by the owner on Oct 29, 2023. It is now read-only.

Commit

Permalink
Switch from string to integer callSet indices.
Browse files Browse the repository at this point in the history
This will help with performance.

It also allows us to put a stable sort order on the rows/columns of the similiarity matrix, facilitating reproducible results from the linear algebra operations.
  • Loading branch information
deflaux committed Jul 30, 2015
1 parent 15794a6 commit 98cf2c2
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.dataflow.utils.CallFilters;
import com.google.cloud.genomics.dataflow.utils.PairGenerator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.collect.BiMap;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;

/**
* Emits a callset pair every time they share a variant.
*/
public class ExtractSimilarCallsets extends DoFn<Variant, KV<KV<String, String>, Long>> {
public class ExtractSimilarCallsets extends DoFn<Variant, KV<KV<Integer, Integer>, Long>> {

private ImmutableMultiset.Builder<KV<String, String>> accumulator;
private BiMap<String, Integer> dataIndices;
private ImmutableMultiset.Builder<KV<Integer, Integer>> accumulator;

public ExtractSimilarCallsets(BiMap<String, Integer> dataIndices) {
this.dataIndices = dataIndices;
}

@Override
public void startBundle(Context c) {
Expand All @@ -40,27 +47,27 @@ public void startBundle(Context c) {

@Override
public void processElement(ProcessContext context) {
for (KV<String, String> pair : PairGenerator.WITH_REPLACEMENT.allPairs(
getSamplesWithVariant(context.element()), String.CASE_INSENSITIVE_ORDER)) {
FluentIterable<KV<Integer, Integer>> pairs = PairGenerator.WITH_REPLACEMENT.allPairs(
getSamplesWithVariant(context.element()), Ordering.natural());
for (KV<Integer, Integer> pair : pairs) {
accumulator.add(pair);
}
}

@Override
public void finishBundle(Context context) {
for (Multiset.Entry<KV<String, String>> entry : accumulator.build().entrySet()) {
for (Multiset.Entry<KV<Integer, Integer>> entry : accumulator.build().entrySet()) {
context.output(KV.of(entry.getElement(), Long.valueOf(entry.getCount())));
}
}

@VisibleForTesting
static ImmutableList<String> getSamplesWithVariant(Variant variant) {
ImmutableList<Integer> getSamplesWithVariant(Variant variant) {
return ImmutableList.copyOf(Iterables.transform(
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, String>() {
CallFilters.getSamplesWithVariantOfMinGenotype(variant, 1), new Function<Call, Integer>() {

@Override
public String apply(Call call) {
return call.getCallSetName();
public Integer apply(Call call) {
return dataIndices.get(call.getCallSetName());
}

}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PDone;
import com.google.common.collect.BiMap;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -36,10 +37,10 @@
* The input data must be for a similarity matrix which will be symmetric. This is not
* the same as Principal Component Analysis.
*/
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<String, String>, Long>>, PDone> {
public class OutputPCoAFile extends PTransform<PCollection<KV<KV<Integer, Integer>, Long>>, PDone> {

private static final Combine.CombineFn<KV<KV<String, String>, Long>,
List<KV<KV<String, String>, Long>>, Iterable<KV<KV<String, String>, Long>>> TO_LIST =
private static final Combine.CombineFn<KV<KV<Integer, Integer>, Long>,
List<KV<KV<Integer, Integer>, Long>>, Iterable<KV<KV<Integer, Integer>, Long>>> TO_LIST =
toList();

private static <X> Combine.CombineFn<X, List<X>, Iterable<X>> toList() {
Expand Down Expand Up @@ -69,20 +70,21 @@ private static <X> Combine.CombineFn<X, List<X>, Iterable<X>> toList() {
};
}

private BiMap<String, Integer> dataIndices;
private final String outputFile;

public OutputPCoAFile(String outputFile) {
public OutputPCoAFile(BiMap<String, Integer> dataIndices, String outputFile) {
this.dataIndices = dataIndices;
this.outputFile = outputFile;
}

@Override
public PDone apply(PCollection<KV<KV<String, String>, Long>> similarPairs) {
public PDone apply(PCollection<KV<KV<Integer, Integer>, Long>> similarPairs) {
return similarPairs
.apply(Sum.<KV<String, String>>longsPerKey())
.apply(Sum.<KV<Integer, Integer>>longsPerKey())
.apply(Combine.globally(TO_LIST))
.apply(ParDo.named("PCoAAnalysis").of(PCoAnalysis.of()))
.apply(ParDo.named("FormatGraphData")
.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
.apply(ParDo.of(new PCoAnalysis(dataIndices)))
.apply(ParDo.of(new DoFn<Iterable<PCoAnalysis.GraphResult>, String>() {
@Override
public void processElement(ProcessContext c) throws Exception {
Iterable<PCoAnalysis.GraphResult> graphResults = c.element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.google.api.client.util.Preconditions;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.dataflow.functions.PCoAnalysis.GraphResult;
import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions;
import com.google.cloud.genomics.utils.GenomicsFactory;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
Expand All @@ -44,17 +47,16 @@
* The input data to this algorithm must be for a similarity matrix - and the
* resulting matrix must be symmetric.
*
* Input: KV(KV(dataName, dataName), count of how similar the data pair is)
* Input: KV(KV(dataIndex, dataIndex), count of how similar the data pair is)
* Output: GraphResults - an x/y pair and a label
*
* Example input for a tiny dataset of size 2:
*
* KV(KV(data1, data1), 5)
* KV(KV(data1, data2), 2)
* KV(KV(data2, data2), 5)
* KV(KV(data2, data1), 2)
* KV(KV(0, 0), 5)
* KV(KV(1, 1), 5)
* KV(KV(1, 0), 2)
*/
public class PCoAnalysis extends DoFn<Iterable<KV<KV<String, String>, Long>>,
public class PCoAnalysis extends DoFn<Iterable<KV<KV<Integer, Integer>, Long>>,
Iterable<PCoAnalysis.GraphResult>> {

public static class GraphResult implements Serializable {
Expand Down Expand Up @@ -118,20 +120,10 @@ public boolean equals(Object obj) {
}
}

private static final PCoAnalysis INSTANCE = new PCoAnalysis();
private BiMap<String, Integer> dataIndicies;

public static PCoAnalysis of() {
return INSTANCE;
}

private PCoAnalysis() {}
private BiMap<String, Integer> dataIndices;

private int getDataIndex(Map<String, Integer> dataIndicies, String dataName) {
if (!dataIndicies.containsKey(dataName)) {
dataIndicies.put(dataName, dataIndicies.size());
}
return dataIndicies.get(dataName);
public PCoAnalysis(BiMap<String, Integer> dataIndices) {
this.dataIndices = dataIndices;
}

// Convert the similarity matrix to an Eigen matrix.
Expand Down Expand Up @@ -206,29 +198,23 @@ private List<GraphResult> getPcaData(double[][] data, BiMap<Integer, String> dat
return results;
}

@Override public void processElement(ProcessContext context) {
Collection<KV<KV<String, String>, Long>> element = ImmutableList.copyOf(context.element());
// Transform the data into a matrix
BiMap<String, Integer> dataIndicies = HashBiMap.create();
@Override
public void processElement(ProcessContext context) {
Collection<KV<KV<Integer, Integer>, Long>> element = ImmutableList.copyOf(context.element());

// TODO: Clean up this code
for (KV<KV<String, String>, Long> entry : element) {
getDataIndex(dataIndicies, entry.getKey().getKey());
getDataIndex(dataIndicies, entry.getKey().getValue());
}

int dataSize = dataIndicies.size();
// Transform the data into a matrix
int dataSize = dataIndices.size();
double[][] matrixData = new double[dataSize][dataSize];
for (KV<KV<String, String>, Long> entry : element) {
int d1 = getDataIndex(dataIndicies, entry.getKey().getKey());
int d2 = getDataIndex(dataIndicies, entry.getKey().getValue());
for (KV<KV<Integer, Integer>, Long> entry : element) {
int d1 = entry.getKey().getKey();
int d2 = entry.getKey().getValue();

double value = entry.getValue();
matrixData[d1][d2] = value;
if (d1 != d2) {
matrixData[d2][d1] = value;
}
}
context.output(getPcaData(matrixData, dataIndicies.inverse()));
context.output(getPcaData(matrixData, dataIndices.inverse()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
import com.google.cloud.genomics.dataflow.functions.ExtractSimilarCallsets;
import com.google.cloud.genomics.dataflow.functions.OutputPCoAFile;
import com.google.cloud.genomics.dataflow.readers.VariantReader;
import com.google.cloud.genomics.dataflow.readers.VariantStreamer;
import com.google.cloud.genomics.dataflow.utils.DataflowWorkarounds;
import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions;
import com.google.cloud.genomics.dataflow.utils.GenomicsOptions;
import com.google.cloud.genomics.utils.Contig.SexChromosomeFilter;
import com.google.cloud.genomics.utils.GenomicsFactory;
import com.google.cloud.genomics.utils.Paginator.ShardBoundary;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.List;

/**
Expand All @@ -57,15 +61,21 @@ public static void main(String[] args) throws IOException, GeneralSecurityExcept
GenomicsDatasetOptions.Methods.getVariantRequests(options, auth,
SexChromosomeFilter.EXCLUDE_XY);

// Use integer indices instead of string callSetNames to reduce data sizes.
List<String> callSetNames = VariantStreamer.getCallSetsNames(options.getDatasetId() , auth);
Collections.sort(callSetNames); // Ensure a stable sort order for reproducible results.
BiMap<String, Integer> dataIndices = HashBiMap.create();
for(String callSetName : callSetNames) {
dataIndices.put(callSetName, dataIndices.size());
}

Pipeline p = Pipeline.create(options);
DataflowWorkarounds.registerGenomicsCoders(p);
p.begin()
.apply(Create.of(requests))
.apply(
ParDo.named("VariantReader").of(
new VariantReader(auth, ShardBoundary.STRICT, VARIANT_FIELDS)))
.apply(ParDo.named("ExtractSimilarCallsets").of(new ExtractSimilarCallsets()))
.apply(new OutputPCoAFile(options.getOutput()));
.apply(ParDo.of(new VariantReader(auth, ShardBoundary.STRICT, VARIANT_FIELDS)))
.apply(ParDo.of(new ExtractSimilarCallsets(dataIndices)))
.apply(new OutputPCoAFile(dataIndices, options.getOutput()));

p.run();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package com.google.cloud.genomics.dataflow.readers;

import com.google.api.services.genomics.model.CallSet;
import com.google.api.services.genomics.model.SearchCallSetsRequest;
import com.google.api.services.genomics.model.SearchVariantSetsRequest;
import com.google.api.services.genomics.model.VariantSet;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
Expand All @@ -32,6 +34,7 @@
import com.google.genomics.v1.StreamVariantsResponse;
import com.google.genomics.v1.StreamingVariantServiceGrpc;
import com.google.genomics.v1.Variant;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Iterator;
Expand Down Expand Up @@ -64,6 +67,25 @@ public static List<String> getVariantSetIds(String datasetId, GenomicsFactory.Of
return output;
}

/**
* Gets CallSets Names for a given variantSetId using the Genomics API.
*/
public static List<String> getCallSetsNames(String variantSetId, GenomicsFactory.OfflineAuth auth)
throws IOException, GeneralSecurityException {
List<String> output = Lists.newArrayList();
Iterable<CallSet> cs = Paginator.Callsets.create(
auth.getGenomics(auth.getDefaultFactory()))
.search(new SearchCallSetsRequest().setVariantSetIds(Lists.newArrayList(variantSetId)),
"callSets/name,nextPageToken");
for (CallSet c : cs) {
output.add(c.getName());
}
if (output.isEmpty()) {
throw new IOException("VariantSet " + variantSetId + " does not contain any CallSets");
}
return output;
}

/**
* Constructs a StreamVariantsRequest for a variantSetId, assuming that the user wants all
* to include all references.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,53 @@
*/
package com.google.cloud.genomics.dataflow.functions;

import static com.google.cloud.genomics.dataflow.functions.ExtractSimilarCallsets.getSamplesWithVariant;
import static org.junit.Assert.assertEquals;

import com.google.api.client.util.Lists;
import com.google.api.services.genomics.model.Call;
import com.google.api.services.genomics.model.Variant;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.genomics.dataflow.utils.PairGenerator;
import com.google.cloud.genomics.dataflow.utils.DataUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.google.api.services.genomics.model.Call;
import com.google.api.services.genomics.model.Variant;
import com.google.cloud.genomics.dataflow.utils.DataUtils;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;

@RunWith(JUnit4.class)
public class ExtractSimilarCallsetsTest {


@Test
public void testGetSamplesWithVariant() throws Exception {

BiMap<String, Integer> dataIndices = HashBiMap.create();
dataIndices.put("ref", dataIndices.size());
dataIndices.put("alt1", dataIndices.size());
dataIndices.put("alt2", dataIndices.size());
dataIndices.put("alt3", dataIndices.size());

ExtractSimilarCallsets doFn = new ExtractSimilarCallsets(dataIndices);

Variant variant = new Variant();
List<Call> calls = new ArrayList<Call>();

variant.setCalls(calls);
assertEquals(Collections.emptyList(), getSamplesWithVariant(variant));
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("ref", 0, 0));
assertEquals(Collections.emptyList(), getSamplesWithVariant(variant));
assertEquals(Collections.emptyList(), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt1", 1, 0));
assertEquals(Collections.singletonList("alt1"), getSamplesWithVariant(variant));
assertEquals(Collections.singletonList(1), doFn.getSamplesWithVariant(variant));

calls.add(DataUtils.makeCall("alt2", 0, 1));
calls.add(DataUtils.makeCall("alt3", 1, 1));
assertEquals(Arrays.asList("alt1", "alt2", "alt3"), getSamplesWithVariant(variant));
assertEquals(Arrays.asList(1, 2, 3), doFn.getSamplesWithVariant(variant));
}

}

0 comments on commit 98cf2c2

Please sign in to comment.