diff --git a/ProbabilityMonad/Base.cs b/ProbabilityMonad/Base.cs index 3b041d4..dd7df1d 100644 --- a/ProbabilityMonad/Base.cs +++ b/ProbabilityMonad/Base.cs @@ -14,40 +14,33 @@ namespace ProbCSharp public static class ProbBase { // Singleton instance of the random generator to avoid repeated values in tight loops - public static Random Gen = new Random(); + public static readonly Random Gen = new Random(); #region Primitive object constructors /// /// Probability constructor /// public static Prob Prob(double probability) - { - return new LogProb(Math.Log(probability)); - } + => new LogProb(Math.Log(probability)); /// /// ItemProb constructor /// public static ItemProb ItemProb(A item, Prob prob) - { - return new ItemProb(item, prob); - } + => new ItemProb(item, prob); /// /// Samples constructor /// public static Samples Samples(IEnumerable> itemProbs) - { - return new Samples(itemProbs); - } + => new Samples(itemProbs); /// /// Tuple constructor /// public static Tuple Tuple(A a, B b) - { - return new Tuple(a, b); - } + => new Tuple(a, b); + #endregion #region Distribution constructors @@ -65,130 +58,99 @@ public static FiniteDist EnumUniformF(IEnumerable items) /// Uniform distribution over list of items /// public static Dist UniformFromList(IEnumerable items) - { - return Primitive(EnumUniformF(items)); - } + => Primitive(EnumUniformF(items)); /// /// Uniform distribution over parameter items /// Only composable with other finite distributions. /// public static FiniteDist UniformF(params A[] items) - { - var itemList = new List(items); - return EnumUniformF(itemList); - } + => EnumUniformF(new List(items)); /// /// Uniform distribution over parameter items /// public static Dist Uniform(params A[] items) - { - return Primitive(UniformF(items)); - } + => Primitive(UniformF(items)); /// /// Bernoulli distribution constructed from success probability /// Only composable with other finite distributions. /// public static FiniteDist BernoulliF(Prob prob) - { - return new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value))); - } + => new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value))); /// /// Bernoulli distribution constructed from success probability /// public static Dist Bernoulli(Prob prob) - { - return Primitive(BernoulliF(prob)); - } + => Primitive(BernoulliF(prob)); /// /// Bernoulli distribution constructed from success probability /// public static Dist Bernoulli(double prob) - { - return Primitive(BernoulliF(Prob(prob))); - } + => Primitive(BernoulliF(Prob(prob))); /// /// Bernoulli distribution constructed from two items, and probability of first item /// public static Dist Bernoulli(double prob, A option1, A option2) - { - return Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2); - } + => Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2); /// /// Categorical distribution /// Only composable with other finite distributions /// public static FiniteDist CategoricalF(params ItemProb[] itemProbs) - { - return new FiniteDist(Samples(itemProbs)); - } + => new FiniteDist(Samples(itemProbs)); /// /// Categorical distribution /// Only composable with other finite distributions /// public static FiniteDist CategoricalF(Samples samples) - { - return new FiniteDist(samples); - } + => new FiniteDist(samples); /// /// Categorical distribution /// public static Dist Categorical(Samples samples) - { - return Primitive(CategoricalF(samples).ToSampleDist()); - } + => Primitive(CategoricalF(samples).ToSampleDist()); /// /// Categorical distribution /// public static CategoricalPrimitive CategoricalPrimitive(A[] items, double [] probabilities) - { - return new CategoricalPrimitive(items, probabilities, Gen); - } + => new CategoricalPrimitive(items, probabilities, Gen); /// /// Primitive studentT distribution /// Only composable with other primitive distributions /// public static StudentTPrimitive StudentTPrimitive(double location, double scale, double normality) - { - return new StudentTPrimitive(location, scale, normality, Gen); - } + => new StudentTPrimitive(location, scale, normality, Gen); /// /// Primitive Exponential distribution /// Only composable with other primitive distributions /// public static ExponentialPrimitive ExponentialPrimitive(double rate) - { - return new ExponentialPrimitive(rate); - } + => new ExponentialPrimitive(rate); /// /// Primitive Contiuous Uniform distribution /// Only composable with other primitive distributions /// public static ContinuousUniformPrimitive ContinuousUniformPrimitive() - { - return new ContinuousUniformPrimitive(Gen); - } + => new ContinuousUniformPrimitive(Gen); /// /// Primitive Contiuous Uniform distribution /// Only composable with other primitive distributions /// public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower, double upper) - { - return new ContinuousUniformPrimitive(lower,upper, Gen); - } + => new ContinuousUniformPrimitive(lower,upper, Gen); /// @@ -196,166 +158,121 @@ public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower /// Only composable with other primitive distributions /// public static PoissonPrimitive PoissonPrimitive(double lambda) - { - return new PoissonPrimitive(lambda, Gen); - } + => new PoissonPrimitive(lambda, Gen); /// /// Primitive Normal distribution /// Only composable with other primitive distributions /// public static NormalPrimitive NormalPrimitive(double mean, double variance) - { - return new NormalPrimitive(mean, variance, Gen); - } + => new NormalPrimitive(mean, variance, Gen); /// /// Primitive LogNormal distribution /// Only composable with other primitive distributions /// public static LogNormalPrimitive LogNormalPrimitive(double mean, double variance) - { - return new LogNormalPrimitive(mean, variance, Gen); - } + => new LogNormalPrimitive(mean, variance, Gen); public static LogNormalPrimitive LogNormalPrimitiveMu(double mu, double sigma) - { - return new LogNormalPrimitive(mu, sigma, true, Gen); - } + => new LogNormalPrimitive(mu, sigma, true, Gen); /// /// Primitive Beta distribution /// Only composable with other primitive distributions /// public static BetaPrimitive BetaPrimitive(double alpha, double beta) - { - return new BetaPrimitive(alpha, beta, Gen); - } + => new BetaPrimitive(alpha, beta, Gen); /// /// Primitive Gamma distribution /// Only composable with other primitive distributions /// public static GammaPrimitive GammaPrimitive(double shape, double rate) - { - return new GammaPrimitive(shape, rate, Gen); - } + => new GammaPrimitive(shape, rate, Gen); public static MultiVariateNormalPrimitive MultiVariateNormalPrimitive(double[] mean, Matrix covariance) - { - return new MultiVariateNormalPrimitive(mean, covariance, Gen); - } + => new MultiVariateNormalPrimitive(mean, covariance, Gen); public static WishartPrimitive WishartPrimitive(double dof, Matrix scale) - { - return new WishartPrimitive(dof, scale, Gen); - } + => new WishartPrimitive(dof, scale, Gen); /// /// Primitive Dirichlet distribution /// Only composable with other primitive distributions /// public static DirichletPrimitive DirichletPrimitive(double[] alpha) - { - return new DirichletPrimitive(alpha, Gen); - } + => new DirichletPrimitive(alpha, Gen); /// /// Poisson distribution /// public static Dist Poisson(double lambda) - { - return Primitive(PoissonPrimitive(lambda)); - } + => Primitive(PoissonPrimitive(lambda)); /// /// Normal distribution /// public static Dist Normal(double mean, double variance) - { - return Primitive(NormalPrimitive(mean, variance)); - } + => Primitive(NormalPrimitive(mean, variance)); /// /// Log Normal distribution /// public static Dist LogNormal(double mean, double variance) - { - return Primitive(LogNormalPrimitive(mean, variance)); - } + => Primitive(LogNormalPrimitive(mean, variance)); public static Dist LogNormalMu(double mu, double sigma) - { - return Primitive(LogNormalPrimitiveMu(mu, sigma)); - } + => Primitive(LogNormalPrimitiveMu(mu, sigma)); /// /// Gamma distribution /// public static Dist Gamma(double shape, double rate) - { - return Primitive(GammaPrimitive(shape, rate)); - } + => Primitive(GammaPrimitive(shape, rate)); /// /// Categorical distribution /// public static Dist Categorical(T[] items, double[] probabilities) - { - return Primitive(CategoricalPrimitive(items, probabilities)); - } + => Primitive(CategoricalPrimitive(items, probabilities)); public static Dist Uniform() - { - return Primitive(ContinuousUniformPrimitive()); - } + => Primitive(ContinuousUniformPrimitive()); public static Dist Uniform(double lower, double upper) - { - return Primitive(ContinuousUniformPrimitive(lower,upper)); - } + => Primitive(ContinuousUniformPrimitive(lower,upper)); /// /// StudenT distribution /// public static Dist StudentT(double location, double scale, double normality) - { - return Primitive(StudentTPrimitive(location, scale, normality)); - } + => Primitive(StudentTPrimitive(location, scale, normality)); /// /// Exponential distribution /// public static Dist Exponential(double rate) - { - return Primitive(ExponentialPrimitive(rate)); - } + => Primitive(ExponentialPrimitive(rate)); /// /// Beta distribution /// public static Dist Beta(double alpha, double beta) - { - return Primitive(BetaPrimitive(alpha, beta)); - } + => Primitive(BetaPrimitive(alpha, beta)); /// /// Dirichlet distribution /// public static Dist Dirichlet(double[] alpha) - { - return Primitive(DirichletPrimitive(alpha)); - } + => Primitive(DirichletPrimitive(alpha)); public static Dist MultiVariateNormal(double[] mean, Matrix covariance) - { - return Primitive(MultiVariateNormalPrimitive(mean, covariance)); - } + => Primitive(MultiVariateNormalPrimitive(mean, covariance)); public static Dist> Wishart(double dof, Matrix scale) - { - return Primitive(WishartPrimitive(dof, scale)); - } + => Primitive(WishartPrimitive(dof, scale)); + #endregion #region GADT constructors @@ -365,82 +282,63 @@ public static Dist> Wishart(double dof, Matrix scale) /// Wraps the distribution to defer evaluation until explicitly parallelized /// public static Dist> Independent(Dist dist) - { - return new Independent(dist); - } + => new Independent(dist); /// /// Evaluates two distributions in parallel, passing the results to a function /// public static Dist RunIndependentWith(Dist dist1, Dist dist2, Func> run) - { - return new RunIndependent(dist1, dist2, run); - } + => new RunIndependent(dist1, dist2, run); /// /// Evaluates two distributions in parallel. The results are collected into a tuple. /// public static Dist> RunIndependent(Dist dist1, Dist dist2) - { - return new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2))); - } + => new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2))); /// /// Evaluates three distributions in parallel. The results are collected into a tuple. /// public static Dist> RunIndependent(Dist dist1, Dist dist2, Dist dist3) - { - return new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3))); - } + => new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3))); /// /// Evaluates three distributions in parallel, passing the results to a function /// public static Dist RunIndependentWith(Dist dist1, Dist dist2, Dist dist3, Func> run) - { - return new RunIndependent3(dist1, dist2, dist3, run); - } + => new RunIndependent3(dist1, dist2, dist3, run); + #endregion /// /// Primitive constructor for continuous dists /// public static Dist Primitive(PrimitiveDist dist) - { - return new Primitive(dist); - } + => new Primitive(dist); /// /// Primitive constructor for finite dists /// public static Dist Primitive(FiniteDist dist) - { - return new Primitive(dist.ToSampleDist()); - } + => new Primitive(dist.ToSampleDist()); /// /// Pure constructor, monadic return /// public static Dist Return(A value) - { - return new Pure(value); - } + => new Pure(value); /// /// Create a conditional distribution with given likelihood function and distribution /// public static Dist Condition(Func likelihood, Dist dist) - { - return new Conditional(likelihood, dist); - } + => new Conditional(likelihood, dist); /// /// Conditional constructor /// public static Dist Condition(this Dist dist, Func likelihood) - { - return new Conditional(likelihood, dist); - } + => new Conditional(likelihood, dist); #endregion @@ -451,9 +349,7 @@ public static Dist Condition(this Dist dist, Func likelihood) /// /// Used to identify identical values public static Samples Compact(Samples samples, Func keyFunc) where A : IComparable - { - return Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item)); - } + => Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item)); /// /// Aggregates probabilities of samples with identical values @@ -478,9 +374,7 @@ public static Samples CompactUnordered(Samples samples, Func public static Samples Enumerate(FiniteDist dist, Func keyFunc) where A : IComparable - { - return Importance.Normalize(Compact(dist.Explicit, keyFunc)); - } + => Importance.Normalize(Compact(dist.Explicit, keyFunc)); /// /// The probability density function for a primitive distribution and point. @@ -488,22 +382,18 @@ public static Samples Enumerate(FiniteDist dist, Func keyF /// public static Prob Pdf(PrimitiveDist dist, double[] x) { - - if (dist is DirichletPrimitive) - { - var dir = dist as DirichletPrimitive; - var d = new MathNet.Numerics.Distributions.Dirichlet(dir.alpha); - - return Prob(d.Density(x)); - } - - if (dist is MultiVariateNormalPrimitive) + switch (dist) { - var d = dist as MultiVariateNormalPrimitive; - - return Prob(d.mvn.Density(DenseMatrix.OfRowArrays(x))); + case DirichletPrimitive dirichlet: + var d = new MathNet.Numerics.Distributions.Dirichlet(dirichlet.alpha); + return Prob(d.Density(x)); + + case MultiVariateNormalPrimitive multiVariateNormal: + return Prob(multiVariateNormal.mvn.Density(DenseMatrix.OfRowArrays(x))); + + default: + throw new NotImplementedException("No PDF for this distribution implemented"); } - throw new NotImplementedException("No PDF for this distribution implemented"); } /// /// The probability density function for a primitive distribution and point. @@ -511,13 +401,9 @@ public static Prob Pdf(PrimitiveDist dist, double[] x) /// public static Prob Pdf(PrimitiveDist> dist, Matrix x) { + if (dist is WishartPrimitive wish) + return Prob(wish.Wishart.Density(x)); - if (dist is WishartPrimitive) - { - var wish = dist as WishartPrimitive; - return Prob(wish.Wishart.Density(x)); - } - throw new NotImplementedException("No PDF for this distribution implemented"); } /// @@ -526,27 +412,23 @@ public static Prob Pdf(PrimitiveDist> dist, Matrix x) /// public static Prob Pdf(PrimitiveDist dist, double y) { - if (dist is NormalPrimitive) + switch (dist) { - var normal = dist as NormalPrimitive; - return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y)); - } - if (dist is LogNormalPrimitive) - { - var lognormal = dist as LogNormalPrimitive; - return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y)); - } - if (dist is BetaPrimitive) - { - var beta = dist as BetaPrimitive; - return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y)); - } - if (dist is GammaPrimitive) - { - var gamma = dist as GammaPrimitive; - return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y)); + case NormalPrimitive normal: + return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y)); + + case LogNormalPrimitive lognormal: + return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y)); + + case BetaPrimitive beta: + return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y)); + + case GammaPrimitive gamma: + return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y)); + + default: + throw new NotImplementedException("No PDF for this distribution implemented"); } - throw new NotImplementedException("No PDF for this distribution implemented"); } /// @@ -555,22 +437,17 @@ public static Prob Pdf(PrimitiveDist dist, double y) /// public static Prob Pmf(PrimitiveDist dist, int y) { - if (dist is PoissonPrimitive) - { - var poisson = dist as PoissonPrimitive; + if (dist is PoissonPrimitive poisson) return Prob(MathNet.Numerics.Distributions.Poisson.PMF(poisson.Lambda, y)); - } + throw new NotImplementedException("No PMF for this distribution implemented"); } public static Prob Pmf(PrimitiveDist dist, T k) { - if (dist is CategoricalPrimitive) - { - var cat = dist as CategoricalPrimitive; - + if (dist is CategoricalPrimitive cat) return Prob(MathNet.Numerics.Distributions.Categorical.PMF(cat.ProbabilityMass, cat.ItemIndex[k])); - } + throw new NotImplementedException("No PMF for this distribution implemented"); } @@ -580,50 +457,40 @@ public static Prob Pmf(PrimitiveDist dist, T k) /// public static Prob Pdf(Dist dist, double y) { - if (dist is Primitive) - { - var primitive = dist as Primitive; + if (dist is Primitive primitive) return Pdf(primitive.dist, y); - } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); } public static Prob Pdf(Dist dist, double[] y) { - if (dist is Primitive) - { - var primitive = dist as Primitive; + if (dist is Primitive primitive) return Pdf(primitive.dist, y); - } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); } public static Prob Pdf(Dist> dist, Matrix y) { - if (dist is Primitive>) - { - var primitive = dist as Primitive>; - return Pdf(primitive.dist, y); - } + if (dist is Primitive> primitive) + return Pdf(primitive.dist, y); + throw new ArgumentException("Can only calculate PDF for primitive distributions"); } public static Prob Pmf(Dist dist, T k) { - if (dist is Primitive) - { - var primitive = dist as Primitive; + if (dist is Primitive primitive) return Pmf(primitive.dist, k); - } + throw new ArgumentException("Can only calculate PDF for primitive distributions"); } public static Prob Pmf(Dist dist, int y) { - if (dist is Primitive) - { - var primitive = dist as Primitive; + if (dist is Primitive primitive) return Pmf(primitive.dist, y); - } + throw new ArgumentException("Can only calculate pmf for primitive distributions"); } @@ -642,9 +509,8 @@ public static IEnumerable Append(IEnumerable list, A value) /// Sigmoid function /// public static double Sigmoid(double x) - { - return 1 / (1 + Math.Exp(-x)); - } + => 1 / (1 + Math.Exp(-x)); + #endregion } } diff --git a/ProbabilityMonad/DistGadt.cs b/ProbabilityMonad/DistGadt.cs index 7d995b3..0243922 100644 --- a/ProbabilityMonad/DistGadt.cs +++ b/ProbabilityMonad/DistGadt.cs @@ -54,20 +54,15 @@ public interface ParallelDistInterpreter public class Pure : Dist { public readonly A Value; + public Pure(A value) - { - Value = value; - } + => Value = value; public X Run(DistInterpreter interpreter) - { - return interpreter.Pure(Value); - } + => interpreter.Pure(Value); public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.Pure(Value); - } + => interpreter.Pure(Value); } /// @@ -76,20 +71,15 @@ public X RunParallel(ParallelDistInterpreter interpreter) public class Primitive : Dist { public readonly PrimitiveDist dist; + public Primitive(PrimitiveDist dist) - { - this.dist = dist; - } + => this.dist = dist; public X Run(DistInterpreter interpreter) - { - return interpreter.Primitive(dist); - } + => interpreter.Primitive(dist); public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.Primitive(dist); - } + => interpreter.Primitive(dist); } /// @@ -99,21 +89,15 @@ public class Conditional : Dist { public readonly Func likelihood; public readonly Dist dist; + public Conditional(Func likelihood, Dist dist) - { - this.likelihood = likelihood; - this.dist = dist; - } + => (this.likelihood, this.dist) = (likelihood, dist); public X Run(DistInterpreter interpreter) - { - return interpreter.Conditional(likelihood, dist); - } + => interpreter.Conditional(likelihood, dist); public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.Conditional(likelihood, dist); - } + => interpreter.Conditional(likelihood, dist); } /// @@ -125,6 +109,7 @@ public class RunIndependent3 : Dist public readonly Dist second; public readonly Dist third; public readonly Func> run; + public RunIndependent3(Dist first, Dist second, Dist third, Func> run) { this.first = first; @@ -145,9 +130,7 @@ from result in run(x, y, z) } public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.RunIndependent3(first, second, third, run); - } + => interpreter.RunIndependent3(first, second, third, run); } /// @@ -158,6 +141,7 @@ public class RunIndependent : Dist public readonly Dist first; public readonly Dist second; public readonly Func> run; + public RunIndependent(Dist first, Dist second, Func> run) { this.first = first; @@ -176,9 +160,7 @@ from result in run(x, y) } public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.RunIndependent(first, second, run); - } + => interpreter.RunIndependent(first, second, run); } /// @@ -187,21 +169,16 @@ public X RunParallel(ParallelDistInterpreter interpreter) public class Independent : Dist> { public readonly Dist dist; + public Independent(Dist dist) - { - this.dist = dist; - } + => this.dist = dist; // Run sequentially if we're not using a parallel interpreter public X Run(DistInterpreter, X> interpreter) - { - return Return(dist).Run(interpreter); - } + => Return(dist).Run(interpreter); public X RunParallel(ParallelDistInterpreter, X> interpreter) - { - return interpreter.Independent(Return(dist)); - } + => interpreter.Independent(Return(dist)); } /// @@ -211,6 +188,7 @@ public class Bind : Dist { public readonly Dist dist; public readonly Func> bind; + public Bind(Dist dist, Func> bind) { this.dist = dist; @@ -218,14 +196,10 @@ public Bind(Dist dist, Func> bind) } public X Run(DistInterpreter interpreter) - { - return interpreter.Bind(dist, bind); - } + => interpreter.Bind(dist, bind); public X RunParallel(ParallelDistInterpreter interpreter) - { - return interpreter.Bind(dist, bind); - } + => interpreter.Bind(dist, bind); } /// @@ -235,36 +209,27 @@ public X RunParallel(ParallelDistInterpreter interpreter) public static class DistExt { public static Dist Select(this Dist dist, Func f) - { - return new Bind(dist, a => new Pure(f(a))); - } + => new Bind(dist, a => new Pure(f(a))); public static Dist SelectMany(this Dist dist, Func> bind) - { - return new Bind(dist, bind); - } + => new Bind(dist, bind); public static Dist SelectMany( this Dist dist, Func> bind, Func project ) - { - return - new Bind(dist, a => - new Bind(bind(a), b => - new Pure(project(a, b)) - ) - ); - } + => new Bind(dist, a => + new Bind(bind(a), b => + new Pure(project(a, b)) + ) + ); /// /// Default to using recursion depth limit of 100 /// public static Dist> Sequence(this IEnumerable> dists) - { - return SequenceWithDepth(dists, 100); - } + => SequenceWithDepth(dists, 100); /// /// This implementation sort of does trampolining to avoid stack overflows, diff --git a/ProbabilityMonad/Finite/FiniteDist.cs b/ProbabilityMonad/Finite/FiniteDist.cs index 6c5937c..7848411 100644 --- a/ProbabilityMonad/Finite/FiniteDist.cs +++ b/ProbabilityMonad/Finite/FiniteDist.cs @@ -11,14 +11,10 @@ namespace ProbCSharp public class FiniteDist { public FiniteDist(Samples samples) - { - Explicit = samples; - } + => Explicit = samples; public FiniteDist(params ItemProb[] samples) - { - Explicit = Samples(samples); - } + => Explicit = Samples(samples); public Samples Explicit { get; } } @@ -31,10 +27,8 @@ public static class FiniteDistMonad /// fmap f (FiniteDist (Samples xs)) = FiniteDist $ Samples $ map (first f) xs /// public static FiniteDist Select(this FiniteDist self, Func select) - { - return new FiniteDist(Samples(self.Explicit.Weights.Select(i => - ItemProb(select(i.Item), i.Prob)))); - } + => new FiniteDist(Samples(self.Explicit.Weights.Select(i => + ItemProb(@select(i.Item), i.Prob)))); /// /// (FiniteDist dist) >>= bind = FiniteDist $ do diff --git a/ProbabilityMonad/Finite/FiniteExtensions.cs b/ProbabilityMonad/Finite/FiniteExtensions.cs index 30558c0..91279ac 100644 --- a/ProbabilityMonad/Finite/FiniteExtensions.cs +++ b/ProbabilityMonad/Finite/FiniteExtensions.cs @@ -29,13 +29,11 @@ public static A Pick(this FiniteDist distribution, Prob pickProb) /// Lifts a FiniteDist into a SampleableDist /// public static PrimitiveDist ToSampleDist(this FiniteDist dist) - { - return new SampleDist(() => + => new SampleDist(() => { var rand = new MathNet.Numerics.Distributions.ContinuousUniform().Sample(); return dist.Pick(Prob(rand)); }); - } /// /// Returns the probability of a certain event @@ -51,71 +49,55 @@ public static Prob ProbOf(this FiniteDist dist, Func eventTest) /// Reweight by a probability that depends on associated item /// public static FiniteDist ConditionSoft(this FiniteDist distribution, Func likelihood) - { - return new FiniteDist( + => new FiniteDist( distribution.Explicit .Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob))) .Normalize() ); - } /// /// Reweight by a probability that depends on associated item, without normalizing /// public static FiniteDist ConditionSoftUnnormalized(this FiniteDist distribution, Func likelihood) - { - return new FiniteDist( + => new FiniteDist( distribution.Explicit.Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob))) ); - } /// /// Hard reweight by a condition that depends on associated item /// public static FiniteDist ConditionHard(this FiniteDist distribution, Func condition) - { - return new FiniteDist( + => new FiniteDist( distribution.Explicit .Select(p => ItemProb(p.Item, condition(p.Item) ? p.Prob : Prob(0))) .Normalize() ); - - - } /// /// Computes the posterior distribution, given a piece of data and a likelihood function /// public static FiniteDist UpdateOn(this FiniteDist prior, Func likelihood, D datum) - { - return prior.ConditionSoft(w => likelihood(w, datum)); - } + => prior.ConditionSoft(w => likelihood(w, datum)); /// /// Computes the posterior distribution, given a list of data and a likelihood function /// public static FiniteDist UpdateOn(this FiniteDist prior, Func likelihood, IEnumerable data) - { - return data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum)); - } + => data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum)); /// /// Normalize a finite distribution /// public static FiniteDist Normalize(this FiniteDist dist) - { - return new FiniteDist(dist.Explicit.Normalize()); - } - + => new FiniteDist(dist.Explicit.Normalize()); + /// /// Join two independent distributions /// public static FiniteDist> Join(this FiniteDist self, FiniteDist other) - { - return from a in self - from b in other - select new Tuple(a, b); - } + => from a in self + from b in other + select new Tuple(a, b); /// /// Returns all elements, and the collection without that element diff --git a/ProbabilityMonad/Histogram.cs b/ProbabilityMonad/Histogram.cs index 94e13c6..f4c70ec 100644 --- a/ProbabilityMonad/Histogram.cs +++ b/ProbabilityMonad/Histogram.cs @@ -67,15 +67,16 @@ internal static List MakeIntBucketList(int min, int max, int numBuckets) internal static string ShowBuckets(IEnumerable buckets, double scale) { var sb = new StringBuilder(); - Func formatDouble = d => $"{d:N2}"; - var minPadding = LongestString(buckets, b => formatDouble(b.Min)); - var maxPadding = LongestString(buckets, b => formatDouble(b.Max)); + string FormatDouble(double d) => d.ToString("N2"); + + var minPadding = LongestString(buckets, b => FormatDouble(b.Min)); + var maxPadding = LongestString(buckets, b => FormatDouble(b.Max)); var barScale = BarScale(buckets.Select(b => b.BarSize), scale); foreach (var bucket in buckets) { - var min = formatDouble(bucket.Min).PadLeft(minPadding, ' '); - var max = formatDouble(bucket.Max).PadLeft(maxPadding, ' '); + var min = FormatDouble(bucket.Min).PadLeft(minPadding, ' '); + var max = FormatDouble(bucket.Max).PadLeft(maxPadding, ' '); sb.AppendLine($"{min} {max} {Bar((int) (bucket.BarSize * barScale))}"); } sb.AppendLine(""); @@ -86,9 +87,7 @@ internal static string ShowBuckets(IEnumerable buckets, double scale) /// Returns the size of the longest string representation of an item in a list /// internal static int LongestString(IEnumerable list, Func toString) - { - return list.Select(x => toString(x).Length).Max(); - } + => list.Select(x => toString(x).Length).Max(); /// /// Calculate scale factor for a list of bar lengths @@ -103,22 +102,13 @@ internal static double BarScale(IEnumerable barLengths, double scale) /// Draw bar of width n /// internal static string Bar(int n) - { - var barBuilder = new StringBuilder(); - for (var i = 0; i < n; i++) - { - barBuilder.Append("#"); - } - return barBuilder.ToString(); - } + => new string('#', n); /// /// Return sum of list with a given value function /// internal static double Sum(Func getVal, IEnumerable list) - { - return list.Select(getVal).Sum(); - } + => list.Select(getVal).Sum(); /// /// Generates a weighted histogram from a list of ItemProbs @@ -147,9 +137,7 @@ public static string Weighted(IEnumerable> nums, int numBuckets /// /// Scale factor. Defaults to 40 public static string Weighted(Samples nums, int numBuckets = 10, double scale = DEFAULT_SCALE) - { - return Weighted(nums.Weights, numBuckets, scale); - } + => Weighted(nums.Weights, numBuckets, scale); /// /// Generates an unweigted histogram for a list of numbers @@ -178,9 +166,7 @@ public static string Unweighted(IEnumerable nums, int numBuckets = 10, d /// /// Scale factor. Defaults to 40 public static string Unweighted(IEnumerable nums, int numBuckets = 10, double scale = DEFAULT_SCALE) - { - return Unweighted(nums.Select(x => (double)x), numBuckets, scale); - } + => Unweighted(nums.Select(x => (double)x), numBuckets, scale); /// /// Generates a histogram from some samples, grouping by a show function. diff --git a/ProbabilityMonad/Inference/Importance.cs b/ProbabilityMonad/Inference/Importance.cs index 6833020..8fae155 100644 --- a/ProbabilityMonad/Inference/Importance.cs +++ b/ProbabilityMonad/Inference/Importance.cs @@ -10,9 +10,7 @@ public static class ImportanceExt /// Importance sample with given number of samples /// public static Dist> ImportanceSamples(this Dist dist, int numSamples) - { - return Importance.ImportanceSamples(numSamples, dist); - } + => Importance.ImportanceSamples(numSamples, dist); } /// @@ -35,11 +33,9 @@ public static Dist> Resample(Samples samples) /// Flattens nested samples /// public static Samples Flatten(Samples> samples) - { - return Samples(from outer in Normalize(samples).Weights - from inner in outer.Item.Weights - select ItemProb(inner.Item, inner.Prob.Mult(outer.Prob))); - } + => Samples(from outer in Normalize(samples).Weights + from inner in outer.Item.Weights + select ItemProb(inner.Item, inner.Prob.Mult(outer.Prob))); /// /// Performs importance sampling on a distribution @@ -57,11 +53,9 @@ public static Dist> ImportanceSamples(int numSamples, Dist dist /// /// An importance sampled distribution public static Dist ImportanceDist(int numSamples, Dist dist) - { - return from probs in ImportanceSamples(numSamples, dist) - from resampled in Categorical(probs) - select resampled; - } + => from probs in ImportanceSamples(numSamples, dist) + from resampled in Categorical(probs) + select resampled; /// /// Normalizes a list of samples diff --git a/ProbabilityMonad/Inference/MetropolisHastings.cs b/ProbabilityMonad/Inference/MetropolisHastings.cs index 2c552ac..47ae9e3 100644 --- a/ProbabilityMonad/Inference/MetropolisHastings.cs +++ b/ProbabilityMonad/Inference/MetropolisHastings.cs @@ -42,9 +42,6 @@ from accept in Primitive(BernoulliF(Prob(Math.Min(1.0, candidate.Prob.Div(chainS /// Returns the value of the metropolis chain at specified index /// public static Dist MHIndex(Dist dist, int n, int index) - { - return MHPrior(dist, n).Select(list => list.ElementAt(index)); - } - + => MHPrior(dist, n).Select(list => list.ElementAt(index)); } } diff --git a/ProbabilityMonad/Inference/Pimh.cs b/ProbabilityMonad/Inference/Pimh.cs index 29066c0..96b1d46 100644 --- a/ProbabilityMonad/Inference/Pimh.cs +++ b/ProbabilityMonad/Inference/Pimh.cs @@ -8,8 +8,6 @@ public static class Pimh /// Particle indepedent Metropolis-Hastings /// public static Dist>> Create(int numParticles, int chainLen, Dist dist) - { - return MetropolisHastings.MHPrior(dist.Run(new Smc(numParticles)), chainLen); - } + => MetropolisHastings.MHPrior(dist.Run(new Smc(numParticles)), chainLen); } } diff --git a/ProbabilityMonad/Inference/Prior.cs b/ProbabilityMonad/Inference/Prior.cs index 80b0dd3..7402e0b 100644 --- a/ProbabilityMonad/Inference/Prior.cs +++ b/ProbabilityMonad/Inference/Prior.cs @@ -8,29 +8,19 @@ namespace ProbCSharp public class Prior : DistInterpreter> { public DistInterpreter New() - { - return new Prior() as DistInterpreter; - } + => new Prior() as DistInterpreter; Dist DistInterpreter>.Bind(Dist dist, Func> bind) - { - return new Bind(dist.Run(new Prior()), bind); - } + => new Bind(dist.Run(new Prior()), bind); Dist DistInterpreter>.Conditional(Func lik, Dist dist) - { - return dist.Run(new Prior()); - } + => dist.Run(new Prior()); Dist DistInterpreter>.Primitive(PrimitiveDist dist) - { - return new Pure(dist.Sample()); - } + => new Pure(dist.Sample()); Dist DistInterpreter>.Pure(A value) - { - return new Pure(value); - } + => new Pure(value); } diff --git a/ProbabilityMonad/Inference/PriorWeighted.cs b/ProbabilityMonad/Inference/PriorWeighted.cs index 70cefd1..8ea366e 100644 --- a/ProbabilityMonad/Inference/PriorWeighted.cs +++ b/ProbabilityMonad/Inference/PriorWeighted.cs @@ -7,9 +7,7 @@ namespace ProbCSharp public static class PriorWeightedExtensions { public static Dist> WeightedPrior(this Dist dist) - { - return dist.Run(new PriorWeighted()); - } + => dist.Run(new PriorWeighted()); } /// @@ -20,32 +18,22 @@ public static Dist> WeightedPrior(this Dist dist) public class PriorWeighted : DistInterpreter>> { public Dist> Bind(Dist dist, Func> bind) - { - return from x in dist.Run(new PriorWeighted()) - from y in bind(x.Item) // Don't remove conditionals here - select new ItemProb(y, x.Prob); - } + => from x in dist.Run(new PriorWeighted()) + from y in bind(x.Item) // Don't remove conditionals here + select new ItemProb(y, x.Prob); public Dist> Conditional(Func lik, Dist dist) - { - return from itemProb in dist.Run(new PriorWeighted()) - select new ItemProb(itemProb.Item, itemProb.Prob.Mult(lik(itemProb.Item))); - } + => from itemProb in dist.Run(new PriorWeighted()) + select new ItemProb(itemProb.Item, itemProb.Prob.Mult(lik(itemProb.Item))); public DistInterpreter New() - { - return new PriorWeighted() as DistInterpreter; - } + => new PriorWeighted() as DistInterpreter; public Dist> Primitive(PrimitiveDist dist) - { - return new Primitive>(dist.Select(a => ItemProb(a, Prob(1)))); - } + => new Primitive>(dist.Select(a => ItemProb(a, Prob(1)))); public Dist> Pure(A value) - { - return new Pure>(new ItemProb(value, Prob(1))); - } + => new Pure>(new ItemProb(value, Prob(1))); } } diff --git a/ProbabilityMonad/Inference/Smc.cs b/ProbabilityMonad/Inference/Smc.cs index 4b4a8cd..94f7d5c 100644 --- a/ProbabilityMonad/Inference/Smc.cs +++ b/ProbabilityMonad/Inference/Smc.cs @@ -12,17 +12,13 @@ public static class SmcExtensions /// SMC that discards pseudo-marginal likelihoods /// public static Dist> SmcStandard(this Dist dist, int n) - { - return dist.Run(new Smc(n)).Run(new Prior>()); - } + => dist.Run(new Smc(n)).Run(new Prior>()); /// /// SMC that does importance sampling /// public static Dist> SmcMultiple(this Dist dist, int numSamples, int numParticles) - { - return dist.Run(new Smc(numParticles)).ImportanceSamples(numSamples).Select(Importance.Flatten); - } + => dist.Run(new Smc(numParticles)).ImportanceSamples(numSamples).Select(Importance.Flatten); } /// @@ -32,17 +28,13 @@ public class Smc : DistInterpreter>> { public int numParticles; public Smc(int numParticles) - { - this.numParticles = numParticles; - } + => this.numParticles = numParticles; public Dist> Bind(Dist dist, Func> bind) - { - return from ps in dist.Run(new Smc(numParticles)) - let unzipped = ps.Unzip() - from ys in unzipped.Item1.Select(bind).Sequence() - select Samples(ys.Zip(unzipped.Item2, ItemProb)); - } + => from ps in dist.Run(new Smc(numParticles)) + let unzipped = ps.Unzip() + from ys in unzipped.Item1.Select(bind).Sequence() + select Samples(ys.Zip(unzipped.Item2, ItemProb)); public Dist> Conditional(Func lik, Dist dist) { @@ -58,9 +50,7 @@ select ps.Select(ip => ItemProb(ip.Item, lik(ip.Item).Mult(ip.Prob))) } public DistInterpreter New() - { - return new Smc(numParticles) as DistInterpreter; - } + => new Smc(numParticles) as DistInterpreter; public Dist> Primitive(PrimitiveDist dist) { diff --git a/ProbabilityMonad/ItemProb.cs b/ProbabilityMonad/ItemProb.cs index df360dc..84750c4 100644 --- a/ProbabilityMonad/ItemProb.cs +++ b/ProbabilityMonad/ItemProb.cs @@ -7,15 +7,11 @@ public class ItemProb { public A Item { get; } public Prob Prob { get; } + public ItemProb(A item, Prob prob) - { - Item = item; - Prob = prob; - } + => (Item, Prob) = (item, prob); public override string ToString() - { - return $"ItemProb ({Item}, {Prob})"; - } + => $"ItemProb ({Item}, {Prob})"; } } diff --git a/ProbabilityMonad/KullbackLeibler.cs b/ProbabilityMonad/KullbackLeibler.cs index e8f5118..c60818b 100644 --- a/ProbabilityMonad/KullbackLeibler.cs +++ b/ProbabilityMonad/KullbackLeibler.cs @@ -13,17 +13,20 @@ public static class KullbackLeibler public static double KLDivergenceF(FiniteDist distQ, FiniteDist distP, Func keyFunc) where A : IComparable { var qWeights = Enumerate(distQ, keyFunc); - Func qDensity = x => + + Prob QDensity(A x) { var weight = qWeights.FirstOrDefault(ip => keyFunc(ip.Item).Equals(keyFunc(x))); - if (weight == null) return Prob(0); + if (weight == null) + return Prob(0); + return weight.Prob; - }; + } var pWeights = Enumerate(distP, keyFunc); var divergences = pWeights.Weights - .Select(w => w.Prob.Value * Math.Log(w.Prob.Div(qDensity(w.Item)).Value)); + .Select(w => w.Prob.Value * Math.Log(w.Prob.Div(QDensity(w.Item)).Value)); return divergences.Sum(); } diff --git a/ProbabilityMonad/LogNormalCreate.cs b/ProbabilityMonad/LogNormalCreate.cs deleted file mode 100644 index 97022a2..0000000 --- a/ProbabilityMonad/LogNormalCreate.cs +++ /dev/null @@ -1,12 +0,0 @@ -using MathNet.Numerics.Distributions; - -namespace ProbCSharp -{ - class LogNormalCreate - { - public static LogNormal fromMeanVariance(double mean, double variance) - { - return LogNormal.WithMeanVariance(mean, variance); - } - } -} diff --git a/ProbabilityMonad/ParallelSampler.cs b/ProbabilityMonad/ParallelSampler.cs index 30fe000..5e8b59e 100644 --- a/ProbabilityMonad/ParallelSampler.cs +++ b/ProbabilityMonad/ParallelSampler.cs @@ -14,9 +14,7 @@ public static class ParallelSampleExtensions /// This will throw an exception if the distribution contains any conditionals. /// public static A SampleParallel(this Dist dist) - { - return dist.RunParallel(new ParallelSampler()); - } + => dist.RunParallel(new ParallelSampler()); /// /// Draws n samples from a distribution in parallel @@ -24,9 +22,7 @@ public static A SampleParallel(this Dist dist) /// This will throw an exception if the distribution contains any conditionals. /// public static IEnumerable SampleNParallel(this Dist dist, int n) - { - return Enumerable.Range(0, n).Select(_ => dist.SampleParallel()); - } + => Enumerable.Range(0, n).Select(_ => dist.SampleParallel()); } /// @@ -41,24 +37,16 @@ public A Bind(Dist dist, Func> bind) } public A Conditional(Func lik, Dist dist) - { - throw new ArgumentException("Cannot sample from conditional distribution."); - } + => throw new ArgumentException("Cannot sample from conditional distribution."); public A Primitive(PrimitiveDist dist) - { - return dist.Sample(); - } + => dist.Sample(); public A Pure(A value) - { - return value; - } + => value; public A Independent(Dist independent) - { - return independent.RunParallel(new ParallelSampler()); - } + => independent.RunParallel(new ParallelSampler()); public A RunIndependent(Dist distB, Dist distC, Func> run) { diff --git a/ProbabilityMonad/Primitive/Beta.cs b/ProbabilityMonad/Primitive/Beta.cs index 9536868..86a112b 100644 --- a/ProbabilityMonad/Primitive/Beta.cs +++ b/ProbabilityMonad/Primitive/Beta.cs @@ -10,6 +10,7 @@ public class BetaPrimitive : PrimitiveDist public double alpha; public double beta; public MathNet.Numerics.Distributions.Beta dist; + public BetaPrimitive(double alpha, double beta, Random gen) { this.alpha = alpha; @@ -18,8 +19,6 @@ public BetaPrimitive(double alpha, double beta, Random gen) } public Func Sample - { - get { return () => dist.Sample(); } - } + => () => dist.Sample(); } } diff --git a/ProbabilityMonad/Primitive/Categorical.cs b/ProbabilityMonad/Primitive/Categorical.cs index 6379373..24b39c6 100644 --- a/ProbabilityMonad/Primitive/Categorical.cs +++ b/ProbabilityMonad/Primitive/Categorical.cs @@ -7,36 +7,31 @@ namespace ProbCSharp { - /// - /// Primitive Categorical distribution - /// - public class CategoricalPrimitive : PrimitiveDist - { - public double[] ProbabilityMass { get; } - public T[] Items { get; } - public Random Gen { get; } - public Categorical categorical { get; } - public Dictionary ItemIndex {get ;} + /// + /// Primitive Categorical distribution + /// + public class CategoricalPrimitive : PrimitiveDist + { + public double[] ProbabilityMass { get; } + public T[] Items { get; } + public Random Gen { get; } + public Categorical categorical { get; } + public Dictionary ItemIndex { get; } - public CategoricalPrimitive(T[] items , double[] probabilities, Random gen) - { - Gen = gen; - Items = items; - ProbabilityMass = probabilities; - categorical = new Categorical(ProbabilityMass, gen); - ItemIndex = new Dictionary(); - for (int c = 0; c < items.Length; c++) - { - ItemIndex.Add(items[c],c); - } - } + public CategoricalPrimitive(T[] items, double[] probabilities, Random gen) + { + Gen = gen; + Items = items; + ProbabilityMass = probabilities; + categorical = new Categorical(ProbabilityMass, gen); + ItemIndex = new Dictionary(); + for (int c = 0; c < items.Length; c++) + { + ItemIndex.Add(items[c], c); + } + } - public Func Sample - { - get - { - return () => Items[categorical.Sample()]; - } - } - } + public Func Sample + => () => Items[categorical.Sample()]; + } } diff --git a/ProbabilityMonad/Primitive/Dirichlet.cs b/ProbabilityMonad/Primitive/Dirichlet.cs index f826bfa..cad475a 100644 --- a/ProbabilityMonad/Primitive/Dirichlet.cs +++ b/ProbabilityMonad/Primitive/Dirichlet.cs @@ -2,7 +2,6 @@ namespace ProbCSharp { - /// /// Primitive Dirichlet distribution /// @@ -11,16 +10,15 @@ public class DirichletPrimitive : PrimitiveDist public double[] alpha; public MathNet.Numerics.Distributions.Dirichlet dist; + public DirichletPrimitive(double[] alpha, Random gen) { this.alpha = alpha; - + dist = new MathNet.Numerics.Distributions.Dirichlet(alpha, gen); } public Func Sample - { - get { return () => dist.Sample(); } - } + => () => dist.Sample(); } } diff --git a/ProbabilityMonad/Primitive/Gamma.cs b/ProbabilityMonad/Primitive/Gamma.cs index 16389dc..3d4c238 100644 --- a/ProbabilityMonad/Primitive/Gamma.cs +++ b/ProbabilityMonad/Primitive/Gamma.cs @@ -1,21 +1,24 @@ using System; -namespace ProbCSharp { +namespace ProbCSharp +{ /// /// Primitive Gamma distribution /// - public class GammaPrimitive : PrimitiveDist { + public class GammaPrimitive : PrimitiveDist + { public double shape; public double rate; public MathNet.Numerics.Distributions.Gamma dist; - public GammaPrimitive(double shape, double rate, Random gen) { + + public GammaPrimitive(double shape, double rate, Random gen) + { this.shape = shape; this.rate = rate; dist = new MathNet.Numerics.Distributions.Gamma(shape, rate); } - public Func Sample { - get { return () => dist.Sample(); } - } + public Func Sample + => () => dist.Sample(); } } diff --git a/ProbabilityMonad/Primitive/LogNormal.cs b/ProbabilityMonad/Primitive/LogNormal.cs index 479e472..4885f31 100644 --- a/ProbabilityMonad/Primitive/LogNormal.cs +++ b/ProbabilityMonad/Primitive/LogNormal.cs @@ -1,36 +1,36 @@ using System; using MathNet.Numerics.Distributions; -namespace ProbCSharp { - /// - /// Primitive Lognormal distribution - /// - public class LogNormalPrimitive : PrimitiveDist - { - public double mean; - public double variance; - public double mu; - public double sigma; - public LogNormal dist; - public LogNormalPrimitive(double mu, double sigma, bool dummy, Random gen) - { - dist = new LogNormal(mu, sigma); - mu = dist.Mu; - sigma = dist.Sigma; - } - public LogNormalPrimitive(double mean, double variance, Random gen) - { - this.mean = mean; - this.variance = variance; - dist = LogNormalCreate.fromMeanVariance(mean, variance); // MathNet.Numerics.Distributions.LogNormal(mu, sigma); - mu = dist.Mu; - sigma = dist.Sigma; - } +namespace ProbCSharp +{ + /// + /// Primitive Lognormal distribution + /// + public class LogNormalPrimitive : PrimitiveDist + { + public double mean; + public double variance; + public double mu; + public double sigma; + public LogNormal dist; - public Func Sample - { - get { return () => dist.Sample(); } - } - } + public LogNormalPrimitive(double mu, double sigma, bool dummy, Random gen) + { + dist = new LogNormal(mu, sigma); + mu = dist.Mu; + sigma = dist.Sigma; + } + + public LogNormalPrimitive(double mean, double variance, Random gen) + { + this.mean = mean; + this.variance = variance; + dist = LogNormal.WithMeanVariance(mean, variance); // MathNet.Numerics.Distributions.LogNormal(mu, sigma); + mu = dist.Mu; + sigma = dist.Sigma; + } + + public Func Sample + => () => dist.Sample(); + } } - \ No newline at end of file diff --git a/ProbabilityMonad/Primitive/MultiVariateNormal.cs b/ProbabilityMonad/Primitive/MultiVariateNormal.cs index 1e5de12..461ffcb 100644 --- a/ProbabilityMonad/Primitive/MultiVariateNormal.cs +++ b/ProbabilityMonad/Primitive/MultiVariateNormal.cs @@ -5,39 +5,31 @@ namespace ProbCSharp { - /// - /// Primitive Multivariate Normal distribution - /// - public class MultiVariateNormalPrimitive : PrimitiveDist - { - public double[] Mean { get; } - public Matrix Covariance { get; } - public Random Gen { get; } - public MatrixNormal mvn { get; } + /// + /// Primitive Multivariate Normal distribution + /// + public class MultiVariateNormalPrimitive : PrimitiveDist + { + public double[] Mean { get; } + public Matrix Covariance { get; } + public Random Gen { get; } + public MatrixNormal mvn { get; } - static public MatrixNormal CreateMultivariateNormal(double[] meanVector, Matrix cv) - { - return new MatrixNormal(DenseMatrix.OfRowArrays(meanVector), - DenseMatrix.OfRowArrays(new double[][] { new double[] { 1.0 } }), - cv); - } + public static MatrixNormal CreateMultivariateNormal(double[] meanVector, Matrix cv) + => new MatrixNormal( + m: DenseMatrix.OfRowArrays(meanVector), + v: DenseMatrix.OfRowArrays(new double[][] { new double[] { 1.0 } }), + k: cv); - public MultiVariateNormalPrimitive(double[] mean, Matrix covariance, Random gen) - { - Mean = mean; - Covariance = covariance; - Gen = gen; - mvn = MultiVariateNormalPrimitive.CreateMultivariateNormal(mean, covariance); - } + public MultiVariateNormalPrimitive(double[] mean, Matrix covariance, Random gen) + { + Mean = mean; + Covariance = covariance; + Gen = gen; + mvn = CreateMultivariateNormal(mean, covariance); + } - public Func Sample - { - get - { - return () => mvn.Sample().Row(0).ToArray(); - } - } - } + public Func Sample + => () => mvn.Sample().Row(0).ToArray(); + } } - - \ No newline at end of file diff --git a/ProbabilityMonad/Primitive/Normal.cs b/ProbabilityMonad/Primitive/Normal.cs index 0c35483..3a8274d 100644 --- a/ProbabilityMonad/Primitive/Normal.cs +++ b/ProbabilityMonad/Primitive/Normal.cs @@ -10,23 +10,18 @@ public class NormalPrimitive : PrimitiveDist { public double Mean { get; } public double Variance { get; } - public Random Gen {get;} + public Random Gen { get; } public Normal normal { get; } - public NormalPrimitive(double mean, double variance, Random gen) - { + public NormalPrimitive(double mean, double variance, Random gen) + { Mean = mean; Variance = variance; Gen = gen; normal = Normal.WithMeanVariance(Mean, Variance, Gen); - } - - public Func Sample - { - get - { - return () => normal.Sample(); - } } + + public Func Sample + => () => normal.Sample(); } } diff --git a/ProbabilityMonad/Primitive/Poisson.cs b/ProbabilityMonad/Primitive/Poisson.cs index 5911bf7..6642aad 100644 --- a/ProbabilityMonad/Primitive/Poisson.cs +++ b/ProbabilityMonad/Primitive/Poisson.cs @@ -9,9 +9,10 @@ namespace ProbCSharp /// public class PoissonPrimitive : PrimitiveDist { - public double Lambda; + public double Lambda; public Poisson Dist; public Random Gen; + public PoissonPrimitive(double lambda, Random gen) { Lambda = lambda; @@ -20,11 +21,6 @@ public PoissonPrimitive(double lambda, Random gen) } public Func Sample - { - get - { - return () => Dist.Sample(); - } - } + => () => Dist.Sample(); } } diff --git a/ProbabilityMonad/Primitive/PrimitiveDist.cs b/ProbabilityMonad/Primitive/PrimitiveDist.cs index f688b98..340fc2c 100644 --- a/ProbabilityMonad/Primitive/PrimitiveDist.cs +++ b/ProbabilityMonad/Primitive/PrimitiveDist.cs @@ -19,15 +19,12 @@ public interface PrimitiveDist public class SampleDist : PrimitiveDist { private readonly Func _sample; + public SampleDist(Func sample) - { - _sample = sample; - } + => _sample = sample; public Func Sample - { - get { return _sample; } - } + => _sample; } @@ -35,23 +32,17 @@ public Func Sample public static class PrimitiveDistMonad { public static PrimitiveDist Select(this PrimitiveDist self, Func f) - { - return new SampleDist(() => f(self.Sample())); - } + => new SampleDist(() => f(self.Sample())); public static PrimitiveDist SelectMany( this PrimitiveDist self, Func> bind, - Func project - ) - { - return new SampleDist(() => + Func project) + => new SampleDist(() => { var firstSample = self.Sample(); var secondSample = bind(firstSample).Sample(); return project(firstSample, secondSample); }); - } } - } diff --git a/ProbabilityMonad/Primitive/Uniform.cs b/ProbabilityMonad/Primitive/Uniform.cs index 4090a93..c12c017 100644 --- a/ProbabilityMonad/Primitive/Uniform.cs +++ b/ProbabilityMonad/Primitive/Uniform.cs @@ -3,20 +3,17 @@ namespace ProbCSharp { - public class ContinuousUniformPrimitive : PrimitiveDist - { - public ContinuousUniform dist; - public ContinuousUniformPrimitive(double lower, double upper, Random gen) - { - dist = new ContinuousUniform(lower,upper); - } - public ContinuousUniformPrimitive(Random gen) - { - dist = new ContinuousUniform(); - } - public Func Sample - { - get { return () => dist.Sample(); } - } - } + public class ContinuousUniformPrimitive : PrimitiveDist + { + public ContinuousUniform dist; + + public ContinuousUniformPrimitive(double lower, double upper, Random gen) + => dist = new ContinuousUniform(lower, upper); + + public ContinuousUniformPrimitive(Random gen) + => dist = new ContinuousUniform(); + + public Func Sample + => () => dist.Sample(); + } } diff --git a/ProbabilityMonad/Primitive/Wishart.cs b/ProbabilityMonad/Primitive/Wishart.cs index 10d48d7..42ac744 100644 --- a/ProbabilityMonad/Primitive/Wishart.cs +++ b/ProbabilityMonad/Primitive/Wishart.cs @@ -4,30 +4,25 @@ namespace ProbCSharp { - /// - /// Primitive Wishart distribution - /// - public class WishartPrimitive : PrimitiveDist> - { - public double DegreesOfFreedom { get; } - public Matrix Scale { get; } - public Random Gen { get; } - public Wishart Wishart { get; } - - public WishartPrimitive(double dof, Matrix scale, Random gen) + /// + /// Primitive Wishart distribution + /// + public class WishartPrimitive : PrimitiveDist> { - DegreesOfFreedom = dof; - Scale = scale; - Gen = gen; - Wishart = new Wishart(dof, scale); - } + public double DegreesOfFreedom { get; } + public Matrix Scale { get; } + public Random Gen { get; } + public Wishart Wishart { get; } - public Func> Sample - { - get - { - return () => Wishart.Sample(); - } + public WishartPrimitive(double dof, Matrix scale, Random gen) + { + DegreesOfFreedom = dof; + Scale = scale; + Gen = gen; + Wishart = new Wishart(dof, scale); + } + + public Func> Sample + => () => Wishart.Sample(); } - } } diff --git a/ProbabilityMonad/Prob.cs b/ProbabilityMonad/Prob.cs index e8eea2d..47b5bee 100644 --- a/ProbabilityMonad/Prob.cs +++ b/ProbabilityMonad/Prob.cs @@ -21,17 +21,10 @@ public class DoubleProb : Prob public double Value { get; } public double LogValue - { - get - { - return Math.Log(Value); - } - } + => Math.Log(Value); public DoubleProb(double probability) - { - Value = probability; - } + => Value = probability; public override string ToString() { @@ -40,24 +33,16 @@ public override string ToString() } public Prob Mult(Prob other) - { - return new DoubleProb(Value * other.Value); - } + => new DoubleProb(Value * other.Value); public Prob Div(Prob other) - { - return new DoubleProb(Value / other.Value); - } + => new DoubleProb(Value / other.Value); public int CompareTo(Prob other) - { - return Value.CompareTo(other.Value); - } + => Value.CompareTo(other.Value); public bool Equals(Prob other) - { - return Value.Equals(other.Value); - } + => Value.Equals(other.Value); } /// @@ -75,45 +60,24 @@ public LogProb(double logProb) } public double Value - { - get - { - return Math.Exp(logProb); - } - } + => Math.Exp(logProb); public double LogValue - { - get - { - return logProb; - } - } + => logProb; public override string ToString() - { - var str = $"{Value*100:G3}%"; - return str; - } + => $"{Value*100:G3}%"; public int CompareTo(Prob other) - { - return logProb.CompareTo(other.LogValue); - } + => logProb.CompareTo(other.LogValue); public Prob Div(Prob other) - { - return new LogProb(logProb - other.LogValue); - } + => new LogProb(logProb - other.LogValue); public bool Equals(Prob other) - { - return other.LogValue == logProb; - } + => other.LogValue == logProb; public Prob Mult(Prob other) - { - return new LogProb(logProb + other.LogValue); - } + => new LogProb(logProb + other.LogValue); } } diff --git a/ProbabilityMonad/Sampler.cs b/ProbabilityMonad/Sampler.cs index 28f2e2c..1a76909 100644 --- a/ProbabilityMonad/Sampler.cs +++ b/ProbabilityMonad/Sampler.cs @@ -11,18 +11,14 @@ public static class SamplerExt /// This will throw an exception if the distribution contains any conditionals. /// public static A Sample(this Dist dist) - { - return dist.Run(new Sampler()); - } + => dist.Run(new Sampler()); /// /// Draws n samples from a distribution. /// This will throw an exception if the distribution contains any conditionals. /// public static IEnumerable SampleN(this Dist dist, int n) - { - return Enumerable.Range(0, n).Select(_ => dist.Sample()); - } + => Enumerable.Range(0, n).Select(_ => dist.Sample()); } /// @@ -31,9 +27,7 @@ public static IEnumerable SampleN(this Dist dist, int n) public class Sampler : DistInterpreter { public DistInterpreter New() - { - return new Sampler() as DistInterpreter; - } + => new Sampler() as DistInterpreter; A DistInterpreter.Bind(Dist dist, Func> bind) { @@ -45,19 +39,13 @@ A DistInterpreter.Bind(Dist dist, Func> bind) /// All conditionals must be removed before sampling. /// A DistInterpreter.Conditional(Func lik, Dist dist) - { - throw new ArgumentException("Cannot sample from conditional distribution."); - } + => throw new ArgumentException("Cannot sample from conditional distribution."); A DistInterpreter.Primitive(PrimitiveDist dist) - { - return dist.Sample(); - } + => dist.Sample(); A DistInterpreter.Pure(A value) - { - return value; - } + => value; } } diff --git a/ProbabilityMonad/Samples.cs b/ProbabilityMonad/Samples.cs index 2e6e009..bc7c65e 100644 --- a/ProbabilityMonad/Samples.cs +++ b/ProbabilityMonad/Samples.cs @@ -13,35 +13,25 @@ namespace ProbCSharp public static class SamplesExt { public static Samples Select(this Samples self, Func, ItemProb> f) - { - return Samples(self.Weights.Select(f)); - } + => Samples(self.Weights.Select(f)); public static Samples MapSample(this Samples self, Func f) - { - return Samples(self.Weights.Select(ip => ItemProb(f(ip.Item), ip.Prob))); - } + => Samples(self.Weights.Select(ip => ItemProb(f(ip.Item), ip.Prob))); /// /// Sum all probabilities in samples /// public static Prob SumProbs(this Samples self) - { - return Prob(self.Weights.Select(ip => ip.Prob.Value).Sum()); - } + => Prob(self.Weights.Select(ip => ip.Prob.Value).Sum()); public static Samples Normalize(this Samples self) - { - return Importance.Normalize(self); - } + => Importance.Normalize(self); /// /// Unzip samples into items and weights /// public static Tuple, IEnumerable> Unzip(this Samples samples) - { - return new Tuple, IEnumerable>(samples.Weights.Select(ip => ip.Item), samples.Weights.Select(ip => ip.Prob)); - } + => new Tuple, IEnumerable>(samples.Weights.Select(ip => ip.Item), samples.Weights.Select(ip => ip.Prob)); } // This is a Wrapper class for IEnumerable> @@ -53,18 +43,12 @@ public class Samples : IEnumerable> { public readonly IEnumerable> Weights; public Samples(IEnumerable> list) - { - Weights = list; - } + => Weights = list; public IEnumerator> GetEnumerator() - { - return Weights.GetEnumerator(); - } + => Weights.GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() - { - return Weights.GetEnumerator(); - } + => Weights.GetEnumerator(); } }