diff --git a/ProbabilityMonad.Test/Models/HiddenMarkovModel.cs b/ProbabilityMonad.Test/Models/HiddenMarkovModel.cs index f7e8d7f..34a4b1f 100644 --- a/ProbabilityMonad.Test/Models/HiddenMarkovModel.cs +++ b/ProbabilityMonad.Test/Models/HiddenMarkovModel.cs @@ -29,7 +29,7 @@ from emit in emissionDist(initial) samples = from prev in samples from next in transitionDist(prev.Last().Item1) from emit in emissionDist(next) - select Append(prev, Tuple(next, emit)).ToList(); + select prev.Append(Tuple(next, emit)).ToList(); numSamples -= 1; } return samples; @@ -94,7 +94,7 @@ public static Dist> Hmm(List observed) Condition(xs => emission(xs.Last(), y), (from rest in d from x in transitionMatrix(rest.Last()) - select Append(rest, x).ToList())); + select rest.Append(x).ToList())); return observed.Aggregate(startDist, expand); } diff --git a/ProbabilityMonad/Base.cs b/ProbabilityMonad/Base.cs index 3b041d4..a5b1b08 100644 --- a/ProbabilityMonad/Base.cs +++ b/ProbabilityMonad/Base.cs @@ -1,650 +1,639 @@ -using System.Collections.Generic; -using System.Linq; -using System; -using MathNet.Numerics.LinearAlgebra; -using MathNet.Numerics.LinearAlgebra.Double; - -namespace ProbCSharp -{ - /// - /// This class exports a whole bunch of constructors. - /// Basically avoids us typing new X() - /// which can make code unreadable when we have really large types. - /// - public static class ProbBase - { - // Singleton instance of the random generator to avoid repeated values in tight loops - public static Random Gen = new Random(); - - #region Primitive object constructors - /// - /// Probability constructor - /// - public static Prob Prob(double probability) - { - return new LogProb(Math.Log(probability)); - } - - /// - /// ItemProb constructor - /// - public static ItemProb ItemProb(A item, Prob prob) - { - return new ItemProb(item, prob); - } - - /// - /// Samples constructor - /// - public static Samples Samples(IEnumerable> itemProbs) - { - return new Samples(itemProbs); - } - - /// - /// Tuple constructor - /// - public static Tuple Tuple(A a, B b) - { - return new Tuple(a, b); - } - #endregion - - #region Distribution constructors - /// - /// Finite uniform distribution over list of items. - /// Only composable with other finite distributions. - /// - public static FiniteDist EnumUniformF(IEnumerable items) - { - var uniform = Samples(items.Select(i => new ItemProb(i, Prob(1)))); - return new FiniteDist(Importance.Normalize(uniform)); - } - - /// - /// Uniform distribution over list of items - /// - public static Dist UniformFromList(IEnumerable items) - { - return Primitive(EnumUniformF(items)); - } - - /// - /// Uniform distribution over parameter items - /// Only composable with other finite distributions. - /// - public static FiniteDist UniformF(params A[] items) - { - var itemList = new List(items); - return EnumUniformF(itemList); - } - - /// - /// Uniform distribution over parameter items - /// - public static Dist Uniform(params A[] items) - { - return Primitive(UniformF(items)); - } - - /// - /// Bernoulli distribution constructed from success probability - /// Only composable with other finite distributions. - /// - public static FiniteDist BernoulliF(Prob prob) - { - return new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value))); - } - - /// - /// Bernoulli distribution constructed from success probability - /// - public static Dist Bernoulli(Prob prob) - { - return Primitive(BernoulliF(prob)); - } - - /// - /// Bernoulli distribution constructed from success probability - /// - public static Dist Bernoulli(double prob) - { - return Primitive(BernoulliF(Prob(prob))); - } - - /// - /// Bernoulli distribution constructed from two items, and probability of first item - /// - public static Dist Bernoulli(double prob, A option1, A option2) - { - return Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2); - } - - /// - /// Categorical distribution - /// Only composable with other finite distributions - /// - public static FiniteDist CategoricalF(params ItemProb[] itemProbs) - { - return new FiniteDist(Samples(itemProbs)); - } - - /// - /// Categorical distribution - /// Only composable with other finite distributions - /// - public static FiniteDist CategoricalF(Samples samples) - { - return new FiniteDist(samples); - } - - /// - /// Categorical distribution - /// - public static Dist Categorical(Samples samples) - { - return Primitive(CategoricalF(samples).ToSampleDist()); - } - - /// - /// Categorical distribution - /// - public static CategoricalPrimitive CategoricalPrimitive(A[] items, double [] probabilities) - { - return new CategoricalPrimitive(items, probabilities, Gen); - } - - /// - /// Primitive studentT distribution - /// Only composable with other primitive distributions - /// - public static StudentTPrimitive StudentTPrimitive(double location, double scale, double normality) - { - return new StudentTPrimitive(location, scale, normality, Gen); - } - - /// - /// Primitive Exponential distribution - /// Only composable with other primitive distributions - /// - public static ExponentialPrimitive ExponentialPrimitive(double rate) - { - return new ExponentialPrimitive(rate); - } - - /// - /// Primitive Contiuous Uniform distribution - /// Only composable with other primitive distributions - /// - public static ContinuousUniformPrimitive ContinuousUniformPrimitive() - { - return new ContinuousUniformPrimitive(Gen); - } - - /// - /// Primitive Contiuous Uniform distribution - /// Only composable with other primitive distributions - /// - public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower, double upper) - { - return new ContinuousUniformPrimitive(lower,upper, Gen); - } - - - /// - /// Primitive Poisson distribution - /// Only composable with other primitive distributions - /// - public static PoissonPrimitive PoissonPrimitive(double lambda) - { - return new PoissonPrimitive(lambda, Gen); - } - - /// - /// Primitive Normal distribution - /// Only composable with other primitive distributions - /// - public static NormalPrimitive NormalPrimitive(double mean, double variance) - { - return new NormalPrimitive(mean, variance, Gen); - } - - /// - /// Primitive LogNormal distribution - /// Only composable with other primitive distributions - /// - public static LogNormalPrimitive LogNormalPrimitive(double mean, double variance) - { - return new LogNormalPrimitive(mean, variance, Gen); - } - - public static LogNormalPrimitive LogNormalPrimitiveMu(double mu, double sigma) - { - return new LogNormalPrimitive(mu, sigma, true, Gen); - } - - /// - /// Primitive Beta distribution - /// Only composable with other primitive distributions - /// - public static BetaPrimitive BetaPrimitive(double alpha, double beta) - { - return new BetaPrimitive(alpha, beta, Gen); - } - - /// - /// Primitive Gamma distribution - /// Only composable with other primitive distributions - /// - public static GammaPrimitive GammaPrimitive(double shape, double rate) - { - return new GammaPrimitive(shape, rate, Gen); - } - - public static MultiVariateNormalPrimitive MultiVariateNormalPrimitive(double[] mean, Matrix covariance) - { - return new MultiVariateNormalPrimitive(mean, covariance, Gen); - } - - public static WishartPrimitive WishartPrimitive(double dof, Matrix scale) - { - return new WishartPrimitive(dof, scale, Gen); - } - - /// - /// Primitive Dirichlet distribution - /// Only composable with other primitive distributions - /// - public static DirichletPrimitive DirichletPrimitive(double[] alpha) - { - return new DirichletPrimitive(alpha, Gen); - } - - /// - /// Poisson distribution - /// - public static Dist Poisson(double lambda) - { - return Primitive(PoissonPrimitive(lambda)); - } - - /// - /// Normal distribution - /// - public static Dist Normal(double mean, double variance) - { - return Primitive(NormalPrimitive(mean, variance)); - } - - /// - /// Log Normal distribution - /// - public static Dist LogNormal(double mean, double variance) - { - return Primitive(LogNormalPrimitive(mean, variance)); - } - - public static Dist LogNormalMu(double mu, double sigma) - { - return Primitive(LogNormalPrimitiveMu(mu, sigma)); - } - - /// - /// Gamma distribution - /// - public static Dist Gamma(double shape, double rate) - { - return Primitive(GammaPrimitive(shape, rate)); - } - - /// - /// Categorical distribution - /// - public static Dist Categorical(T[] items, double[] probabilities) - { - return Primitive(CategoricalPrimitive(items, probabilities)); - } - - public static Dist Uniform() - { - return Primitive(ContinuousUniformPrimitive()); - } - - public static Dist Uniform(double lower, double upper) - { - return Primitive(ContinuousUniformPrimitive(lower,upper)); - } - - /// - /// StudenT distribution - /// - public static Dist StudentT(double location, double scale, double normality) - { - return Primitive(StudentTPrimitive(location, scale, normality)); - } - - /// - /// Exponential distribution - /// - public static Dist Exponential(double rate) - { - return Primitive(ExponentialPrimitive(rate)); - } - - /// - /// Beta distribution - /// - public static Dist Beta(double alpha, double beta) - { - return Primitive(BetaPrimitive(alpha, beta)); - } - - /// - /// Dirichlet distribution - /// - public static Dist Dirichlet(double[] alpha) - { - return Primitive(DirichletPrimitive(alpha)); - } - - public static Dist MultiVariateNormal(double[] mean, Matrix covariance) - { - return Primitive(MultiVariateNormalPrimitive(mean, covariance)); - } - - public static Dist> Wishart(double dof, Matrix scale) - { - return Primitive(WishartPrimitive(dof, scale)); - } - #endregion - - #region GADT constructors - - #region Parallel constructors - /// - /// Wraps the distribution to defer evaluation until explicitly parallelized - /// - public static Dist> Independent(Dist dist) - { - return new Independent(dist); - } - - /// - /// Evaluates two distributions in parallel, passing the results to a function - /// - public static Dist RunIndependentWith(Dist dist1, Dist dist2, Func> run) - { - return new RunIndependent(dist1, dist2, run); - } - - /// - /// Evaluates two distributions in parallel. The results are collected into a tuple. - /// - public static Dist> RunIndependent(Dist dist1, Dist dist2) - { - return new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2))); - } - - /// - /// Evaluates three distributions in parallel. The results are collected into a tuple. - /// - public static Dist> RunIndependent(Dist dist1, Dist dist2, Dist dist3) - { - return new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3))); - } - - /// - /// Evaluates three distributions in parallel, passing the results to a function - /// - public static Dist RunIndependentWith(Dist dist1, Dist dist2, Dist dist3, Func> run) - { - return new RunIndependent3(dist1, dist2, dist3, run); - } - #endregion - - /// - /// Primitive constructor for continuous dists - /// - public static Dist Primitive(PrimitiveDist dist) - { - return new Primitive(dist); - } - - /// - /// Primitive constructor for finite dists - /// - public static Dist Primitive(FiniteDist dist) - { - return new Primitive(dist.ToSampleDist()); - } - - /// - /// Pure constructor, monadic return - /// - public static Dist Return(A value) - { - return new Pure(value); - } - - /// - /// Create a conditional distribution with given likelihood function and distribution - /// - public static Dist Condition(Func likelihood, Dist dist) - { - return new Conditional(likelihood, dist); - } - - /// - /// Conditional constructor - /// - public static Dist Condition(this Dist dist, Func likelihood) - { - return new Conditional(likelihood, dist); - } - - #endregion - - #region Utility functions - /// - /// Aggregates probabilities of samples with identical values - /// The samples are arranged in ascending order - /// - /// Used to identify identical values - public static Samples Compact(Samples samples, Func keyFunc) where A : IComparable - { - return Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item)); - } - - /// - /// Aggregates probabilities of samples with identical values - /// - /// Used to identify identical values - public static Samples CompactUnordered(Samples samples, Func keyFunc) - { - var compacted = - samples.Weights - .GroupBy(ip => keyFunc(ip.Item)) - .Select(g => - ItemProb( - g.First().Item, - Prob(g.Select(ip => ip.Prob.Value).Sum()) - ) - ); - return Samples(compacted); - } - - /// - /// Aggregate & normalize samples - /// The samples are arranged in ascending order - /// - public static Samples Enumerate(FiniteDist dist, Func keyFunc) where A : IComparable - { - return Importance.Normalize(Compact(dist.Explicit, keyFunc)); - } - - /// - /// The probability density function for a primitive distribution and point. - /// Throws NotImplementedException if no PDF is defined for given distribution. - /// - public static Prob Pdf(PrimitiveDist dist, double[] x) - { - - if (dist is DirichletPrimitive) - { - var dir = dist as DirichletPrimitive; - var d = new MathNet.Numerics.Distributions.Dirichlet(dir.alpha); - - return Prob(d.Density(x)); - } - - if (dist is MultiVariateNormalPrimitive) - { - var d = dist as MultiVariateNormalPrimitive; - - return Prob(d.mvn.Density(DenseMatrix.OfRowArrays(x))); - } - throw new NotImplementedException("No PDF for this distribution implemented"); - } - /// - /// The probability density function for a primitive distribution and point. - /// Throws NotImplementedException if no PDF is defined for given distribution. - /// - public static Prob Pdf(PrimitiveDist> dist, Matrix x) - { - - if (dist is WishartPrimitive) - { - var wish = dist as WishartPrimitive; - return Prob(wish.Wishart.Density(x)); - } - - throw new NotImplementedException("No PDF for this distribution implemented"); - } - /// - /// The probability density function for a primitive distribution and point. - /// Throws NotImplementedException if no PDF is defined for given distribution. - /// - public static Prob Pdf(PrimitiveDist dist, double y) - { - if (dist is NormalPrimitive) - { - var normal = dist as NormalPrimitive; - return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y)); - } - if (dist is LogNormalPrimitive) - { - var lognormal = dist as LogNormalPrimitive; - return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y)); - } - if (dist is BetaPrimitive) - { - var beta = dist as BetaPrimitive; - return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y)); - } - if (dist is GammaPrimitive) - { - var gamma = dist as GammaPrimitive; - return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y)); - } - throw new NotImplementedException("No PDF for this distribution implemented"); - } - - /// - /// The probability mass function for a primitive distribution and point. - /// Throws NotImplementedException if no PMF is defined for given distribution. - /// - public static Prob Pmf(PrimitiveDist dist, int y) - { - if (dist is PoissonPrimitive) - { - var poisson = dist as PoissonPrimitive; - return Prob(MathNet.Numerics.Distributions.Poisson.PMF(poisson.Lambda, y)); - } - throw new NotImplementedException("No PMF for this distribution implemented"); - } - - public static Prob Pmf(PrimitiveDist dist, T k) - { - if (dist is CategoricalPrimitive) - { - var cat = dist as CategoricalPrimitive; - - return Prob(MathNet.Numerics.Distributions.Categorical.PMF(cat.ProbabilityMass, cat.ItemIndex[k])); - } - throw new NotImplementedException("No PMF for this distribution implemented"); - } - - /// - /// The probability density function for a distribution and point. - /// Throws ArgumentException if the distribution is not a Primitive. - /// - public static Prob Pdf(Dist dist, double y) - { - if (dist is Primitive) - { - var primitive = dist as Primitive; - return Pdf(primitive.dist, y); - } - throw new ArgumentException("Can only calculate PDF for primitive distributions"); - } - - public static Prob Pdf(Dist dist, double[] y) - { - if (dist is Primitive) - { - var primitive = dist as Primitive; - return Pdf(primitive.dist, y); - } - throw new ArgumentException("Can only calculate PDF for primitive distributions"); - } - public static Prob Pdf(Dist> dist, Matrix y) - { - if (dist is Primitive>) - { - var primitive = dist as Primitive>; - return Pdf(primitive.dist, y); - } - throw new ArgumentException("Can only calculate PDF for primitive distributions"); - } - - public static Prob Pmf(Dist dist, T k) - { - if (dist is Primitive) - { - var primitive = dist as Primitive; - return Pmf(primitive.dist, k); - } - throw new ArgumentException("Can only calculate PDF for primitive distributions"); - } - - public static Prob Pmf(Dist dist, int y) - { - if (dist is Primitive) - { - var primitive = dist as Primitive; - return Pmf(primitive.dist, y); - } - throw new ArgumentException("Can only calculate pmf for primitive distributions"); - } - - - /// - /// Appends a value to a list. Non-mutative. - /// - public static IEnumerable Append(IEnumerable list, A value) - { - var appendList = new List(list); - appendList.Add(value); - return appendList; - } - - /// - /// Sigmoid function - /// - public static double Sigmoid(double x) - { - return 1 / (1 + Math.Exp(-x)); - } - #endregion - } -} +using System.Collections.Generic; +using System.Linq; +using System; +using MathNet.Numerics.LinearAlgebra; +using MathNet.Numerics.LinearAlgebra.Double; + +namespace ProbCSharp +{ + /// + /// This class exports a whole bunch of constructors. + /// Basically avoids us typing new X() + /// which can make code unreadable when we have really large types. + /// + public static class ProbBase + { + // Singleton instance of the random generator to avoid repeated values in tight loops + public static Random Gen = new Random(); + + #region Primitive object constructors + /// + /// Probability constructor + /// + public static Prob Prob(double probability) + { + return new LogProb(Math.Log(probability)); + } + + /// + /// ItemProb constructor + /// + public static ItemProb ItemProb(A item, Prob prob) + { + return new ItemProb(item, prob); + } + + /// + /// Samples constructor + /// + public static Samples Samples(IEnumerable> itemProbs) + { + return new Samples(itemProbs); + } + + /// + /// Tuple constructor + /// + public static Tuple Tuple(A a, B b) + { + return new Tuple(a, b); + } + #endregion + + #region Distribution constructors + /// + /// Finite uniform distribution over list of items. + /// Only composable with other finite distributions. + /// + public static FiniteDist EnumUniformF(IEnumerable items) + { + var uniform = Samples(items.Select(i => new ItemProb(i, Prob(1)))); + return new FiniteDist(Importance.Normalize(uniform)); + } + + /// + /// Uniform distribution over list of items + /// + public static Dist UniformFromList(IEnumerable items) + { + return Primitive(EnumUniformF(items)); + } + + /// + /// Uniform distribution over parameter items + /// Only composable with other finite distributions. + /// + public static FiniteDist UniformF(params A[] items) + { + var itemList = new List(items); + return EnumUniformF(itemList); + } + + /// + /// Uniform distribution over parameter items + /// + public static Dist Uniform(params A[] items) + { + return Primitive(UniformF(items)); + } + + /// + /// Bernoulli distribution constructed from success probability + /// Only composable with other finite distributions. + /// + public static FiniteDist BernoulliF(Prob prob) + { + return new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value))); + } + + /// + /// Bernoulli distribution constructed from success probability + /// + public static Dist Bernoulli(Prob prob) + { + return Primitive(BernoulliF(prob)); + } + + /// + /// Bernoulli distribution constructed from success probability + /// + public static Dist Bernoulli(double prob) + { + return Primitive(BernoulliF(Prob(prob))); + } + + /// + /// Bernoulli distribution constructed from two items, and probability of first item + /// + public static Dist Bernoulli(double prob, A option1, A option2) + { + return Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2); + } + + /// + /// Categorical distribution + /// Only composable with other finite distributions + /// + public static FiniteDist CategoricalF(params ItemProb[] itemProbs) + { + return new FiniteDist(Samples(itemProbs)); + } + + /// + /// Categorical distribution + /// Only composable with other finite distributions + /// + public static FiniteDist CategoricalF(Samples samples) + { + return new FiniteDist(samples); + } + + /// + /// Categorical distribution + /// + public static Dist Categorical(Samples samples) + { + return Primitive(CategoricalF(samples).ToSampleDist()); + } + + /// + /// Categorical distribution + /// + public static CategoricalPrimitive CategoricalPrimitive(A[] items, double [] probabilities) + { + return new CategoricalPrimitive(items, probabilities, Gen); + } + + /// + /// Primitive studentT distribution + /// Only composable with other primitive distributions + /// + public static StudentTPrimitive StudentTPrimitive(double location, double scale, double normality) + { + return new StudentTPrimitive(location, scale, normality, Gen); + } + + /// + /// Primitive Exponential distribution + /// Only composable with other primitive distributions + /// + public static ExponentialPrimitive ExponentialPrimitive(double rate) + { + return new ExponentialPrimitive(rate); + } + + /// + /// Primitive Contiuous Uniform distribution + /// Only composable with other primitive distributions + /// + public static ContinuousUniformPrimitive ContinuousUniformPrimitive() + { + return new ContinuousUniformPrimitive(Gen); + } + + /// + /// Primitive Contiuous Uniform distribution + /// Only composable with other primitive distributions + /// + public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower, double upper) + { + return new ContinuousUniformPrimitive(lower,upper, Gen); + } + + + /// + /// Primitive Poisson distribution + /// Only composable with other primitive distributions + /// + public static PoissonPrimitive PoissonPrimitive(double lambda) + { + return new PoissonPrimitive(lambda, Gen); + } + + /// + /// Primitive Normal distribution + /// Only composable with other primitive distributions + /// + public static NormalPrimitive NormalPrimitive(double mean, double variance) + { + return new NormalPrimitive(mean, variance, Gen); + } + + /// + /// Primitive LogNormal distribution + /// Only composable with other primitive distributions + /// + public static LogNormalPrimitive LogNormalPrimitive(double mean, double variance) + { + return new LogNormalPrimitive(mean, variance, Gen); + } + + public static LogNormalPrimitive LogNormalPrimitiveMu(double mu, double sigma) + { + return new LogNormalPrimitive(mu, sigma, true, Gen); + } + + /// + /// Primitive Beta distribution + /// Only composable with other primitive distributions + /// + public static BetaPrimitive BetaPrimitive(double alpha, double beta) + { + return new BetaPrimitive(alpha, beta, Gen); + } + + /// + /// Primitive Gamma distribution + /// Only composable with other primitive distributions + /// + public static GammaPrimitive GammaPrimitive(double shape, double rate) + { + return new GammaPrimitive(shape, rate, Gen); + } + + public static MultiVariateNormalPrimitive MultiVariateNormalPrimitive(double[] mean, Matrix covariance) + { + return new MultiVariateNormalPrimitive(mean, covariance, Gen); + } + + public static WishartPrimitive WishartPrimitive(double dof, Matrix scale) + { + return new WishartPrimitive(dof, scale, Gen); + } + + /// + /// Primitive Dirichlet distribution + /// Only composable with other primitive distributions + /// + public static DirichletPrimitive DirichletPrimitive(double[] alpha) + { + return new DirichletPrimitive(alpha, Gen); + } + + /// + /// Poisson distribution + /// + public static Dist Poisson(double lambda) + { + return Primitive(PoissonPrimitive(lambda)); + } + + /// + /// Normal distribution + /// + public static Dist Normal(double mean, double variance) + { + return Primitive(NormalPrimitive(mean, variance)); + } + + /// + /// Log Normal distribution + /// + public static Dist LogNormal(double mean, double variance) + { + return Primitive(LogNormalPrimitive(mean, variance)); + } + + public static Dist LogNormalMu(double mu, double sigma) + { + return Primitive(LogNormalPrimitiveMu(mu, sigma)); + } + + /// + /// Gamma distribution + /// + public static Dist Gamma(double shape, double rate) + { + return Primitive(GammaPrimitive(shape, rate)); + } + + /// + /// Categorical distribution + /// + public static Dist Categorical(T[] items, double[] probabilities) + { + return Primitive(CategoricalPrimitive(items, probabilities)); + } + + public static Dist Uniform() + { + return Primitive(ContinuousUniformPrimitive()); + } + + public static Dist Uniform(double lower, double upper) + { + return Primitive(ContinuousUniformPrimitive(lower,upper)); + } + + /// + /// StudenT distribution + /// + public static Dist StudentT(double location, double scale, double normality) + { + return Primitive(StudentTPrimitive(location, scale, normality)); + } + + /// + /// Exponential distribution + /// + public static Dist Exponential(double rate) + { + return Primitive(ExponentialPrimitive(rate)); + } + + /// + /// Beta distribution + /// + public static Dist Beta(double alpha, double beta) + { + return Primitive(BetaPrimitive(alpha, beta)); + } + + /// + /// Dirichlet distribution + /// + public static Dist Dirichlet(double[] alpha) + { + return Primitive(DirichletPrimitive(alpha)); + } + + public static Dist MultiVariateNormal(double[] mean, Matrix covariance) + { + return Primitive(MultiVariateNormalPrimitive(mean, covariance)); + } + + public static Dist> Wishart(double dof, Matrix scale) + { + return Primitive(WishartPrimitive(dof, scale)); + } + #endregion + + #region GADT constructors + + #region Parallel constructors + /// + /// Wraps the distribution to defer evaluation until explicitly parallelized + /// + public static Dist> Independent(Dist dist) + { + return new Independent(dist); + } + + /// + /// Evaluates two distributions in parallel, passing the results to a function + /// + public static Dist RunIndependentWith(Dist dist1, Dist dist2, Func> run) + { + return new RunIndependent(dist1, dist2, run); + } + + /// + /// Evaluates two distributions in parallel. The results are collected into a tuple. + /// + public static Dist> RunIndependent(Dist dist1, Dist dist2) + { + return new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2))); + } + + /// + /// Evaluates three distributions in parallel. The results are collected into a tuple. + /// + public static Dist> RunIndependent(Dist dist1, Dist dist2, Dist dist3) + { + return new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3))); + } + + /// + /// Evaluates three distributions in parallel, passing the results to a function + /// + public static Dist RunIndependentWith(Dist dist1, Dist dist2, Dist dist3, Func> run) + { + return new RunIndependent3(dist1, dist2, dist3, run); + } + #endregion + + /// + /// Primitive constructor for continuous dists + /// + public static Dist Primitive(PrimitiveDist dist) + { + return new Primitive(dist); + } + + /// + /// Primitive constructor for finite dists + /// + public static Dist Primitive(FiniteDist dist) + { + return new Primitive(dist.ToSampleDist()); + } + + /// + /// Pure constructor, monadic return + /// + public static Dist Return(A value) + { + return new Pure(value); + } + + /// + /// Create a conditional distribution with given likelihood function and distribution + /// + public static Dist Condition(Func likelihood, Dist dist) + { + return new Conditional(likelihood, dist); + } + + /// + /// Conditional constructor + /// + public static Dist Condition(this Dist dist, Func likelihood) + { + return new Conditional(likelihood, dist); + } + + #endregion + + #region Utility functions + /// + /// Aggregates probabilities of samples with identical values + /// The samples are arranged in ascending order + /// + /// Used to identify identical values + public static Samples Compact(Samples samples, Func keyFunc) where A : IComparable + { + return Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item)); + } + + /// + /// Aggregates probabilities of samples with identical values + /// + /// Used to identify identical values + public static Samples CompactUnordered(Samples samples, Func keyFunc) + { + var compacted = + samples.Weights + .GroupBy(ip => keyFunc(ip.Item)) + .Select(g => + ItemProb( + g.First().Item, + Prob(g.Sum(ip => ip.Prob.Value)) + ) + ); + return Samples(compacted); + } + + /// + /// Aggregate & normalize samples + /// The samples are arranged in ascending order + /// + public static Samples Enumerate(FiniteDist dist, Func keyFunc) where A : IComparable + { + return Importance.Normalize(Compact(dist.Explicit, keyFunc)); + } + + /// + /// The probability density function for a primitive distribution and point. + /// Throws NotImplementedException if no PDF is defined for given distribution. + /// + public static Prob Pdf(PrimitiveDist dist, double[] x) + { + + if (dist is DirichletPrimitive) + { + var dir = dist as DirichletPrimitive; + var d = new MathNet.Numerics.Distributions.Dirichlet(dir.alpha); + + return Prob(d.Density(x)); + } + + if (dist is MultiVariateNormalPrimitive) + { + var d = dist as MultiVariateNormalPrimitive; + + return Prob(d.mvn.Density(DenseMatrix.OfRowArrays(x))); + } + throw new NotImplementedException("No PDF for this distribution implemented"); + } + /// + /// The probability density function for a primitive distribution and point. + /// Throws NotImplementedException if no PDF is defined for given distribution. + /// + public static Prob Pdf(PrimitiveDist> dist, Matrix x) + { + + if (dist is WishartPrimitive) + { + var wish = dist as WishartPrimitive; + return Prob(wish.Wishart.Density(x)); + } + + throw new NotImplementedException("No PDF for this distribution implemented"); + } + /// + /// The probability density function for a primitive distribution and point. + /// Throws NotImplementedException if no PDF is defined for given distribution. + /// + public static Prob Pdf(PrimitiveDist dist, double y) + { + if (dist is NormalPrimitive) + { + var normal = dist as NormalPrimitive; + return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y)); + } + if (dist is LogNormalPrimitive) + { + var lognormal = dist as LogNormalPrimitive; + return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y)); + } + if (dist is BetaPrimitive) + { + var beta = dist as BetaPrimitive; + return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y)); + } + if (dist is GammaPrimitive) + { + var gamma = dist as GammaPrimitive; + return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y)); + } + throw new NotImplementedException("No PDF for this distribution implemented"); + } + + /// + /// The probability mass function for a primitive distribution and point. + /// Throws NotImplementedException if no PMF is defined for given distribution. + /// + public static Prob Pmf(PrimitiveDist dist, int y) + { + if (dist is PoissonPrimitive) + { + var poisson = dist as PoissonPrimitive; + return Prob(MathNet.Numerics.Distributions.Poisson.PMF(poisson.Lambda, y)); + } + throw new NotImplementedException("No PMF for this distribution implemented"); + } + + public static Prob Pmf(PrimitiveDist dist, T k) + { + if (dist is CategoricalPrimitive) + { + var cat = dist as CategoricalPrimitive; + + return Prob(MathNet.Numerics.Distributions.Categorical.PMF(cat.ProbabilityMass, cat.ItemIndex[k])); + } + throw new NotImplementedException("No PMF for this distribution implemented"); + } + + /// + /// The probability density function for a distribution and point. + /// Throws ArgumentException if the distribution is not a Primitive. + /// + public static Prob Pdf(Dist dist, double y) + { + if (dist is Primitive) + { + var primitive = dist as Primitive; + return Pdf(primitive.dist, y); + } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); + } + + public static Prob Pdf(Dist dist, double[] y) + { + if (dist is Primitive) + { + var primitive = dist as Primitive; + return Pdf(primitive.dist, y); + } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); + } + public static Prob Pdf(Dist> dist, Matrix y) + { + if (dist is Primitive>) + { + var primitive = dist as Primitive>; + return Pdf(primitive.dist, y); + } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); + } + + public static Prob Pmf(Dist dist, T k) + { + if (dist is Primitive) + { + var primitive = dist as Primitive; + return Pmf(primitive.dist, k); + } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); + } + + public static Prob Pmf(Dist dist, int y) + { + if (dist is Primitive) + { + var primitive = dist as Primitive; + return Pmf(primitive.dist, y); + } + throw new ArgumentException("Can only calculate pmf for primitive distributions"); + } + + /// + /// Sigmoid function + /// + public static double Sigmoid(double x) + { + return 1 / (1 + Math.Exp(-x)); + } + #endregion + } +} diff --git a/ProbabilityMonad/DistGadt.cs b/ProbabilityMonad/DistGadt.cs index 7d995b3..3bd3dbe 100644 --- a/ProbabilityMonad/DistGadt.cs +++ b/ProbabilityMonad/DistGadt.cs @@ -293,7 +293,7 @@ private static Dist> RunSequence(IEnumerable> dists) Return>(new List()), (listDist, aDist) => from a in aDist from list in listDist - select Append(list, a) + select list.Append(a) ); } diff --git a/ProbabilityMonad/Finite/FiniteExtensions.cs b/ProbabilityMonad/Finite/FiniteExtensions.cs index 30558c0..0f3a676 100644 --- a/ProbabilityMonad/Finite/FiniteExtensions.cs +++ b/ProbabilityMonad/Finite/FiniteExtensions.cs @@ -44,7 +44,7 @@ public static Prob ProbOf(this FiniteDist dist, Func eventTest) { var matches = dist.Explicit.Weights.Where(p => eventTest(p.Item)); if (!matches.Any()) return Prob(0); - return Prob(matches.Select(p => p.Prob.Value).Sum()); + return Prob(matches.Sum(p => p.Prob.Value)); } /// diff --git a/ProbabilityMonad/Histogram.cs b/ProbabilityMonad/Histogram.cs index 94e13c6..586e9bb 100644 --- a/ProbabilityMonad/Histogram.cs +++ b/ProbabilityMonad/Histogram.cs @@ -87,7 +87,7 @@ internal static string ShowBuckets(IEnumerable buckets, double scale) /// internal static int LongestString(IEnumerable list, Func toString) { - return list.Select(x => toString(x).Length).Max(); + return list.Max(x => toString(x).Length); } /// @@ -103,22 +103,7 @@ internal static double BarScale(IEnumerable barLengths, double scale) /// Draw bar of width n /// internal static string Bar(int n) - { - var barBuilder = new StringBuilder(); - for (var i = 0; i < n; i++) - { - barBuilder.Append("#"); - } - return barBuilder.ToString(); - } - - /// - /// Return sum of list with a given value function - /// - internal static double Sum(Func getVal, IEnumerable list) - { - return list.Select(getVal).Sum(); - } + => new string('#', n); /// /// Generates a weighted histogram from a list of ItemProbs @@ -128,16 +113,16 @@ public static string Weighted(IEnumerable> nums, int numBuckets { if (!nums.Any()) return "No data to graph."; - var sorted = nums.OrderBy(x => x.Item); + var sorted = nums.OrderBy(x => x.Item).ToArray(); var min = sorted.First().Item; var max = sorted.Last().Item; var bucketList = MakeBucketList(min, max, numBuckets); - var totalMass = Sum(ip => ip.Prob.Value, sorted.Where(ip => !Double.IsInfinity(ip.Prob.LogValue))); + var totalMass = sorted.Where(ip => !Double.IsInfinity(ip.Prob.LogValue)).Sum(ip => ip.Prob.Value); foreach (var bucket in bucketList) { bucket.WeightedValues.AddRange(sorted.Where(x => x.Item >= bucket.Min && x.Item < bucket.Max)); - bucket.BarSize = (int) Math.Floor(bucket.WeightedValues.Select(x => x.Prob.Value).Sum() / totalMass * scale); + bucket.BarSize = (int) Math.Floor(bucket.WeightedValues.Sum(x => x.Prob.Value) / totalMass * scale); } return ShowBuckets(bucketList, scale); } @@ -158,10 +143,10 @@ public static string Weighted(Samples nums, int numBuckets = 10, double public static string Unweighted(IEnumerable nums, int numBuckets = 10, double scale = DEFAULT_SCALE) { if (!nums.Any()) return "No data to graph."; - var sorted = nums.OrderBy(x => x); + var sorted = nums.OrderBy(x => x).ToArray(); var min = sorted.First(); var max = sorted.Last() + 10e-10; - var total = Sum(x => x, sorted); + var total = sorted.Sum(x => x); var bucketList = MakeBucketList(min, max, numBuckets); @@ -193,8 +178,10 @@ public static string Finite(Samples itemProbs, Func showFunc = var normalized = Importance.Normalize(CompactUnordered(itemProbs, showFunc)); var sb = new StringBuilder(); - var display = normalized.Weights.Select(ip => new Tuple(showFunc(ip.Item), (int) Math.Floor(ip.Prob.Value * scale), ip.Prob)); - var maxWidth = display.Select(d => d.Item1.Length).Max(); + var display = normalized.Weights + .Select(ip => new Tuple(showFunc(ip.Item), (int)Math.Floor(ip.Prob.Value * scale), ip.Prob)) + .ToArray(); + var maxWidth = display.Max(d => d.Item1.Length); var barScale = BarScale(display.Select(d => d.Item2), scale); foreach (var line in display) { diff --git a/ProbabilityMonad/KullbackLeibler.cs b/ProbabilityMonad/KullbackLeibler.cs index e8f5118..35fbf0b 100644 --- a/ProbabilityMonad/KullbackLeibler.cs +++ b/ProbabilityMonad/KullbackLeibler.cs @@ -22,9 +22,8 @@ public static double KLDivergenceF(FiniteDist distQ, FiniteDist di var pWeights = Enumerate(distP, keyFunc); - var divergences = pWeights.Weights - .Select(w => w.Prob.Value * Math.Log(w.Prob.Div(qDensity(w.Item)).Value)); - return divergences.Sum(); + return pWeights.Weights + .Sum(w => w.Prob.Value * Math.Log(w.Prob.Div(qDensity(w.Item)).Value)); } /// diff --git a/ProbabilityMonad/Sampler.cs b/ProbabilityMonad/Sampler.cs index 28f2e2c..1b605fa 100644 --- a/ProbabilityMonad/Sampler.cs +++ b/ProbabilityMonad/Sampler.cs @@ -21,7 +21,8 @@ public static A Sample(this Dist dist) /// public static IEnumerable SampleN(this Dist dist, int n) { - return Enumerable.Range(0, n).Select(_ => dist.Sample()); + for (int i = 0; i < n; ++i) + yield return dist.Sample(); } } diff --git a/ProbabilityMonad/Samples.cs b/ProbabilityMonad/Samples.cs index 2e6e009..73e3be8 100644 --- a/ProbabilityMonad/Samples.cs +++ b/ProbabilityMonad/Samples.cs @@ -27,7 +27,7 @@ public static Samples MapSample(this Samples self, Func f) /// public static Prob SumProbs(this Samples self) { - return Prob(self.Weights.Select(ip => ip.Prob.Value).Sum()); + return Prob(self.Weights.Sum(ip => ip.Prob.Value)); } public static Samples Normalize(this Samples self)