Skip to content

Commit

Permalink
Extracted Enum from the insert-size distribution type.
Browse files Browse the repository at this point in the history
  • Loading branch information
vruano committed Aug 30, 2018
1 parent 22196c3 commit 2eff1aa
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 319 deletions.
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
package org.broadinstitute.hellbender.tools.spark.sv;

import com.google.api.client.repackaged.com.google.common.annotations.VisibleForTesting;
import org.apache.commons.math3.distribution.AbstractIntegerDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.LibraryStatistics;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.ReadMetadata;
import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -29,280 +13,16 @@
*/
public class InsertSizeDistribution implements Serializable {

private static final long serialVersionUID = -1L;

@VisibleForTesting
static final Type[] SUPPORTED_TYPES = { new NormalType(), new LogNormalType(), new EmpiricalType() };

public interface Type {

List<String> getNames();

IntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev);

default IntegerDistribution fromReadMetadataFile(final String whereFrom) {
try {
return fromSerializationFile(whereFrom);
} catch (final RuntimeException ex) {
return fromTextFile(whereFrom);
}
}

default IntegerDistribution fromTextFile(String whereFrom) {
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(BucketUtils.openFile(whereFrom)))) {
String line;
int value;
double totalSum = 0;
double totalSqSum = 0;
long totalCount = 0;
while ((line = reader.readLine()) != null) {
if (line.startsWith(ReadMetadata.CDF_PREFIX)) {
final String[] cdf = line.substring(ReadMetadata.CDF_PREFIX.length() + 1).split("\t");
long leftCdf = 0;
for (value = 0; value < cdf.length; value++) {
final long frequency = Long.parseLong(cdf[value]) - leftCdf;
leftCdf += frequency;
totalSum += frequency * value;
totalSqSum += value * value * frequency;
}
totalCount += leftCdf;
}
}
if (totalCount == 0) {
throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom);
}
final double mean = totalSum / totalCount;
final double stdDev = Math.sqrt(Math.abs(totalSqSum/totalCount - mean * mean));
return fromMeanAndStdDeviation(mean, stdDev);
} catch (final IOException ex2) {
throw new UserException.CouldNotReadInputFile(whereFrom);
} catch (final NumberFormatException ex2) {
throw new UserException.MalformedFile("the CDF contains non-numbers in " + whereFrom);
}
}

default IntegerDistribution fromSerializationFile(String whereFrom) {
final ReadMetadata metaData = ReadMetadata.Serializer.readStandalone(whereFrom);
double totalSum = 0;
double totalSqSum = 0;
long totalCount = 0;
for (final LibraryStatistics libStats : metaData.getAllLibraryStatistics().values()) {
final IntHistogram.CDF cdf = libStats.getCDF();
final long cdfTotalCount = cdf.getTotalObservations();
final int size = cdf.size();
for (int i = 1; i < size; i++) {
final double fraction = cdf.getFraction(i) - cdf.getFraction(i - 1);
final double count = fraction * cdfTotalCount;
totalSum += count * i;
totalSqSum += count * i * i;
}
totalCount += cdfTotalCount;
}
if (totalCount == 0) {
throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom);
}
final double mean = totalSum / totalCount;
final double variance = Math.abs(totalSqSum / totalCount - mean * mean);
final double stdDev = Math.sqrt(variance);
return fromMeanAndStdDeviation(mean, stdDev);
}

}

public static class NormalType implements Type {

@Override
public List<String> getNames() {
return Collections.unmodifiableList(Arrays.asList("Normal", "N", "Norm", "Gauss", "Gaussian"));
}

@Override
public IntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev) {
final int seed = ((Double.hashCode(mean) * 31) + Double.hashCode(stddev) * 31 ) + "Normal".hashCode() * 31;
final RandomGenerator rdnGen = new JDKRandomGenerator();
rdnGen.setSeed(seed);
final RealDistribution normal = new NormalDistribution(mean, stddev);

// We took the math for the truncted normal distribution from Wikipedia:
// https://en.wikipedia.org/wiki/Truncated_normal_distribution
final double zeroCumulative = normal.cumulativeProbability(0);
final double zeroProb = normal.density(0);
final double normalization = 1.0 / (1.0 - zeroCumulative); // = 1.0 / Z in Wikipedia
final double mu = normal.getNumericalMean();
final double sigmaSquare = normal.getNumericalVariance();
final double sigma = Math.sqrt(sigmaSquare);
final double newMean = mu + sigma * zeroProb * normalization;
final double newVariance = sigmaSquare * (1 + normalization * (-mu / sigma) * zeroProb
- Math.pow(zeroProb * normalization, 2));
// The actual distribution:
return new AbstractIntegerDistribution(rdnGen) {

private static final long serialVersionUID = -1L;

@Override
public double probability(int x) {
return normal.probability(x) * normalization;
}

@Override
public double cumulativeProbability(int x) {
return (normal.cumulativeProbability(x) - zeroCumulative) * normalization;
}

@Override
public double getNumericalMean() {
return newMean;
}

@Override
public double getNumericalVariance() {
return newVariance;
}

@Override
public int getSupportLowerBound() {
return 0;
}

@Override
public int getSupportUpperBound() {
return Integer.MAX_VALUE;
}

@Override
public boolean isSupportConnected() {
return true;
}
};
}

@Override
public IntegerDistribution fromSerializationFile(String whereFrom) {
final ReadMetadata metaData = ReadMetadata.Serializer.readStandalone(whereFrom);
final IntHistogram hist = new IntHistogram(2000);
for (final LibraryStatistics libStats : metaData.getAllLibraryStatistics().values()) {
final IntHistogram.CDF cdf = libStats.getCDF();
final long cdfTotalCount = cdf.getTotalObservations();
final int size = cdf.size() - 1;
hist.addObservations(0, Math.round(cdf.getFraction(0) * cdfTotalCount));
for (int i = 1; i < size; i++) {
final double fraction = cdf.getFraction(i) - cdf.getFraction(i - 1);
final long count = Math.round(fraction * cdfTotalCount);
hist.addObservations(i, count);
}
}
if (hist.getTotalObservations() == 0) {
throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom);
}
return hist.empiricalDistribution((int) Math.max(1, hist.getTotalObservations() / 1_000_000));
}
}

public static class EmpiricalType implements Type {

@Override
public List<String> getNames() {
return Collections.unmodifiableList(Arrays.asList("Empirical", "Emp"));
}

@Override
public IntegerDistribution fromMeanAndStdDeviation(double mean, double stddev) {
throw new UserException.BadInput("Empirical insert-size-distribution needs a meta-file");
}

@Override
public IntegerDistribution fromTextFile(String whereFrom) {
final IntHistogram hist = new IntHistogram(2000); // 2000 is the number of tracked values i.e. 0..2000
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(BucketUtils.openFile(whereFrom)))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith(ReadMetadata.CDF_PREFIX)) {
final String[] cdf = line.substring(ReadMetadata.CDF_PREFIX.length() + 1).split("\t");
long leftCdf = 0;
for (int value = 0; value < cdf.length; value++) {
final long frequency = Long.parseLong(cdf[value]) - leftCdf;
hist.addObservations(value, frequency);
}
}
}
if (hist.getTotalObservations() == 0) {
throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom);
}
// We apply a smoothing that won't yield probabilities far below 10^-6 (Phred ~ 60)
return hist.empiricalDistribution((int) Math.max(1, hist.getTotalObservations() / 1_000_000));
} catch (final IOException ex) {
throw new UserException.CouldNotReadInputFile(whereFrom);
}
}

}

public static class LogNormalType implements Type {

@Override
public List<String> getNames() {
return Collections.unmodifiableList(Arrays.asList("logN", "lnN", "logNorm", "lnNorm", "logNormal", "lnNormal"));
}

@Override
public IntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev) {
final double var = stddev * stddev;
final double scale = 2 * Math.log(mean) - 0.5 * Math.log(var + mean * mean); // scale = mu in wikipedia article.
final double shape = Math.sqrt(Math.log(1 + (var / (mean * mean)))); // shape = sigma in wikipedia article.
final RealDistribution real = new LogNormalDistribution(scale, shape);
final int seed = (((Double.hashCode(mean) * 31) + Double.hashCode(stddev) * 31) + "LogNormal".hashCode() * 31);
final RandomGenerator rdnGen = new JDKRandomGenerator();
rdnGen.setSeed(seed);
return new AbstractIntegerDistribution(rdnGen) {

private static final long serialVersionUID = -1L;

@Override
public double probability(int x) {
return real.density(x);
}

@Override
public double cumulativeProbability(int x) {
return real.cumulativeProbability(x);
}

@Override
public double getNumericalMean() {
return real.getNumericalMean();
}

@Override
public double getNumericalVariance() {
return real.getNumericalVariance();
}

@Override
public int getSupportLowerBound() {
return 0;
}

@Override
public int getSupportUpperBound() {
return Integer.MAX_VALUE;
}

@Override
public boolean isSupportConnected() {
return true;
}
};
}
}
private static final long serialVersionUID = 1L;

private static Pattern DESCRIPTION_PATTERN =
Pattern.compile("^\\s*(?<name>[^\\s\\(\\)]+)\\s*\\((?<mean>[^,\\(\\)]+?)\\s*(?:,\\s*(?<stddev>[^,\\(\\)]+?)\\s*)?\\)\\s*");

private final String description;

private transient IntegerDistribution dist;
private transient AbstractIntegerDistribution dist;

private IntegerDistribution dist() {
private AbstractIntegerDistribution dist() {
initializeDistribution();
return dist;
}
Expand All @@ -325,11 +45,13 @@ private void initializeDistribution() {
throw new UserException.BadInput("unsupported insert size distribution description format: " + description);
}
final Matcher matcher = DESCRIPTION_PATTERN.matcher(description);
matcher.find();
if (!matcher.find()) {
throw new UserException.BadInput("the insert-size distribution spec is not up to standard: " + description);
}
final String nameString = matcher.group("name");
final String meanString = matcher.group("mean");
final String stddevString = matcher.group("stddev");
final Type type = extractDistributionType(nameString, description);
final InsertSizeDistributionShape type = extractDistributionShape(nameString, description);
if (stddevString != null) {
final double mean = extractDoubleParameter("mean", description, meanString, 0, Double.MAX_VALUE);
final double stddev = extractDoubleParameter("stddev", description, stddevString, 0, Double.MAX_VALUE);
Expand All @@ -339,14 +61,13 @@ private void initializeDistribution() {
}
}

private static Type extractDistributionType(final String nameString, final String description) {
for (final Type candidate : SUPPORTED_TYPES) {
if (candidate.getNames().stream().anyMatch(name -> name.toLowerCase().equals(nameString.trim().toLowerCase()))) {
return candidate;
}
private static InsertSizeDistributionShape extractDistributionShape(final String nameString, final String description) {
final InsertSizeDistributionShape result = InsertSizeDistributionShape.decode(nameString);
if (result == null) {
throw new UserException.BadInput("unsupported insert size distribution name '" + nameString
+ "' in description: " + description);
}
throw new UserException.BadInput("unsupported insert size distribution name '" + nameString
+ "' in description: " + description);
return result;
}

private static double extractDoubleParameter(final String name, final String description,
Expand Down Expand Up @@ -391,11 +112,11 @@ public int maximum() {
return Math.min(Integer.MAX_VALUE, dist().getSupportUpperBound());
}

public double density(final int size) {
public double probability(final int size) {
return dist().probability(size);
}

public double logDensity(final int size) {
return Math.log(dist().probability(size));
public double logProbability(final int size) {
return dist().logProbability(size);
}
}
Loading

0 comments on commit 2eff1aa

Please sign in to comment.