From b5a198a834410745b4840ada435ce6e97365b0c7 Mon Sep 17 00:00:00 2001 From: vruano Date: Wed, 29 Aug 2018 16:07:15 -0400 Subject: [PATCH] Added the Empirical Insert-Size-Distribution Shape (arbitrary fraction for each isize). Move the different shape types (normal, lognormal and empirical) into an enum. --- .../spark/sv/InsertSizeDistribution.java | 166 ++------- .../spark/sv/InsertSizeDistributionShape.java | 314 ++++++++++++++++++ .../tools/spark/utils/IntHistogram.java | 113 ++++++- .../InsertSizeDistributionShapeUnitTest.java | 40 +++ .../sv/InsertSizeDistributionUnitTest.java | 69 ++-- .../hellbender/utils/IntHistogramTest.java | 39 ++- 6 files changed, 569 insertions(+), 172 deletions(-) create mode 100644 src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShape.java create mode 100644 src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShapeUnitTest.java diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistribution.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistribution.java index 00482e4aa75..c805e218daf 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistribution.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistribution.java @@ -1,26 +1,10 @@ package org.broadinstitute.hellbender.tools.spark.sv; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.io.Input; -import com.google.api.client.repackaged.com.google.common.annotations.VisibleForTesting; -import org.apache.commons.math3.distribution.AbstractRealDistribution; -import org.apache.commons.math3.distribution.LogNormalDistribution; -import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.AbstractIntegerDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; import org.broadinstitute.hellbender.exceptions.UserException; -import org.broadinstitute.hellbender.tools.spark.sv.evidence.LibraryStatistics; -import org.broadinstitute.hellbender.tools.spark.sv.evidence.ReadMetadata; -import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram; -import org.broadinstitute.hellbender.utils.gcs.BucketUtils; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; + import java.io.Serializable; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -29,123 +13,16 @@ */ public class InsertSizeDistribution implements Serializable { - private static final long serialVersionUID = -1L; - - @VisibleForTesting - static final Type[] SUPPORTED_TYPES = { new NormalType(), new LogNormalType() }; - - public interface Type { - List getNames(); - - AbstractRealDistribution fromMeanAndStdDeviation(final double mean, final double stddev); - - default AbstractRealDistribution fromReadMetadataFile(final String whereFrom) { - try { - return fromSerializationFile(whereFrom); - } catch (final RuntimeException ex) { - return fromTextFile(whereFrom); - } - } - - default AbstractRealDistribution fromTextFile(String whereFrom) { - try (final BufferedReader reader = new BufferedReader(new InputStreamReader(BucketUtils.openFile(whereFrom)))) { - String line; - int value; - double totalSum = 0; - double totalSqSum = 0; - long totalCount = 0; - while ((line = reader.readLine()) != null) { - if (line.startsWith(ReadMetadata.CDF_PREFIX)) { - final String[] cdf = line.substring(ReadMetadata.CDF_PREFIX.length() + 1).split("\t"); - long leftCdf = 0; - for (value = 0; value < cdf.length; value++) { - final long frequency = Long.parseLong(cdf[value]) - leftCdf; - leftCdf += frequency; - totalSum += frequency * value; - totalSqSum += value * value * frequency; - } - totalCount += leftCdf; - } - } - if (totalCount == 0) { - throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); - } - final double mean = totalSum / totalCount; - final double stdDev = Math.sqrt(Math.abs(totalSqSum/totalCount - mean * mean)); - return fromMeanAndStdDeviation(mean, stdDev); - } catch (final IOException ex2) { - throw new UserException.CouldNotReadInputFile(whereFrom); - } catch (final NumberFormatException ex2) { - throw new UserException.MalformedFile("the CDF contains non-numbers in " + whereFrom); - } - } - - default AbstractRealDistribution fromSerializationFile(String whereFrom) { - final ReadMetadata metaData = ReadMetadata.Serializer.readStandalone(whereFrom); - double totalSum = 0; - double totalSqSum = 0; - long totalCount = 0; - for (final LibraryStatistics libStats : metaData.getAllLibraryStatistics().values()) { - final IntHistogram.CDF cdf = libStats.getCDF(); - final long cdfTotalCount = cdf.getTotalObservations(); - final int size = cdf.size(); - for (int i = 1; i < size; i++) { - final double fraction = cdf.getFraction(i) - cdf.getFraction(i - 1); - final double count = fraction * cdfTotalCount; - totalSum += count * i; - totalSqSum += count * i * i; - } - totalCount += cdfTotalCount; - } - if (totalCount == 0) { - throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); - } - final double mean = totalSum / totalCount; - final double variance = Math.abs(totalSqSum / totalCount - mean * mean); - final double stdDev = Math.sqrt(variance); - return fromMeanAndStdDeviation(mean, stdDev); - } - - } - - public static class NormalType implements Type { - - @Override - public List getNames() { - return Collections.unmodifiableList(Arrays.asList("Normal", "N", "Norm", "Gauss", "Gaussian")); - } - - @Override - public AbstractRealDistribution fromMeanAndStdDeviation(final double mean, final double stddev) { - return new NormalDistribution(mean, stddev); - } - - } - - public static class LogNormalType implements Type { - - @Override - public List getNames() { - return Collections.unmodifiableList(Arrays.asList("logN", "lnN", "logNorm", "lnNorm", "logNormal", "lnNormal")); - } - - @Override - public AbstractRealDistribution fromMeanAndStdDeviation(final double mean, final double stddev) { - final double var = stddev * stddev; - final double scale = 2 * Math.log(mean) - 0.5 * Math.log(var + mean * mean); // scale = mu in wikipedia article. - final double shape = Math.sqrt(Math.log(1 + (var / (mean * mean)))); // shape = sigma in wikipedia article. - return new LogNormalDistribution(scale, shape); - } - } + private static final long serialVersionUID = 1L; private static Pattern DESCRIPTION_PATTERN = Pattern.compile("^\\s*(?[^\\s\\(\\)]+)\\s*\\((?[^,\\(\\)]+?)\\s*(?:,\\s*(?[^,\\(\\)]+?)\\s*)?\\)\\s*"); private final String description; - private transient AbstractRealDistribution dist; + private transient AbstractIntegerDistribution dist; - private AbstractRealDistribution dist() { + private AbstractIntegerDistribution dist() { initializeDistribution(); return dist; } @@ -168,11 +45,13 @@ private void initializeDistribution() { throw new UserException.BadInput("unsupported insert size distribution description format: " + description); } final Matcher matcher = DESCRIPTION_PATTERN.matcher(description); - matcher.find(); + if (!matcher.find()) { + throw new UserException.BadInput("the insert-size distribution spec is not up to standard: " + description); + } final String nameString = matcher.group("name"); final String meanString = matcher.group("mean"); final String stddevString = matcher.group("stddev"); - final Type type = extractDistributionType(nameString, description); + final InsertSizeDistributionShape type = extractDistributionShape(nameString, description); if (stddevString != null) { final double mean = extractDoubleParameter("mean", description, meanString, 0, Double.MAX_VALUE); final double stddev = extractDoubleParameter("stddev", description, stddevString, 0, Double.MAX_VALUE); @@ -182,14 +61,13 @@ private void initializeDistribution() { } } - private static Type extractDistributionType(final String nameString, final String description) { - for (final Type candidate : SUPPORTED_TYPES) { - if (candidate.getNames().stream().anyMatch(name -> name.toLowerCase().equals(nameString.trim().toLowerCase()))) { - return candidate; - } + private static InsertSizeDistributionShape extractDistributionShape(final String nameString, final String description) { + final InsertSizeDistributionShape result = InsertSizeDistributionShape.decode(nameString); + if (result == null) { + throw new UserException.BadInput("unsupported insert size distribution name '" + nameString + + "' in description: " + description); } - throw new UserException.BadInput("unsupported insert size distribution name '" + nameString - + "' in description: " + description); + return result; } private static double extractDoubleParameter(final String name, final String description, @@ -227,18 +105,18 @@ public boolean equals(final Object obj) { } public int minimum() { - return (int) Math.max(0, dist().getSupportLowerBound()); + return Math.max(0, dist().getSupportLowerBound()); } public int maximum() { - return (int) Math.min(Integer.MAX_VALUE, dist().getSupportUpperBound()); + return Math.min(Integer.MAX_VALUE, dist().getSupportUpperBound()); } - public double density(final int size) { - return dist().density(size); + public double probability(final int size) { + return dist().probability(size); } - public double logDensity(final int size) { - return dist().logDensity(size); + public double logProbability(final int size) { + return dist().logProbability(size); } } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShape.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShape.java new file mode 100644 index 00000000000..e644261db72 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShape.java @@ -0,0 +1,314 @@ +package org.broadinstitute.hellbender.tools.spark.sv; + +import org.apache.commons.math3.distribution.AbstractIntegerDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.commons.math3.distribution.LogNormalDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.broadinstitute.hellbender.exceptions.UserException; +import org.broadinstitute.hellbender.tools.spark.sv.evidence.LibraryStatistics; +import org.broadinstitute.hellbender.tools.spark.sv.evidence.ReadMetadata; +import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram; +import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.gcs.BucketUtils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Supported insert size distributions shapes. + */ +public enum InsertSizeDistributionShape { + /** + * The insert size distribution follow a normal distribution truncated at 0. + */ + NORMAL("N", "Gauss", "Gaussian") { + public AbstractIntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev) { + final int seed = ((Double.hashCode(mean) * 31) + Double.hashCode(stddev) * 31 ) + this.name().hashCode() * 31; + final RandomGenerator rdnGen = new JDKRandomGenerator(); + rdnGen.setSeed(seed); + final RealDistribution normal = new NormalDistribution(mean, stddev); + + // We took the math for the truncted normal distribution from Wikipedia: + // https://en.wikipedia.org/wiki/Truncated_normal_distribution +// final NormalDistribution normal = new NormalDistribution(mean, stddev); +// final double Z = 1.0 / (1.0 - normal.cumulativeProbability(-0.5)); +// expectedDensity = (x) -> Z * (normal.cumulativeProbability(x + 0.5) - normal.cumulativeProbability(x - 0.5)); + + final double zeroCumulative = normal.cumulativeProbability(-0.5); + final double zeroProb = normal.density(0); + final double normalization = 1.0 / (1.0 - zeroCumulative); // = 1.0 / Z in Wikipedia + final double mu = normal.getNumericalMean(); + final double sigmaSquare = normal.getNumericalVariance(); + final double sigma = Math.sqrt(sigmaSquare); + final double newMean = mu + sigma * zeroProb * normalization; + final double newVariance = sigmaSquare * (1 + normalization * (-mu / sigma) * zeroProb + - Math.pow(zeroProb * normalization, 2)); + // The actual distribution: + return new AbstractIntegerDistribution(rdnGen) { + + private static final long serialVersionUID = -1L; + + @Override + public double probability(int x) { + return normalization * (normal.cumulativeProbability(x + 0.5) - normal.cumulativeProbability(x - 0.5)); + } + + @Override + public double cumulativeProbability(int x) { + return (normal.cumulativeProbability(x + 0.5) - zeroCumulative) * normalization; + } + + @Override + public double getNumericalMean() { + return newMean; + } + + @Override + public double getNumericalVariance() { + return newVariance; + } + + @Override + public int getSupportLowerBound() { + return 0; + } + + @Override + public int getSupportUpperBound() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isSupportConnected() { + return true; + } + }; + } + + @Override + public AbstractIntegerDistribution fromSerializationFile(String whereFrom) { + final ReadMetadata metaData = ReadMetadata.Serializer.readStandalone(whereFrom); + final IntHistogram hist = new IntHistogram(2000); + long modeCount = 0; + for (final LibraryStatistics libStats : metaData.getAllLibraryStatistics().values()) { + final IntHistogram.CDF cdf = libStats.getCDF(); + final long cdfTotalCount = cdf.getTotalObservations(); + final int size = cdf.size() - 1; + hist.addObservations(0, Math.round(cdf.getFraction(0) * cdfTotalCount)); + for (int i = 1; i < size; i++) { + final double fraction = cdf.getFraction(i) - cdf.getFraction(i - 1); + final long count = Math.round(fraction * cdfTotalCount); + if (modeCount < count) { + modeCount = count; + } + hist.addObservations(i, count); + } + } + if (hist.getTotalObservations() == 0) { + throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); + } + return hist.empiricalDistribution((int) Math.max(1, modeCount / 1_000_000)); + } + + + }, /** + * The insert size distribution follows a log-normal; i.e. the exp(isize) ~ Normal. + */ + LOG_NORMAL("LogNormal", "LnN") { + @Override + public AbstractIntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev) { + final double var = stddev * stddev; + final double scale = 2 * Math.log(mean) - 0.5 * Math.log(var + mean * mean); // scale = mu in wikipedia article. + final double shape = Math.sqrt(Math.log(1 + (var / (mean * mean)))); // shape = sigma in wikipedia article. + final RealDistribution real = new LogNormalDistribution(scale, shape); + final int seed = (((Double.hashCode(mean) * 31) + Double.hashCode(stddev) * 31) + name().hashCode() * 31); + final RandomGenerator rdnGen = new JDKRandomGenerator(); + rdnGen.setSeed(seed); + return new AbstractIntegerDistribution(rdnGen) { + + private static final long serialVersionUID = -1L; + + @Override + public double probability(int x) { + return real.cumulativeProbability(x + 0.5) - real.cumulativeProbability(x - 0.5); + } + + @Override + public double cumulativeProbability(int x) { + return real.cumulativeProbability(x + 0.5); + } + + @Override + public double getNumericalMean() { + return real.getNumericalMean(); + } + + @Override + public double getNumericalVariance() { + return real.getNumericalVariance(); + } + + @Override + public int getSupportLowerBound() { + return 0; + } + + @Override + public int getSupportUpperBound() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isSupportConnected() { + return true; + } + }; + } + }, /** + * Arbitrary densities are set for each possible insert size. + */ + EMPIRICAL("E", "Emp") { + + @Override + public AbstractIntegerDistribution fromMeanAndStdDeviation(double mean, double stddev) { + throw new UserException.BadInput("Empirical insert-size-distribution needs a meta-file"); + } + + @Override + public AbstractIntegerDistribution fromTextFile(String whereFrom) { + final IntHistogram hist = new IntHistogram(2000); // 2000 is the number of tracked values i.e. 0..2000 + long modeCount = 0; + try (final BufferedReader reader = new BufferedReader(new InputStreamReader(BucketUtils.openFile(whereFrom)))) { + String line; + while ((line = reader.readLine()) != null) { + if (line.startsWith(ReadMetadata.CDF_PREFIX)) { + final String[] cdf = line.substring(ReadMetadata.CDF_PREFIX.length() + 1).split("\t"); + long leftCdf = 0; + for (int value = 0; value < cdf.length; value++) { + final long frequency = Long.parseLong(cdf[value]) - leftCdf; + hist.addObservations(value, frequency); + if (hist.getNObservations(value) > modeCount) { + modeCount = hist.getNObservations(value); + } + } + } + } + if (hist.getTotalObservations() == 0) { + throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); + } + // We apply a smoothing that won't yield probabilities far below 10^-6 (Phred ~ 60) + return hist.empiricalDistribution((int) Math.max(1, modeCount / 1_000_000)); + } catch (final IOException ex) { + throw new UserException.CouldNotReadInputFile(whereFrom); + } + } + }; + + private final List aliases; + + private static final Map byLowercaseName; + + static { + final InsertSizeDistributionShape[] shapes = values(); + byLowercaseName = new HashMap<>(shapes.length * (1 + 5)); + for (final InsertSizeDistributionShape shape : shapes) { + byLowercaseName.put(shape.name().toLowerCase(), shape); + for (final String alias : shape.aliases) { + byLowercaseName.put(alias.toLowerCase(), shape); + } + } + } + + InsertSizeDistributionShape(final String ... names) { + final List nameList = Arrays.asList(names); + aliases = Collections.unmodifiableList(nameList); + } + + public List aliases() { + return aliases; + } + + public static InsertSizeDistributionShape decode(final String name) { + Utils.nonNull(name); + return byLowercaseName.get(name.toLowerCase()); + } + + protected abstract AbstractIntegerDistribution fromMeanAndStdDeviation(final double mean, final double stddev); + + protected AbstractIntegerDistribution fromReadMetadataFile(final String whereFrom) { + try { + return fromSerializationFile(whereFrom); + } catch (final RuntimeException ex) { + return fromTextFile(whereFrom); + } + } + + protected AbstractIntegerDistribution fromTextFile(String whereFrom) { + try (final BufferedReader reader = new BufferedReader(new InputStreamReader(BucketUtils.openFile(whereFrom)))) { + String line; + int value; + double totalSum = 0; + double totalSqSum = 0; + long totalCount = 0; + while ((line = reader.readLine()) != null) { + if (line.startsWith(ReadMetadata.CDF_PREFIX)) { + final String[] cdf = line.substring(ReadMetadata.CDF_PREFIX.length() + 1).split("\t"); + long leftCdf = 0; + for (value = 0; value < cdf.length; value++) { + final long frequency = Long.parseLong(cdf[value]) - leftCdf; + leftCdf += frequency; + totalSum += frequency * value; + totalSqSum += value * value * frequency; + } + totalCount += leftCdf; + } + } + if (totalCount == 0) { + throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); + } + final double mean = totalSum / totalCount; + final double stdDev = Math.sqrt(Math.abs(totalSqSum/totalCount - mean * mean)); + return fromMeanAndStdDeviation(mean, stdDev); + } catch (final IOException ex2) { + throw new UserException.CouldNotReadInputFile(whereFrom); + } catch (final NumberFormatException ex2) { + throw new UserException.MalformedFile("the CDF contains non-numbers in " + whereFrom); + } + } + + protected AbstractIntegerDistribution fromSerializationFile(String whereFrom) { + final ReadMetadata metaData = ReadMetadata.Serializer.readStandalone(whereFrom); + double totalSum = 0; + double totalSqSum = 0; + long totalCount = 0; + for (final LibraryStatistics libStats : metaData.getAllLibraryStatistics().values()) { + final IntHistogram.CDF cdf = libStats.getCDF(); + final long cdfTotalCount = cdf.getTotalObservations(); + final int size = cdf.size(); + for (int i = 1; i < size; i++) { + final double fraction = cdf.getFraction(i) - cdf.getFraction(i - 1); + final double count = fraction * cdfTotalCount; + totalSum += count * i; + totalSqSum += count * i * i; + } + totalCount += cdfTotalCount; + } + if (totalCount == 0) { + throw new UserException.MalformedFile("Could not find any insert-sizes in " + whereFrom); + } + final double mean = totalSum / totalCount; + final double variance = Math.abs(totalSqSum / totalCount - mean * mean); + final double stdDev = Math.sqrt(variance); + return fromMeanAndStdDeviation(mean, stdDev); + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/spark/utils/IntHistogram.java b/src/main/java/org/broadinstitute/hellbender/tools/spark/utils/IntHistogram.java index 35bac4e8c15..4abacf308d9 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/spark/utils/IntHistogram.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/spark/utils/IntHistogram.java @@ -4,9 +4,15 @@ import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.commons.math3.distribution.AbstractIntegerDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; import org.broadinstitute.hellbender.utils.Utils; +import org.broadinstitute.hellbender.utils.param.ParamUtils; import java.util.Arrays; +import java.util.Random; /** Histogram of observations on a compact set of non-negative integer values. */ @DefaultSerializer(IntHistogram.Serializer.class) @@ -22,7 +28,7 @@ public IntHistogram( final int maxTrackedValue ) { totalObservations = 0; } - private IntHistogram( final long[] counts, final long totalObservations ) { + private IntHistogram( final long[] counts, final long totalObservations) { this.counts = counts; this.totalObservations = totalObservations; } @@ -33,6 +39,7 @@ private IntHistogram( final Kryo kryo, final Input input ) { long total = 0L; for ( int idx = 0; idx != len; ++idx ) { final long val = input.readLong(); + final double idxSum = val * idx; total += val; counts[idx] = val; } @@ -122,6 +129,110 @@ public IntHistogram read( final Kryo kryo, final Input input, final Class + * Later changes in the histogram won't affect the returned distribution. + *

+ *

+ * You can indicate an arbitrary "smoothing" count number which added to the + * observed frequency of every value. This way non observed values won't necessarily + * have a probability of zero. + *

+ *

+ * The supported space is enclosed in [0,{@link Integer#MAX_VALUE max_int}]. + * So the probalility of negative values is 0 and the probability of 0 and positive values is + * always at least smoothing / num. of observations. + *

+ *

+ * Despite that the probability of very large values not tracked by the histogram is not zero, + * the cumulative probability is said to reach 1.0 at the largest tracked number. + * As a consequence the distribution is actually not a proper one (unknown normalization constant). + *

+ *

+ * However since we expect that just a small fraction of the observations will fall outside the tracked range + * we assume that is a proper distribution as far as the calculation of the mean, variance, density and + * cumulative density is concern. + *

+ * + * @param smoothing zero or a positive number of counts to each value. + * @return never {@code null}. + */ + public AbstractIntegerDistribution empiricalDistribution(final int smoothing) { + ParamUtils.isPositiveOrZero(smoothing, "the smoothing must be zero or positive"); + final long[] counts = Arrays.copyOfRange(this.counts, 0, this.counts.length - 1); + final double[] cumulativeCounts = new double[counts.length]; + double sum = 0; + double sqSum = 0; + for (int i = 0; i < counts.length; i++) { + final long newCount = (counts[i] += smoothing); + sum += newCount * i; + sqSum += i * newCount * i; + } + cumulativeCounts[0] = counts[0]; + for (int i = 1; i < counts.length; i++) { + cumulativeCounts[i] = counts[i] + cumulativeCounts[i - 1]; + } + final double totalCounts = cumulativeCounts[counts.length - 1]; + final double inverseTotalCounts = 1.0 / totalCounts; + final double mean = sum / totalCounts; + final double variance = sqSum / totalCounts - mean * mean; + final int seed = Arrays.hashCode(counts); + final RandomGenerator rdnGen = new JDKRandomGenerator(); + rdnGen.setSeed(seed); + return new AbstractIntegerDistribution(rdnGen) { + + private static final long serialVersionUID = -1L; + + @Override + public double probability(int x) { + if (x < 0) { + return 0.0; + } else if (x >= counts.length) { + return smoothing * inverseTotalCounts; + } else { + return counts[x] * inverseTotalCounts; + } + } + + @Override + public double cumulativeProbability(int x) { + if (x < 0) { + return 0; + } else if (x >= counts.length) { + return 1.0; + } else { + return cumulativeCounts[x] * inverseTotalCounts; + } + } + + @Override + public double getNumericalMean() { + return mean; + } + + @Override + public double getNumericalVariance() { + return variance; + } + + @Override + public int getSupportLowerBound() { + return 0; + } + + @Override + public int getSupportUpperBound() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isSupportConnected() { + return true; + } + }; + } + @DefaultSerializer(CDF.Serializer.class) public final static class CDF { final float[] cdfFractions; diff --git a/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShapeUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShapeUnitTest.java new file mode 100644 index 00000000000..018aa4baa6c --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionShapeUnitTest.java @@ -0,0 +1,40 @@ +package org.broadinstitute.hellbender.tools.spark.sv; + +import org.testng.Assert; +import org.testng.annotations.Test; + +/** + * Unit test for {@link InsertSizeDistributionShape}. + */ +public class InsertSizeDistributionShapeUnitTest { + + @Test + public void testUniqueAliases() { + for (final InsertSizeDistributionShape shape1 : InsertSizeDistributionShape.values()) { + for (final InsertSizeDistributionShape shape2 : InsertSizeDistributionShape.values()) { + if (shape1 != shape2) { + for (final String alias1 : shape1.aliases()) { + for (final String alias2 : shape2.aliases()) { + Assert.assertNotEquals(alias1.toLowerCase(), alias2.toLowerCase()); + } + } + } + } + } + } + + @Test + public void testAliasesAndNameEncodeEachShape() { + for (final InsertSizeDistributionShape shape : InsertSizeDistributionShape.values()) { + Assert.assertSame(InsertSizeDistributionShape.decode(shape.name()), shape); + for (final String alias : shape.aliases()) { + Assert.assertSame(InsertSizeDistributionShape.decode(alias), shape); + } + } + } + + @Test + public void testDecodeOnGarbageReturnsNull() { + Assert.assertNull(InsertSizeDistributionShape.decode("Garbage")); + } +} diff --git a/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionUnitTest.java index 2fb951c1edb..e12e582f3f1 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/spark/sv/InsertSizeDistributionUnitTest.java @@ -1,6 +1,8 @@ package org.broadinstitute.hellbender.tools.spark.sv; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.math3.distribution.LogNormalDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.math3.random.JDKRandomGenerator; import org.broadinstitute.hellbender.GATKBaseTest; @@ -38,11 +40,8 @@ public void testFromSerializedMetaData() throws IOException { tempFile.deleteOnExit(); ReadMetadata.Serializer.writeStandalone(readMetadata, tempFile.toString()); final InsertSizeDistribution lnDist = new InsertSizeDistribution("LogNormal(" + tempFile.toString() + ")"); - final InsertSizeDistribution nDist = new InsertSizeDistribution("Normal(" + tempFile.toString() + ")"); Assert.assertEquals(lnDist.mean(), ReadMetadataTest.LIBRARY_STATISTICS_MEAN, ReadMetadataTest.LIBRARY_STATISTIC_MEAN_DIFF); Assert.assertEquals(lnDist.stddev(), ReadMetadataTest.LIBRARY_STATISTICS_SDEV, ReadMetadataTest.LIBRARY_STATISTICS_SDEV_DIFF); - Assert.assertEquals(nDist.mean(), ReadMetadataTest.LIBRARY_STATISTICS_MEAN, ReadMetadataTest.LIBRARY_STATISTIC_MEAN_DIFF); - Assert.assertEquals(nDist.stddev(), ReadMetadataTest.LIBRARY_STATISTICS_SDEV, ReadMetadataTest.LIBRARY_STATISTICS_SDEV_DIFF); } finally { tempFile.delete(); @@ -58,12 +57,8 @@ public void testFromTextMetaData() throws IOException { tempFile.deleteOnExit(); ReadMetadata.writeMetadata(readMetadata, tempFile.toString()); final InsertSizeDistribution lnDist = new InsertSizeDistribution("LogNormal(" + tempFile.toString() + ")"); - final InsertSizeDistribution nDist = new InsertSizeDistribution("Normal(" + tempFile.toString() + ")"); Assert.assertEquals(lnDist.mean(), ReadMetadataTest.LIBRARY_STATISTICS_MEAN, ReadMetadataTest.LIBRARY_STATISTIC_MEAN_DIFF); Assert.assertEquals(lnDist.stddev(), ReadMetadataTest.LIBRARY_STATISTICS_SDEV, ReadMetadataTest.LIBRARY_STATISTICS_SDEV_DIFF); - Assert.assertEquals(nDist.mean(), ReadMetadataTest.LIBRARY_STATISTICS_MEAN, ReadMetadataTest.LIBRARY_STATISTIC_MEAN_DIFF); - Assert.assertEquals(nDist.stddev(), ReadMetadataTest.LIBRARY_STATISTICS_SDEV, ReadMetadataTest.LIBRARY_STATISTICS_SDEV_DIFF); - } finally { tempFile.delete(); } @@ -74,17 +69,23 @@ public void testFromRealTextMetaData() { final InsertSizeDistribution lnDist = new InsertSizeDistribution("LogNormal(" + READ_METADATA_FILE.toString() + ")"); final InsertSizeDistribution nDist = new InsertSizeDistribution("Normal(" + READ_METADATA_FILE.toString() + ")"); for (final InsertSizeDistribution dist : Arrays.asList(lnDist, nDist)) { - Assert.assertEquals(dist.mean(), 379.1432, 0.00005); // calculated independently using R. - Assert.assertEquals(dist.variance(), 18163.74, 0.005); - Assert.assertEquals(dist.stddev(), 134.7729, 0.00005); + Assert.assertEquals(dist.mean(), 379.1432, 0.01); // calculated independently using R. + Assert.assertEquals(dist.variance(), 18162., 2); + Assert.assertEquals(dist.stddev(), 134.76, 0.02); } } @Test(dataProvider = "testData") public void testProbability(final String description, final int x, final double expected, final double logExpected) { final InsertSizeDistribution isd = new InsertSizeDistribution(description); - Assert.assertEquals(isd.density(x), expected, 0.00001); - Assert.assertEquals(isd.logDensity(x), logExpected, 0.01); + Assert.assertEquals(isd.probability(x), expected, 0.00001); + if (isd.probability(x) > 0 && expected > 0) { + Assert.assertEquals(isd.logProbability(x), logExpected, 0.01); + } else { + //TODO currently we don't have the hability of producing a finite log-prob if the prob is == 0. + //TODO due to limitations in apache common-math. So this else avoid to fail in these instances. + Assert.assertTrue(Math.abs(isd.logProbability(x) - logExpected) < 0.01 || isd.logProbability(x) == Double.NEGATIVE_INFINITY || logExpected == Double.NEGATIVE_INFINITY); + } } @DataProvider(name = "testData") @@ -100,8 +101,8 @@ public Object[][] testData() { final int[] fixedSizes = {1, 11, 113, 143, 243, 321, 494, 539, 10190, 301298712}; final double[] sizeSigmas = {0, -1, 1, 3.5, -2, 2, -6.9, 6.9}; final PoissonDistribution spacesDistr = new PoissonDistribution(randomGenerator, 0.1, 0.0001, 100); - for (final InsertSizeDistribution.Type type : InsertSizeDistribution.SUPPORTED_TYPES) { - final List distrNames = type.getNames(); + for (final InsertSizeDistributionShape type : Arrays.asList(InsertSizeDistributionShape.NORMAL, InsertSizeDistributionShape.LOG_NORMAL)) { + final List distrNames = type.aliases(); for (final double mean : means) { for (final double cv : cvs) { final double stddev = mean * cv; @@ -109,18 +110,21 @@ public Object[][] testData() { // the actual implementation in main relies on apache common math, so is difficult to fall into the // same error mode thus masking bugs. final IntToDoubleFunction expectedDensity; - final IntToDoubleFunction logExpecetedDensity; - if (type.getClass() == InsertSizeDistribution.NormalType.class) { - expectedDensity = (x) -> Math.exp(-.5 * Math.pow((((double) x) - mean) / stddev, 2)) * (1.0 / (stddev * Math.sqrt(2 * Math.PI))); - logExpecetedDensity = (x) -> -.5 * Math.pow((((double) x) - mean) / stddev, 2) - Math.log(stddev * Math.sqrt(2 * Math.PI)); - } else if (type.getClass() == InsertSizeDistribution.LogNormalType.class) { + final IntToDoubleFunction expectedLogDensity; + if (type == InsertSizeDistributionShape.NORMAL) { + final NormalDistribution normal = new NormalDistribution(mean, stddev); + final double Z = 1.0 / (1.0 - normal.cumulativeProbability(-0.5)); + expectedDensity = (x) -> Z * (normal.cumulativeProbability(x + 0.5) - normal.cumulativeProbability(x - 0.5)); + expectedLogDensity = (x) -> Math.log(expectedDensity.applyAsDouble(x)); + } else if (type == InsertSizeDistributionShape.LOG_NORMAL) { final double var = stddev * stddev; final double logMean = Math.log(mean) - Math.log(Math.sqrt(1 + (var / (mean * mean)))); final double logStddev = Math.sqrt(Math.log(1 + var / (mean * mean))); - expectedDensity = (x) -> Math.exp(-.5 * Math.pow((Math.log(x) - logMean) / logStddev, 2)) / (x * logStddev * Math.sqrt(2 * Math.PI)); - logExpecetedDensity = (x) -> -.5 * Math.pow((Math.log(x) - logMean) / logStddev, 2) - Math.log(x * logStddev * Math.sqrt(2 * Math.PI)); + final LogNormalDistribution logNormal = new LogNormalDistribution(logMean, logStddev); + expectedDensity = (x) -> logNormal.cumulativeProbability(x + 0.5) - logNormal.cumulativeProbability(x - 0.5); + expectedLogDensity = (x) -> Math.log(expectedDensity.applyAsDouble(x)); } else { - throw new IllegalStateException("test do not support one of the type supported by InsertSizeDistribution: " + type.getNames().get(0)); + throw new IllegalStateException("test do not support one of the type supported by InsertSizeDistribution: " + type.aliases().get(0)); } // We add fixed length cases for (final int fixedSize : fixedSizes) { @@ -128,7 +132,7 @@ public Object[][] testData() { result.add(new Object[]{ composeDescriptionString(distrName, mean, stddev, spacesDistr), fixedSize, expectedDensity.applyAsDouble(fixedSize), - logExpecetedDensity.applyAsDouble(fixedSize)}); + expectedLogDensity.applyAsDouble(fixedSize)}); } // We add relative length cases (expressed in sigmas) for (final double sizeSigma : sizeSigmas) { @@ -140,15 +144,28 @@ public Object[][] testData() { result.add(new Object[]{ composeDescriptionString(distrName, mean, stddev, spacesDistr), x, expectedDensity.applyAsDouble(x), - logExpecetedDensity.applyAsDouble(x)}); + expectedLogDensity.applyAsDouble(x)}); } } } } // A couple of hard-wired cases to attest that the code above is generating genuine cases // rather than have the same error mode as the implementation in main. - result.add(new Object[] { "N(300,150)", 231, 0.002392602, Math.log(0.002392602)}); - result.add(new Object[] { "lnN(100,1)", 103, 0.004833924, Math.log(0.004833924)}); + // I enclose commented out the code use to calculate these values in R: + + result.add(new Object[] { "N(300,150)", 231, 0.002447848, Math.log(0.002447848)}); + // > probs = pnorm(c(149.5, 151.5) , 300, 150) + // > probs[2] - probs[1] + // 0.002447848 + + result.add(new Object[] { "lnN(100,10)", 103, 0.03655933, Math.log(0.03655933)}); + // > mu = 100; sigma = 10 + // > meanlog = log(mu/sqrt(1+ (sigma^2)/mu^2)) # eq. from wikipedia article. + // > sdlog = sqrt(exp(2*mu + sigma^2)*(exp(sigma^2)-1)) # eq. from wikipedia article. + // > probs = plnorm(c(102.5, 103.5), meanlog, sdlog) + // > probs[2] - probs[1] + // 0.03655933 + return result.toArray(new Object[result.size()][]); } diff --git a/src/test/java/org/broadinstitute/hellbender/utils/IntHistogramTest.java b/src/test/java/org/broadinstitute/hellbender/utils/IntHistogramTest.java index 06b03d95564..34ecc4ed919 100644 --- a/src/test/java/org/broadinstitute/hellbender/utils/IntHistogramTest.java +++ b/src/test/java/org/broadinstitute/hellbender/utils/IntHistogramTest.java @@ -1,14 +1,17 @@ package org.broadinstitute.hellbender.utils; +import org.apache.commons.math3.distribution.IntegerDistribution; import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram; import org.broadinstitute.hellbender.GATKBaseTest; import org.testng.Assert; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.Arrays; import java.util.Random; public class IntHistogramTest extends GATKBaseTest { - private static final int MAX_TRACKED_VALUE = 10000; + private static final int MAX_TRACKED_VALUE = 2000; private static final float SIGNIFICANCE = .05f; @Test @@ -79,6 +82,40 @@ public void testBiModalDistribution() { Assert.assertTrue(cdf.isDifferentByKSStatistic(sample1, SIGNIFICANCE)); } + @Test + public void testEmpiricalDistributionWithoutSmoothingSampling() { + final IntHistogram largeSample = genNormalSample(480, 25, 10000); + final IntegerDistribution dist = largeSample.empiricalDistribution(0); + final IntHistogram distSample = new IntHistogram(MAX_TRACKED_VALUE); + Arrays.stream(dist.sample(1000)).forEach(distSample::addObservation); + Assert.assertFalse(largeSample.getCDF().isDifferentByKSStatistic(distSample, SIGNIFICANCE)); + } + + @Test(dataProvider = "smoothingValues") + public void testEmpiricalDistributionSmoothing(final int smoothing) { + final IntHistogram largeSample = genNormalSample(480, 25, 10000); + final IntegerDistribution dist = largeSample.empiricalDistribution(smoothing); + final long smoothedNumberOfObservations = largeSample.getMaximumTrackedValue() * smoothing + largeSample.getTotalObservations(); + double cumulative = 0; + double expectation = 0; + double sqExpectation = 0; + for (int i = 0; i <= largeSample.getMaximumTrackedValue(); i++) { + final double distProb = dist.probability(i); + Assert.assertEquals(distProb, (largeSample.getNObservations(i) + smoothing) / (double) smoothedNumberOfObservations, 0.0001); + cumulative += distProb; + Assert.assertEquals(dist.cumulativeProbability(i), cumulative, 0.00001); + expectation += distProb * i; + sqExpectation += i * distProb * i; + } + Assert.assertEquals(dist.getNumericalMean(), expectation, 0.00001); + Assert.assertEquals(dist.getNumericalVariance(), sqExpectation - expectation * expectation, 0.00001); + } + + @DataProvider + public Object[][] smoothingValues() { + return new Object[][] { { 0 }, { 1 }, { 2 }, { 13 }, {100 }}; + } + public static IntHistogram genNormalSample( final int mean, final int stdDev, final int nSamples ) { Random random = new Random(47L); final IntHistogram histogram = new IntHistogram(MAX_TRACKED_VALUE);