diff --git a/ProbabilityMonad/Base.cs b/ProbabilityMonad/Base.cs
index 3b041d4..dd7df1d 100644
--- a/ProbabilityMonad/Base.cs
+++ b/ProbabilityMonad/Base.cs
@@ -14,40 +14,33 @@ namespace ProbCSharp
public static class ProbBase
{
// Singleton instance of the random generator to avoid repeated values in tight loops
- public static Random Gen = new Random();
+ public static readonly Random Gen = new Random();
#region Primitive object constructors
///
/// Probability constructor
///
public static Prob Prob(double probability)
- {
- return new LogProb(Math.Log(probability));
- }
+ => new LogProb(Math.Log(probability));
///
/// ItemProb constructor
///
public static ItemProb ItemProb(A item, Prob prob)
- {
- return new ItemProb(item, prob);
- }
+ => new ItemProb(item, prob);
///
/// Samples constructor
///
public static Samples Samples(IEnumerable> itemProbs)
- {
- return new Samples(itemProbs);
- }
+ => new Samples(itemProbs);
///
/// Tuple constructor
///
public static Tuple Tuple(A a, B b)
- {
- return new Tuple(a, b);
- }
+ => new Tuple(a, b);
+
#endregion
#region Distribution constructors
@@ -65,130 +58,99 @@ public static FiniteDist EnumUniformF(IEnumerable items)
/// Uniform distribution over list of items
///
public static Dist UniformFromList(IEnumerable items)
- {
- return Primitive(EnumUniformF(items));
- }
+ => Primitive(EnumUniformF(items));
///
/// Uniform distribution over parameter items
/// Only composable with other finite distributions.
///
public static FiniteDist UniformF(params A[] items)
- {
- var itemList = new List(items);
- return EnumUniformF(itemList);
- }
+ => EnumUniformF(new List(items));
///
/// Uniform distribution over parameter items
///
public static Dist Uniform(params A[] items)
- {
- return Primitive(UniformF(items));
- }
+ => Primitive(UniformF(items));
///
/// Bernoulli distribution constructed from success probability
/// Only composable with other finite distributions.
///
public static FiniteDist BernoulliF(Prob prob)
- {
- return new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value)));
- }
+ => new FiniteDist(ItemProb(true, prob), ItemProb(false, Prob(1 - prob.Value)));
///
/// Bernoulli distribution constructed from success probability
///
public static Dist Bernoulli(Prob prob)
- {
- return Primitive(BernoulliF(prob));
- }
+ => Primitive(BernoulliF(prob));
///
/// Bernoulli distribution constructed from success probability
///
public static Dist Bernoulli(double prob)
- {
- return Primitive(BernoulliF(Prob(prob)));
- }
+ => Primitive(BernoulliF(Prob(prob)));
///
/// Bernoulli distribution constructed from two items, and probability of first item
///
public static Dist Bernoulli(double prob, A option1, A option2)
- {
- return Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2);
- }
+ => Primitive(BernoulliF(Prob(prob))).Select(b => b ? option1 : option2);
///
/// Categorical distribution
/// Only composable with other finite distributions
///
public static FiniteDist CategoricalF(params ItemProb[] itemProbs)
- {
- return new FiniteDist(Samples(itemProbs));
- }
+ => new FiniteDist(Samples(itemProbs));
///
/// Categorical distribution
/// Only composable with other finite distributions
///
public static FiniteDist CategoricalF(Samples samples)
- {
- return new FiniteDist(samples);
- }
+ => new FiniteDist(samples);
///
/// Categorical distribution
///
public static Dist Categorical(Samples samples)
- {
- return Primitive(CategoricalF(samples).ToSampleDist());
- }
+ => Primitive(CategoricalF(samples).ToSampleDist());
///
/// Categorical distribution
///
public static CategoricalPrimitive CategoricalPrimitive(A[] items, double [] probabilities)
- {
- return new CategoricalPrimitive(items, probabilities, Gen);
- }
+ => new CategoricalPrimitive(items, probabilities, Gen);
///
/// Primitive studentT distribution
/// Only composable with other primitive distributions
///
public static StudentTPrimitive StudentTPrimitive(double location, double scale, double normality)
- {
- return new StudentTPrimitive(location, scale, normality, Gen);
- }
+ => new StudentTPrimitive(location, scale, normality, Gen);
///
/// Primitive Exponential distribution
/// Only composable with other primitive distributions
///
public static ExponentialPrimitive ExponentialPrimitive(double rate)
- {
- return new ExponentialPrimitive(rate);
- }
+ => new ExponentialPrimitive(rate);
///
/// Primitive Contiuous Uniform distribution
/// Only composable with other primitive distributions
///
public static ContinuousUniformPrimitive ContinuousUniformPrimitive()
- {
- return new ContinuousUniformPrimitive(Gen);
- }
+ => new ContinuousUniformPrimitive(Gen);
///
/// Primitive Contiuous Uniform distribution
/// Only composable with other primitive distributions
///
public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower, double upper)
- {
- return new ContinuousUniformPrimitive(lower,upper, Gen);
- }
+ => new ContinuousUniformPrimitive(lower,upper, Gen);
///
@@ -196,166 +158,121 @@ public static ContinuousUniformPrimitive ContinuousUniformPrimitive(double lower
/// Only composable with other primitive distributions
///
public static PoissonPrimitive PoissonPrimitive(double lambda)
- {
- return new PoissonPrimitive(lambda, Gen);
- }
+ => new PoissonPrimitive(lambda, Gen);
///
/// Primitive Normal distribution
/// Only composable with other primitive distributions
///
public static NormalPrimitive NormalPrimitive(double mean, double variance)
- {
- return new NormalPrimitive(mean, variance, Gen);
- }
+ => new NormalPrimitive(mean, variance, Gen);
///
/// Primitive LogNormal distribution
/// Only composable with other primitive distributions
///
public static LogNormalPrimitive LogNormalPrimitive(double mean, double variance)
- {
- return new LogNormalPrimitive(mean, variance, Gen);
- }
+ => new LogNormalPrimitive(mean, variance, Gen);
public static LogNormalPrimitive LogNormalPrimitiveMu(double mu, double sigma)
- {
- return new LogNormalPrimitive(mu, sigma, true, Gen);
- }
+ => new LogNormalPrimitive(mu, sigma, true, Gen);
///
/// Primitive Beta distribution
/// Only composable with other primitive distributions
///
public static BetaPrimitive BetaPrimitive(double alpha, double beta)
- {
- return new BetaPrimitive(alpha, beta, Gen);
- }
+ => new BetaPrimitive(alpha, beta, Gen);
///
/// Primitive Gamma distribution
/// Only composable with other primitive distributions
///
public static GammaPrimitive GammaPrimitive(double shape, double rate)
- {
- return new GammaPrimitive(shape, rate, Gen);
- }
+ => new GammaPrimitive(shape, rate, Gen);
public static MultiVariateNormalPrimitive MultiVariateNormalPrimitive(double[] mean, Matrix covariance)
- {
- return new MultiVariateNormalPrimitive(mean, covariance, Gen);
- }
+ => new MultiVariateNormalPrimitive(mean, covariance, Gen);
public static WishartPrimitive WishartPrimitive(double dof, Matrix scale)
- {
- return new WishartPrimitive(dof, scale, Gen);
- }
+ => new WishartPrimitive(dof, scale, Gen);
///
/// Primitive Dirichlet distribution
/// Only composable with other primitive distributions
///
public static DirichletPrimitive DirichletPrimitive(double[] alpha)
- {
- return new DirichletPrimitive(alpha, Gen);
- }
+ => new DirichletPrimitive(alpha, Gen);
///
/// Poisson distribution
///
public static Dist Poisson(double lambda)
- {
- return Primitive(PoissonPrimitive(lambda));
- }
+ => Primitive(PoissonPrimitive(lambda));
///
/// Normal distribution
///
public static Dist Normal(double mean, double variance)
- {
- return Primitive(NormalPrimitive(mean, variance));
- }
+ => Primitive(NormalPrimitive(mean, variance));
///
/// Log Normal distribution
///
public static Dist LogNormal(double mean, double variance)
- {
- return Primitive(LogNormalPrimitive(mean, variance));
- }
+ => Primitive(LogNormalPrimitive(mean, variance));
public static Dist LogNormalMu(double mu, double sigma)
- {
- return Primitive(LogNormalPrimitiveMu(mu, sigma));
- }
+ => Primitive(LogNormalPrimitiveMu(mu, sigma));
///
/// Gamma distribution
///
public static Dist Gamma(double shape, double rate)
- {
- return Primitive(GammaPrimitive(shape, rate));
- }
+ => Primitive(GammaPrimitive(shape, rate));
///
/// Categorical distribution
///
public static Dist Categorical(T[] items, double[] probabilities)
- {
- return Primitive(CategoricalPrimitive(items, probabilities));
- }
+ => Primitive(CategoricalPrimitive(items, probabilities));
public static Dist Uniform()
- {
- return Primitive(ContinuousUniformPrimitive());
- }
+ => Primitive(ContinuousUniformPrimitive());
public static Dist Uniform(double lower, double upper)
- {
- return Primitive(ContinuousUniformPrimitive(lower,upper));
- }
+ => Primitive(ContinuousUniformPrimitive(lower,upper));
///
/// StudenT distribution
///
public static Dist StudentT(double location, double scale, double normality)
- {
- return Primitive(StudentTPrimitive(location, scale, normality));
- }
+ => Primitive(StudentTPrimitive(location, scale, normality));
///
/// Exponential distribution
///
public static Dist Exponential(double rate)
- {
- return Primitive(ExponentialPrimitive(rate));
- }
+ => Primitive(ExponentialPrimitive(rate));
///
/// Beta distribution
///
public static Dist Beta(double alpha, double beta)
- {
- return Primitive(BetaPrimitive(alpha, beta));
- }
+ => Primitive(BetaPrimitive(alpha, beta));
///
/// Dirichlet distribution
///
public static Dist Dirichlet(double[] alpha)
- {
- return Primitive(DirichletPrimitive(alpha));
- }
+ => Primitive(DirichletPrimitive(alpha));
public static Dist MultiVariateNormal(double[] mean, Matrix covariance)
- {
- return Primitive(MultiVariateNormalPrimitive(mean, covariance));
- }
+ => Primitive(MultiVariateNormalPrimitive(mean, covariance));
public static Dist> Wishart(double dof, Matrix scale)
- {
- return Primitive(WishartPrimitive(dof, scale));
- }
+ => Primitive(WishartPrimitive(dof, scale));
+
#endregion
#region GADT constructors
@@ -365,82 +282,63 @@ public static Dist> Wishart(double dof, Matrix scale)
/// Wraps the distribution to defer evaluation until explicitly parallelized
///
public static Dist> Independent(Dist dist)
- {
- return new Independent(dist);
- }
+ => new Independent(dist);
///
/// Evaluates two distributions in parallel, passing the results to a function
///
public static Dist RunIndependentWith(Dist dist1, Dist dist2, Func> run)
- {
- return new RunIndependent(dist1, dist2, run);
- }
+ => new RunIndependent(dist1, dist2, run);
///
/// Evaluates two distributions in parallel. The results are collected into a tuple.
///
public static Dist> RunIndependent(Dist dist1, Dist dist2)
- {
- return new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2)));
- }
+ => new RunIndependent>(dist1, dist2, (t1, t2) => Return(new Tuple(t1, t2)));
///
/// Evaluates three distributions in parallel. The results are collected into a tuple.
///
public static Dist> RunIndependent(Dist dist1, Dist dist2, Dist dist3)
- {
- return new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3)));
- }
+ => new RunIndependent3>(dist1, dist2, dist3, (t1, t2, t3) => Return(new Tuple(t1, t2, t3)));
///
/// Evaluates three distributions in parallel, passing the results to a function
///
public static Dist RunIndependentWith(Dist dist1, Dist dist2, Dist dist3, Func> run)
- {
- return new RunIndependent3(dist1, dist2, dist3, run);
- }
+ => new RunIndependent3(dist1, dist2, dist3, run);
+
#endregion
///
/// Primitive constructor for continuous dists
///
public static Dist Primitive(PrimitiveDist dist)
- {
- return new Primitive(dist);
- }
+ => new Primitive(dist);
///
/// Primitive constructor for finite dists
///
public static Dist Primitive(FiniteDist dist)
- {
- return new Primitive(dist.ToSampleDist());
- }
+ => new Primitive(dist.ToSampleDist());
///
/// Pure constructor, monadic return
///
public static Dist Return(A value)
- {
- return new Pure(value);
- }
+ => new Pure(value);
///
/// Create a conditional distribution with given likelihood function and distribution
///
public static Dist Condition(Func likelihood, Dist dist)
- {
- return new Conditional(likelihood, dist);
- }
+ => new Conditional(likelihood, dist);
///
/// Conditional constructor
///
public static Dist Condition(this Dist dist, Func likelihood)
- {
- return new Conditional(likelihood, dist);
- }
+ => new Conditional(likelihood, dist);
#endregion
@@ -451,9 +349,7 @@ public static Dist Condition(this Dist dist, Func likelihood)
///
/// Used to identify identical values
public static Samples Compact(Samples samples, Func keyFunc) where A : IComparable
- {
- return Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item));
- }
+ => Samples(CompactUnordered(samples, keyFunc).Weights.OrderBy(w => w.Item));
///
/// Aggregates probabilities of samples with identical values
@@ -478,9 +374,7 @@ public static Samples CompactUnordered(Samples samples, Func
public static Samples Enumerate(FiniteDist dist, Func keyFunc) where A : IComparable
- {
- return Importance.Normalize(Compact(dist.Explicit, keyFunc));
- }
+ => Importance.Normalize(Compact(dist.Explicit, keyFunc));
///
/// The probability density function for a primitive distribution and point.
@@ -488,22 +382,18 @@ public static Samples Enumerate(FiniteDist dist, Func keyF
///
public static Prob Pdf(PrimitiveDist dist, double[] x)
{
-
- if (dist is DirichletPrimitive)
- {
- var dir = dist as DirichletPrimitive;
- var d = new MathNet.Numerics.Distributions.Dirichlet(dir.alpha);
-
- return Prob(d.Density(x));
- }
-
- if (dist is MultiVariateNormalPrimitive)
+ switch (dist)
{
- var d = dist as MultiVariateNormalPrimitive;
-
- return Prob(d.mvn.Density(DenseMatrix.OfRowArrays(x)));
+ case DirichletPrimitive dirichlet:
+ var d = new MathNet.Numerics.Distributions.Dirichlet(dirichlet.alpha);
+ return Prob(d.Density(x));
+
+ case MultiVariateNormalPrimitive multiVariateNormal:
+ return Prob(multiVariateNormal.mvn.Density(DenseMatrix.OfRowArrays(x)));
+
+ default:
+ throw new NotImplementedException("No PDF for this distribution implemented");
}
- throw new NotImplementedException("No PDF for this distribution implemented");
}
///
/// The probability density function for a primitive distribution and point.
@@ -511,13 +401,9 @@ public static Prob Pdf(PrimitiveDist dist, double[] x)
///
public static Prob Pdf(PrimitiveDist> dist, Matrix x)
{
+ if (dist is WishartPrimitive wish)
+ return Prob(wish.Wishart.Density(x));
- if (dist is WishartPrimitive)
- {
- var wish = dist as WishartPrimitive;
- return Prob(wish.Wishart.Density(x));
- }
-
throw new NotImplementedException("No PDF for this distribution implemented");
}
///
@@ -526,27 +412,23 @@ public static Prob Pdf(PrimitiveDist> dist, Matrix x)
///
public static Prob Pdf(PrimitiveDist dist, double y)
{
- if (dist is NormalPrimitive)
+ switch (dist)
{
- var normal = dist as NormalPrimitive;
- return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y));
- }
- if (dist is LogNormalPrimitive)
- {
- var lognormal = dist as LogNormalPrimitive;
- return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y));
- }
- if (dist is BetaPrimitive)
- {
- var beta = dist as BetaPrimitive;
- return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y));
- }
- if (dist is GammaPrimitive)
- {
- var gamma = dist as GammaPrimitive;
- return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y));
+ case NormalPrimitive normal:
+ return Prob(MathNet.Numerics.Distributions.Normal.PDF(normal.Mean, Math.Sqrt(normal.Variance), y));
+
+ case LogNormalPrimitive lognormal:
+ return Prob(MathNet.Numerics.Distributions.LogNormal.PDF(lognormal.mu, lognormal.sigma, y));
+
+ case BetaPrimitive beta:
+ return Prob(MathNet.Numerics.Distributions.Beta.PDF(beta.alpha, beta.beta, y));
+
+ case GammaPrimitive gamma:
+ return Prob(MathNet.Numerics.Distributions.Gamma.PDF(gamma.shape, gamma.rate, y));
+
+ default:
+ throw new NotImplementedException("No PDF for this distribution implemented");
}
- throw new NotImplementedException("No PDF for this distribution implemented");
}
///
@@ -555,22 +437,17 @@ public static Prob Pdf(PrimitiveDist dist, double y)
///
public static Prob Pmf(PrimitiveDist dist, int y)
{
- if (dist is PoissonPrimitive)
- {
- var poisson = dist as PoissonPrimitive;
+ if (dist is PoissonPrimitive poisson)
return Prob(MathNet.Numerics.Distributions.Poisson.PMF(poisson.Lambda, y));
- }
+
throw new NotImplementedException("No PMF for this distribution implemented");
}
public static Prob Pmf(PrimitiveDist dist, T k)
{
- if (dist is CategoricalPrimitive)
- {
- var cat = dist as CategoricalPrimitive;
-
+ if (dist is CategoricalPrimitive cat)
return Prob(MathNet.Numerics.Distributions.Categorical.PMF(cat.ProbabilityMass, cat.ItemIndex[k]));
- }
+
throw new NotImplementedException("No PMF for this distribution implemented");
}
@@ -580,50 +457,40 @@ public static Prob Pmf(PrimitiveDist dist, T k)
///
public static Prob Pdf(Dist dist, double y)
{
- if (dist is Primitive)
- {
- var primitive = dist as Primitive;
+ if (dist is Primitive primitive)
return Pdf(primitive.dist, y);
- }
+
throw new ArgumentException("Can only calculate PDF for primitive distributions");
}
public static Prob Pdf(Dist dist, double[] y)
{
- if (dist is Primitive)
- {
- var primitive = dist as Primitive;
+ if (dist is Primitive primitive)
return Pdf(primitive.dist, y);
- }
+
throw new ArgumentException("Can only calculate PDF for primitive distributions");
}
public static Prob Pdf(Dist> dist, Matrix y)
{
- if (dist is Primitive>)
- {
- var primitive = dist as Primitive>;
- return Pdf(primitive.dist, y);
- }
+ if (dist is Primitive> primitive)
+ return Pdf(primitive.dist, y);
+
throw new ArgumentException("Can only calculate PDF for primitive distributions");
}
public static Prob Pmf(Dist dist, T k)
{
- if (dist is Primitive)
- {
- var primitive = dist as Primitive;
+ if (dist is Primitive primitive)
return Pmf(primitive.dist, k);
- }
+
throw new ArgumentException("Can only calculate PDF for primitive distributions");
}
public static Prob Pmf(Dist dist, int y)
{
- if (dist is Primitive)
- {
- var primitive = dist as Primitive;
+ if (dist is Primitive primitive)
return Pmf(primitive.dist, y);
- }
+
throw new ArgumentException("Can only calculate pmf for primitive distributions");
}
@@ -642,9 +509,8 @@ public static IEnumerable Append(IEnumerable list, A value)
/// Sigmoid function
///
public static double Sigmoid(double x)
- {
- return 1 / (1 + Math.Exp(-x));
- }
+ => 1 / (1 + Math.Exp(-x));
+
#endregion
}
}
diff --git a/ProbabilityMonad/DistGadt.cs b/ProbabilityMonad/DistGadt.cs
index 7d995b3..0243922 100644
--- a/ProbabilityMonad/DistGadt.cs
+++ b/ProbabilityMonad/DistGadt.cs
@@ -54,20 +54,15 @@ public interface ParallelDistInterpreter
public class Pure : Dist
{
public readonly A Value;
+
public Pure(A value)
- {
- Value = value;
- }
+ => Value = value;
public X Run(DistInterpreter interpreter)
- {
- return interpreter.Pure(Value);
- }
+ => interpreter.Pure(Value);
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.Pure(Value);
- }
+ => interpreter.Pure(Value);
}
///
@@ -76,20 +71,15 @@ public X RunParallel(ParallelDistInterpreter interpreter)
public class Primitive : Dist
{
public readonly PrimitiveDist dist;
+
public Primitive(PrimitiveDist dist)
- {
- this.dist = dist;
- }
+ => this.dist = dist;
public X Run(DistInterpreter interpreter)
- {
- return interpreter.Primitive(dist);
- }
+ => interpreter.Primitive(dist);
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.Primitive(dist);
- }
+ => interpreter.Primitive(dist);
}
///
@@ -99,21 +89,15 @@ public class Conditional : Dist
{
public readonly Func likelihood;
public readonly Dist dist;
+
public Conditional(Func likelihood, Dist dist)
- {
- this.likelihood = likelihood;
- this.dist = dist;
- }
+ => (this.likelihood, this.dist) = (likelihood, dist);
public X Run(DistInterpreter interpreter)
- {
- return interpreter.Conditional(likelihood, dist);
- }
+ => interpreter.Conditional(likelihood, dist);
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.Conditional(likelihood, dist);
- }
+ => interpreter.Conditional(likelihood, dist);
}
///
@@ -125,6 +109,7 @@ public class RunIndependent3 : Dist
public readonly Dist second;
public readonly Dist third;
public readonly Func> run;
+
public RunIndependent3(Dist first, Dist second, Dist third, Func> run)
{
this.first = first;
@@ -145,9 +130,7 @@ from result in run(x, y, z)
}
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.RunIndependent3(first, second, third, run);
- }
+ => interpreter.RunIndependent3(first, second, third, run);
}
///
@@ -158,6 +141,7 @@ public class RunIndependent : Dist
public readonly Dist first;
public readonly Dist second;
public readonly Func> run;
+
public RunIndependent(Dist first, Dist second, Func> run)
{
this.first = first;
@@ -176,9 +160,7 @@ from result in run(x, y)
}
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.RunIndependent(first, second, run);
- }
+ => interpreter.RunIndependent(first, second, run);
}
///
@@ -187,21 +169,16 @@ public X RunParallel(ParallelDistInterpreter interpreter)
public class Independent : Dist>
{
public readonly Dist dist;
+
public Independent(Dist dist)
- {
- this.dist = dist;
- }
+ => this.dist = dist;
// Run sequentially if we're not using a parallel interpreter
public X Run(DistInterpreter, X> interpreter)
- {
- return Return(dist).Run(interpreter);
- }
+ => Return(dist).Run(interpreter);
public X RunParallel(ParallelDistInterpreter, X> interpreter)
- {
- return interpreter.Independent(Return(dist));
- }
+ => interpreter.Independent(Return(dist));
}
///
@@ -211,6 +188,7 @@ public class Bind : Dist
{
public readonly Dist dist;
public readonly Func> bind;
+
public Bind(Dist dist, Func> bind)
{
this.dist = dist;
@@ -218,14 +196,10 @@ public Bind(Dist dist, Func> bind)
}
public X Run(DistInterpreter interpreter)
- {
- return interpreter.Bind(dist, bind);
- }
+ => interpreter.Bind(dist, bind);
public X RunParallel(ParallelDistInterpreter interpreter)
- {
- return interpreter.Bind(dist, bind);
- }
+ => interpreter.Bind(dist, bind);
}
///
@@ -235,36 +209,27 @@ public X RunParallel(ParallelDistInterpreter interpreter)
public static class DistExt
{
public static Dist Select(this Dist dist, Func f)
- {
- return new Bind(dist, a => new Pure(f(a)));
- }
+ => new Bind(dist, a => new Pure(f(a)));
public static Dist SelectMany(this Dist dist, Func> bind)
- {
- return new Bind(dist, bind);
- }
+ => new Bind(dist, bind);
public static Dist SelectMany(
this Dist dist,
Func> bind,
Func project
)
- {
- return
- new Bind(dist, a =>
- new Bind(bind(a), b =>
- new Pure(project(a, b))
- )
- );
- }
+ => new Bind(dist, a =>
+ new Bind(bind(a), b =>
+ new Pure(project(a, b))
+ )
+ );
///
/// Default to using recursion depth limit of 100
///
public static Dist> Sequence(this IEnumerable> dists)
- {
- return SequenceWithDepth(dists, 100);
- }
+ => SequenceWithDepth(dists, 100);
///
/// This implementation sort of does trampolining to avoid stack overflows,
diff --git a/ProbabilityMonad/Finite/FiniteDist.cs b/ProbabilityMonad/Finite/FiniteDist.cs
index 6c5937c..7848411 100644
--- a/ProbabilityMonad/Finite/FiniteDist.cs
+++ b/ProbabilityMonad/Finite/FiniteDist.cs
@@ -11,14 +11,10 @@ namespace ProbCSharp
public class FiniteDist
{
public FiniteDist(Samples samples)
- {
- Explicit = samples;
- }
+ => Explicit = samples;
public FiniteDist(params ItemProb[] samples)
- {
- Explicit = Samples(samples);
- }
+ => Explicit = Samples(samples);
public Samples Explicit { get; }
}
@@ -31,10 +27,8 @@ public static class FiniteDistMonad
/// fmap f (FiniteDist (Samples xs)) = FiniteDist $ Samples $ map (first f) xs
///
public static FiniteDist Select(this FiniteDist self, Func select)
- {
- return new FiniteDist(Samples(self.Explicit.Weights.Select(i =>
- ItemProb(select(i.Item), i.Prob))));
- }
+ => new FiniteDist(Samples(self.Explicit.Weights.Select(i =>
+ ItemProb(@select(i.Item), i.Prob))));
///
/// (FiniteDist dist) >>= bind = FiniteDist $ do
diff --git a/ProbabilityMonad/Finite/FiniteExtensions.cs b/ProbabilityMonad/Finite/FiniteExtensions.cs
index 30558c0..91279ac 100644
--- a/ProbabilityMonad/Finite/FiniteExtensions.cs
+++ b/ProbabilityMonad/Finite/FiniteExtensions.cs
@@ -29,13 +29,11 @@ public static A Pick(this FiniteDist distribution, Prob pickProb)
/// Lifts a FiniteDist into a SampleableDist
///
public static PrimitiveDist ToSampleDist(this FiniteDist dist)
- {
- return new SampleDist(() =>
+ => new SampleDist(() =>
{
var rand = new MathNet.Numerics.Distributions.ContinuousUniform().Sample();
return dist.Pick(Prob(rand));
});
- }
///
/// Returns the probability of a certain event
@@ -51,71 +49,55 @@ public static Prob ProbOf(this FiniteDist dist, Func eventTest)
/// Reweight by a probability that depends on associated item
///
public static FiniteDist ConditionSoft(this FiniteDist distribution, Func likelihood)
- {
- return new FiniteDist(
+ => new FiniteDist(
distribution.Explicit
.Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob)))
.Normalize()
);
- }
///
/// Reweight by a probability that depends on associated item, without normalizing
///
public static FiniteDist ConditionSoftUnnormalized(this FiniteDist distribution, Func likelihood)
- {
- return new FiniteDist(
+ => new FiniteDist(
distribution.Explicit.Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob)))
);
- }
///
/// Hard reweight by a condition that depends on associated item
///
public static FiniteDist ConditionHard(this FiniteDist distribution, Func condition)
- {
- return new FiniteDist(
+ => new FiniteDist(
distribution.Explicit
.Select(p => ItemProb(p.Item, condition(p.Item) ? p.Prob : Prob(0)))
.Normalize()
);
-
-
- }
///
/// Computes the posterior distribution, given a piece of data and a likelihood function
///
public static FiniteDist UpdateOn(this FiniteDist prior, Func likelihood, D datum)
- {
- return prior.ConditionSoft(w => likelihood(w, datum));
- }
+ => prior.ConditionSoft(w => likelihood(w, datum));
///
/// Computes the posterior distribution, given a list of data and a likelihood function
///
public static FiniteDist UpdateOn(this FiniteDist prior, Func likelihood, IEnumerable data)
- {
- return data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum));
- }
+ => data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum));
///
/// Normalize a finite distribution
///
public static FiniteDist Normalize(this FiniteDist dist)
- {
- return new FiniteDist(dist.Explicit.Normalize());
- }
-
+ => new FiniteDist(dist.Explicit.Normalize());
+
///
/// Join two independent distributions
///
public static FiniteDist> Join(this FiniteDist self, FiniteDist other)
- {
- return from a in self
- from b in other
- select new Tuple(a, b);
- }
+ => from a in self
+ from b in other
+ select new Tuple(a, b);
///
/// Returns all elements, and the collection without that element
diff --git a/ProbabilityMonad/Histogram.cs b/ProbabilityMonad/Histogram.cs
index 94e13c6..f4c70ec 100644
--- a/ProbabilityMonad/Histogram.cs
+++ b/ProbabilityMonad/Histogram.cs
@@ -67,15 +67,16 @@ internal static List MakeIntBucketList(int min, int max, int numBuckets)
internal static string ShowBuckets(IEnumerable buckets, double scale)
{
var sb = new StringBuilder();
- Func formatDouble = d => $"{d:N2}";
- var minPadding = LongestString(buckets, b => formatDouble(b.Min));
- var maxPadding = LongestString(buckets, b => formatDouble(b.Max));
+ string FormatDouble(double d) => d.ToString("N2");
+
+ var minPadding = LongestString(buckets, b => FormatDouble(b.Min));
+ var maxPadding = LongestString(buckets, b => FormatDouble(b.Max));
var barScale = BarScale(buckets.Select(b => b.BarSize), scale);
foreach (var bucket in buckets)
{
- var min = formatDouble(bucket.Min).PadLeft(minPadding, ' ');
- var max = formatDouble(bucket.Max).PadLeft(maxPadding, ' ');
+ var min = FormatDouble(bucket.Min).PadLeft(minPadding, ' ');
+ var max = FormatDouble(bucket.Max).PadLeft(maxPadding, ' ');
sb.AppendLine($"{min} {max} {Bar((int) (bucket.BarSize * barScale))}");
}
sb.AppendLine("");
@@ -86,9 +87,7 @@ internal static string ShowBuckets(IEnumerable buckets, double scale)
/// Returns the size of the longest string representation of an item in a list
///
internal static int LongestString(IEnumerable list, Func toString)
- {
- return list.Select(x => toString(x).Length).Max();
- }
+ => list.Select(x => toString(x).Length).Max();
///
/// Calculate scale factor for a list of bar lengths
@@ -103,22 +102,13 @@ internal static double BarScale(IEnumerable barLengths, double scale)
/// Draw bar of width n
///
internal static string Bar(int n)
- {
- var barBuilder = new StringBuilder();
- for (var i = 0; i < n; i++)
- {
- barBuilder.Append("#");
- }
- return barBuilder.ToString();
- }
+ => new string('#', n);
///
/// Return sum of list with a given value function
///
internal static double Sum(Func getVal, IEnumerable list)
- {
- return list.Select(getVal).Sum();
- }
+ => list.Select(getVal).Sum();
///
/// Generates a weighted histogram from a list of ItemProbs
@@ -147,9 +137,7 @@ public static string Weighted(IEnumerable> nums, int numBuckets
///
/// Scale factor. Defaults to 40
public static string Weighted(Samples nums, int numBuckets = 10, double scale = DEFAULT_SCALE)
- {
- return Weighted(nums.Weights, numBuckets, scale);
- }
+ => Weighted(nums.Weights, numBuckets, scale);
///
/// Generates an unweigted histogram for a list of numbers
@@ -178,9 +166,7 @@ public static string Unweighted(IEnumerable nums, int numBuckets = 10, d
///
/// Scale factor. Defaults to 40
public static string Unweighted(IEnumerable nums, int numBuckets = 10, double scale = DEFAULT_SCALE)
- {
- return Unweighted(nums.Select(x => (double)x), numBuckets, scale);
- }
+ => Unweighted(nums.Select(x => (double)x), numBuckets, scale);
///
/// Generates a histogram from some samples, grouping by a show function.
diff --git a/ProbabilityMonad/Inference/Importance.cs b/ProbabilityMonad/Inference/Importance.cs
index 6833020..8fae155 100644
--- a/ProbabilityMonad/Inference/Importance.cs
+++ b/ProbabilityMonad/Inference/Importance.cs
@@ -10,9 +10,7 @@ public static class ImportanceExt
/// Importance sample with given number of samples
///
public static Dist> ImportanceSamples(this Dist dist, int numSamples)
- {
- return Importance.ImportanceSamples(numSamples, dist);
- }
+ => Importance.ImportanceSamples(numSamples, dist);
}
///
@@ -35,11 +33,9 @@ public static Dist> Resample(Samples samples)
/// Flattens nested samples
///
public static Samples Flatten(Samples> samples)
- {
- return Samples(from outer in Normalize(samples).Weights
- from inner in outer.Item.Weights
- select ItemProb(inner.Item, inner.Prob.Mult(outer.Prob)));
- }
+ => Samples(from outer in Normalize(samples).Weights
+ from inner in outer.Item.Weights
+ select ItemProb(inner.Item, inner.Prob.Mult(outer.Prob)));
///
/// Performs importance sampling on a distribution
@@ -57,11 +53,9 @@ public static Dist> ImportanceSamples(int numSamples, Dist dist
///
/// An importance sampled distribution
public static Dist ImportanceDist(int numSamples, Dist dist)
- {
- return from probs in ImportanceSamples(numSamples, dist)
- from resampled in Categorical(probs)
- select resampled;
- }
+ => from probs in ImportanceSamples(numSamples, dist)
+ from resampled in Categorical(probs)
+ select resampled;
///
/// Normalizes a list of samples
diff --git a/ProbabilityMonad/Inference/MetropolisHastings.cs b/ProbabilityMonad/Inference/MetropolisHastings.cs
index 2c552ac..47ae9e3 100644
--- a/ProbabilityMonad/Inference/MetropolisHastings.cs
+++ b/ProbabilityMonad/Inference/MetropolisHastings.cs
@@ -42,9 +42,6 @@ from accept in Primitive(BernoulliF(Prob(Math.Min(1.0, candidate.Prob.Div(chainS
/// Returns the value of the metropolis chain at specified index
///
public static Dist MHIndex(Dist dist, int n, int index)
- {
- return MHPrior(dist, n).Select(list => list.ElementAt(index));
- }
-
+ => MHPrior(dist, n).Select(list => list.ElementAt(index));
}
}
diff --git a/ProbabilityMonad/Inference/Pimh.cs b/ProbabilityMonad/Inference/Pimh.cs
index 29066c0..96b1d46 100644
--- a/ProbabilityMonad/Inference/Pimh.cs
+++ b/ProbabilityMonad/Inference/Pimh.cs
@@ -8,8 +8,6 @@ public static class Pimh
/// Particle indepedent Metropolis-Hastings
///
public static Dist>> Create(int numParticles, int chainLen, Dist dist)
- {
- return MetropolisHastings.MHPrior(dist.Run(new Smc(numParticles)), chainLen);
- }
+ => MetropolisHastings.MHPrior(dist.Run(new Smc(numParticles)), chainLen);
}
}
diff --git a/ProbabilityMonad/Inference/Prior.cs b/ProbabilityMonad/Inference/Prior.cs
index 80b0dd3..7402e0b 100644
--- a/ProbabilityMonad/Inference/Prior.cs
+++ b/ProbabilityMonad/Inference/Prior.cs
@@ -8,29 +8,19 @@ namespace ProbCSharp
public class Prior : DistInterpreter>
{
public DistInterpreter New()
- {
- return new Prior() as DistInterpreter;
- }
+ => new Prior() as DistInterpreter;
Dist DistInterpreter>.Bind(Dist dist, Func> bind)
- {
- return new Bind(dist.Run(new Prior()), bind);
- }
+ => new Bind(dist.Run(new Prior()), bind);
Dist DistInterpreter>.Conditional(Func lik, Dist dist)
- {
- return dist.Run(new Prior());
- }
+ => dist.Run(new Prior());
Dist DistInterpreter>.Primitive(PrimitiveDist dist)
- {
- return new Pure(dist.Sample());
- }
+ => new Pure(dist.Sample());
Dist DistInterpreter>.Pure(A value)
- {
- return new Pure(value);
- }
+ => new Pure(value);
}
diff --git a/ProbabilityMonad/Inference/PriorWeighted.cs b/ProbabilityMonad/Inference/PriorWeighted.cs
index 70cefd1..8ea366e 100644
--- a/ProbabilityMonad/Inference/PriorWeighted.cs
+++ b/ProbabilityMonad/Inference/PriorWeighted.cs
@@ -7,9 +7,7 @@ namespace ProbCSharp
public static class PriorWeightedExtensions
{
public static Dist> WeightedPrior(this Dist dist)
- {
- return dist.Run(new PriorWeighted());
- }
+ => dist.Run(new PriorWeighted());
}
///
@@ -20,32 +18,22 @@ public static Dist> WeightedPrior(this Dist dist)
public class PriorWeighted : DistInterpreter>>
{
public Dist> Bind(Dist dist, Func> bind)
- {
- return from x in dist.Run(new PriorWeighted())
- from y in bind(x.Item) // Don't remove conditionals here
- select new ItemProb(y, x.Prob);
- }
+ => from x in dist.Run(new PriorWeighted())
+ from y in bind(x.Item) // Don't remove conditionals here
+ select new ItemProb(y, x.Prob);
public Dist> Conditional(Func lik, Dist dist)
- {
- return from itemProb in dist.Run(new PriorWeighted())
- select new ItemProb(itemProb.Item, itemProb.Prob.Mult(lik(itemProb.Item)));
- }
+ => from itemProb in dist.Run(new PriorWeighted())
+ select new ItemProb(itemProb.Item, itemProb.Prob.Mult(lik(itemProb.Item)));
public DistInterpreter New()
- {
- return new PriorWeighted() as DistInterpreter;
- }
+ => new PriorWeighted() as DistInterpreter;
public Dist> Primitive(PrimitiveDist dist)
- {
- return new Primitive>(dist.Select(a => ItemProb(a, Prob(1))));
- }
+ => new Primitive>(dist.Select(a => ItemProb(a, Prob(1))));
public Dist> Pure(A value)
- {
- return new Pure>(new ItemProb(value, Prob(1)));
- }
+ => new Pure>(new ItemProb(value, Prob(1)));
}
}
diff --git a/ProbabilityMonad/Inference/Smc.cs b/ProbabilityMonad/Inference/Smc.cs
index 4b4a8cd..94f7d5c 100644
--- a/ProbabilityMonad/Inference/Smc.cs
+++ b/ProbabilityMonad/Inference/Smc.cs
@@ -12,17 +12,13 @@ public static class SmcExtensions
/// SMC that discards pseudo-marginal likelihoods
///
public static Dist> SmcStandard(this Dist dist, int n)
- {
- return dist.Run(new Smc(n)).Run(new Prior>());
- }
+ => dist.Run(new Smc(n)).Run(new Prior>());
///
/// SMC that does importance sampling
///
public static Dist> SmcMultiple(this Dist dist, int numSamples, int numParticles)
- {
- return dist.Run(new Smc(numParticles)).ImportanceSamples(numSamples).Select(Importance.Flatten);
- }
+ => dist.Run(new Smc(numParticles)).ImportanceSamples(numSamples).Select(Importance.Flatten);
}
///
@@ -32,17 +28,13 @@ public class Smc : DistInterpreter>>
{
public int numParticles;
public Smc(int numParticles)
- {
- this.numParticles = numParticles;
- }
+ => this.numParticles = numParticles;
public Dist> Bind(Dist dist, Func> bind)
- {
- return from ps in dist.Run(new Smc(numParticles))
- let unzipped = ps.Unzip()
- from ys in unzipped.Item1.Select(bind).Sequence()
- select Samples(ys.Zip(unzipped.Item2, ItemProb));
- }
+ => from ps in dist.Run(new Smc(numParticles))
+ let unzipped = ps.Unzip()
+ from ys in unzipped.Item1.Select(bind).Sequence()
+ select Samples(ys.Zip(unzipped.Item2, ItemProb));
public Dist> Conditional(Func lik, Dist dist)
{
@@ -58,9 +50,7 @@ select ps.Select(ip => ItemProb(ip.Item, lik(ip.Item).Mult(ip.Prob)))
}
public DistInterpreter New()
- {
- return new Smc(numParticles) as DistInterpreter;
- }
+ => new Smc(numParticles) as DistInterpreter;
public Dist> Primitive(PrimitiveDist dist)
{
diff --git a/ProbabilityMonad/ItemProb.cs b/ProbabilityMonad/ItemProb.cs
index df360dc..84750c4 100644
--- a/ProbabilityMonad/ItemProb.cs
+++ b/ProbabilityMonad/ItemProb.cs
@@ -7,15 +7,11 @@ public class ItemProb
{
public A Item { get; }
public Prob Prob { get; }
+
public ItemProb(A item, Prob prob)
- {
- Item = item;
- Prob = prob;
- }
+ => (Item, Prob) = (item, prob);
public override string ToString()
- {
- return $"ItemProb ({Item}, {Prob})";
- }
+ => $"ItemProb ({Item}, {Prob})";
}
}
diff --git a/ProbabilityMonad/KullbackLeibler.cs b/ProbabilityMonad/KullbackLeibler.cs
index e8f5118..c60818b 100644
--- a/ProbabilityMonad/KullbackLeibler.cs
+++ b/ProbabilityMonad/KullbackLeibler.cs
@@ -13,17 +13,20 @@ public static class KullbackLeibler
public static double KLDivergenceF(FiniteDist distQ, FiniteDist distP, Func keyFunc) where A : IComparable
{
var qWeights = Enumerate(distQ, keyFunc);
- Func qDensity = x =>
+
+ Prob QDensity(A x)
{
var weight = qWeights.FirstOrDefault(ip => keyFunc(ip.Item).Equals(keyFunc(x)));
- if (weight == null) return Prob(0);
+ if (weight == null)
+ return Prob(0);
+
return weight.Prob;
- };
+ }
var pWeights = Enumerate(distP, keyFunc);
var divergences = pWeights.Weights
- .Select(w => w.Prob.Value * Math.Log(w.Prob.Div(qDensity(w.Item)).Value));
+ .Select(w => w.Prob.Value * Math.Log(w.Prob.Div(QDensity(w.Item)).Value));
return divergences.Sum();
}
diff --git a/ProbabilityMonad/LogNormalCreate.cs b/ProbabilityMonad/LogNormalCreate.cs
deleted file mode 100644
index 97022a2..0000000
--- a/ProbabilityMonad/LogNormalCreate.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using MathNet.Numerics.Distributions;
-
-namespace ProbCSharp
-{
- class LogNormalCreate
- {
- public static LogNormal fromMeanVariance(double mean, double variance)
- {
- return LogNormal.WithMeanVariance(mean, variance);
- }
- }
-}
diff --git a/ProbabilityMonad/ParallelSampler.cs b/ProbabilityMonad/ParallelSampler.cs
index 30fe000..5e8b59e 100644
--- a/ProbabilityMonad/ParallelSampler.cs
+++ b/ProbabilityMonad/ParallelSampler.cs
@@ -14,9 +14,7 @@ public static class ParallelSampleExtensions
/// This will throw an exception if the distribution contains any conditionals.
///
public static A SampleParallel(this Dist dist)
- {
- return dist.RunParallel(new ParallelSampler());
- }
+ => dist.RunParallel(new ParallelSampler());
///
/// Draws n samples from a distribution in parallel
@@ -24,9 +22,7 @@ public static A SampleParallel(this Dist dist)
/// This will throw an exception if the distribution contains any conditionals.
///
public static IEnumerable SampleNParallel(this Dist dist, int n)
- {
- return Enumerable.Range(0, n).Select(_ => dist.SampleParallel());
- }
+ => Enumerable.Range(0, n).Select(_ => dist.SampleParallel());
}
///
@@ -41,24 +37,16 @@ public A Bind(Dist dist, Func> bind)
}
public A Conditional(Func lik, Dist dist)
- {
- throw new ArgumentException("Cannot sample from conditional distribution.");
- }
+ => throw new ArgumentException("Cannot sample from conditional distribution.");
public A Primitive(PrimitiveDist dist)
- {
- return dist.Sample();
- }
+ => dist.Sample();
public A Pure(A value)
- {
- return value;
- }
+ => value;
public A Independent(Dist independent)
- {
- return independent.RunParallel(new ParallelSampler());
- }
+ => independent.RunParallel(new ParallelSampler());
public A RunIndependent(Dist distB, Dist distC, Func> run)
{
diff --git a/ProbabilityMonad/Primitive/Beta.cs b/ProbabilityMonad/Primitive/Beta.cs
index 9536868..86a112b 100644
--- a/ProbabilityMonad/Primitive/Beta.cs
+++ b/ProbabilityMonad/Primitive/Beta.cs
@@ -10,6 +10,7 @@ public class BetaPrimitive : PrimitiveDist
public double alpha;
public double beta;
public MathNet.Numerics.Distributions.Beta dist;
+
public BetaPrimitive(double alpha, double beta, Random gen)
{
this.alpha = alpha;
@@ -18,8 +19,6 @@ public BetaPrimitive(double alpha, double beta, Random gen)
}
public Func Sample
- {
- get { return () => dist.Sample(); }
- }
+ => () => dist.Sample();
}
}
diff --git a/ProbabilityMonad/Primitive/Categorical.cs b/ProbabilityMonad/Primitive/Categorical.cs
index 6379373..24b39c6 100644
--- a/ProbabilityMonad/Primitive/Categorical.cs
+++ b/ProbabilityMonad/Primitive/Categorical.cs
@@ -7,36 +7,31 @@
namespace ProbCSharp
{
- ///
- /// Primitive Categorical distribution
- ///
- public class CategoricalPrimitive : PrimitiveDist
- {
- public double[] ProbabilityMass { get; }
- public T[] Items { get; }
- public Random Gen { get; }
- public Categorical categorical { get; }
- public Dictionary ItemIndex {get ;}
+ ///
+ /// Primitive Categorical distribution
+ ///
+ public class CategoricalPrimitive : PrimitiveDist
+ {
+ public double[] ProbabilityMass { get; }
+ public T[] Items { get; }
+ public Random Gen { get; }
+ public Categorical categorical { get; }
+ public Dictionary ItemIndex { get; }
- public CategoricalPrimitive(T[] items , double[] probabilities, Random gen)
- {
- Gen = gen;
- Items = items;
- ProbabilityMass = probabilities;
- categorical = new Categorical(ProbabilityMass, gen);
- ItemIndex = new Dictionary();
- for (int c = 0; c < items.Length; c++)
- {
- ItemIndex.Add(items[c],c);
- }
- }
+ public CategoricalPrimitive(T[] items, double[] probabilities, Random gen)
+ {
+ Gen = gen;
+ Items = items;
+ ProbabilityMass = probabilities;
+ categorical = new Categorical(ProbabilityMass, gen);
+ ItemIndex = new Dictionary();
+ for (int c = 0; c < items.Length; c++)
+ {
+ ItemIndex.Add(items[c], c);
+ }
+ }
- public Func Sample
- {
- get
- {
- return () => Items[categorical.Sample()];
- }
- }
- }
+ public Func Sample
+ => () => Items[categorical.Sample()];
+ }
}
diff --git a/ProbabilityMonad/Primitive/Dirichlet.cs b/ProbabilityMonad/Primitive/Dirichlet.cs
index f826bfa..cad475a 100644
--- a/ProbabilityMonad/Primitive/Dirichlet.cs
+++ b/ProbabilityMonad/Primitive/Dirichlet.cs
@@ -2,7 +2,6 @@
namespace ProbCSharp
{
-
///
/// Primitive Dirichlet distribution
///
@@ -11,16 +10,15 @@ public class DirichletPrimitive : PrimitiveDist
public double[] alpha;
public MathNet.Numerics.Distributions.Dirichlet dist;
+
public DirichletPrimitive(double[] alpha, Random gen)
{
this.alpha = alpha;
-
+
dist = new MathNet.Numerics.Distributions.Dirichlet(alpha, gen);
}
public Func Sample
- {
- get { return () => dist.Sample(); }
- }
+ => () => dist.Sample();
}
}
diff --git a/ProbabilityMonad/Primitive/Gamma.cs b/ProbabilityMonad/Primitive/Gamma.cs
index 16389dc..3d4c238 100644
--- a/ProbabilityMonad/Primitive/Gamma.cs
+++ b/ProbabilityMonad/Primitive/Gamma.cs
@@ -1,21 +1,24 @@
using System;
-namespace ProbCSharp {
+namespace ProbCSharp
+{
///
/// Primitive Gamma distribution
///
- public class GammaPrimitive : PrimitiveDist {
+ public class GammaPrimitive : PrimitiveDist
+ {
public double shape;
public double rate;
public MathNet.Numerics.Distributions.Gamma dist;
- public GammaPrimitive(double shape, double rate, Random gen) {
+
+ public GammaPrimitive(double shape, double rate, Random gen)
+ {
this.shape = shape;
this.rate = rate;
dist = new MathNet.Numerics.Distributions.Gamma(shape, rate);
}
- public Func Sample {
- get { return () => dist.Sample(); }
- }
+ public Func Sample
+ => () => dist.Sample();
}
}
diff --git a/ProbabilityMonad/Primitive/LogNormal.cs b/ProbabilityMonad/Primitive/LogNormal.cs
index 479e472..4885f31 100644
--- a/ProbabilityMonad/Primitive/LogNormal.cs
+++ b/ProbabilityMonad/Primitive/LogNormal.cs
@@ -1,36 +1,36 @@
using System;
using MathNet.Numerics.Distributions;
-namespace ProbCSharp {
- ///
- /// Primitive Lognormal distribution
- ///
- public class LogNormalPrimitive : PrimitiveDist
- {
- public double mean;
- public double variance;
- public double mu;
- public double sigma;
- public LogNormal dist;
- public LogNormalPrimitive(double mu, double sigma, bool dummy, Random gen)
- {
- dist = new LogNormal(mu, sigma);
- mu = dist.Mu;
- sigma = dist.Sigma;
- }
- public LogNormalPrimitive(double mean, double variance, Random gen)
- {
- this.mean = mean;
- this.variance = variance;
- dist = LogNormalCreate.fromMeanVariance(mean, variance); // MathNet.Numerics.Distributions.LogNormal(mu, sigma);
- mu = dist.Mu;
- sigma = dist.Sigma;
- }
+namespace ProbCSharp
+{
+ ///
+ /// Primitive Lognormal distribution
+ ///
+ public class LogNormalPrimitive : PrimitiveDist
+ {
+ public double mean;
+ public double variance;
+ public double mu;
+ public double sigma;
+ public LogNormal dist;
- public Func Sample
- {
- get { return () => dist.Sample(); }
- }
- }
+ public LogNormalPrimitive(double mu, double sigma, bool dummy, Random gen)
+ {
+ dist = new LogNormal(mu, sigma);
+ mu = dist.Mu;
+ sigma = dist.Sigma;
+ }
+
+ public LogNormalPrimitive(double mean, double variance, Random gen)
+ {
+ this.mean = mean;
+ this.variance = variance;
+ dist = LogNormal.WithMeanVariance(mean, variance); // MathNet.Numerics.Distributions.LogNormal(mu, sigma);
+ mu = dist.Mu;
+ sigma = dist.Sigma;
+ }
+
+ public Func Sample
+ => () => dist.Sample();
+ }
}
-
\ No newline at end of file
diff --git a/ProbabilityMonad/Primitive/MultiVariateNormal.cs b/ProbabilityMonad/Primitive/MultiVariateNormal.cs
index 1e5de12..461ffcb 100644
--- a/ProbabilityMonad/Primitive/MultiVariateNormal.cs
+++ b/ProbabilityMonad/Primitive/MultiVariateNormal.cs
@@ -5,39 +5,31 @@
namespace ProbCSharp
{
- ///
- /// Primitive Multivariate Normal distribution
- ///
- public class MultiVariateNormalPrimitive : PrimitiveDist
- {
- public double[] Mean { get; }
- public Matrix Covariance { get; }
- public Random Gen { get; }
- public MatrixNormal mvn { get; }
+ ///
+ /// Primitive Multivariate Normal distribution
+ ///
+ public class MultiVariateNormalPrimitive : PrimitiveDist
+ {
+ public double[] Mean { get; }
+ public Matrix Covariance { get; }
+ public Random Gen { get; }
+ public MatrixNormal mvn { get; }
- static public MatrixNormal CreateMultivariateNormal(double[] meanVector, Matrix cv)
- {
- return new MatrixNormal(DenseMatrix.OfRowArrays(meanVector),
- DenseMatrix.OfRowArrays(new double[][] { new double[] { 1.0 } }),
- cv);
- }
+ public static MatrixNormal CreateMultivariateNormal(double[] meanVector, Matrix cv)
+ => new MatrixNormal(
+ m: DenseMatrix.OfRowArrays(meanVector),
+ v: DenseMatrix.OfRowArrays(new double[][] { new double[] { 1.0 } }),
+ k: cv);
- public MultiVariateNormalPrimitive(double[] mean, Matrix covariance, Random gen)
- {
- Mean = mean;
- Covariance = covariance;
- Gen = gen;
- mvn = MultiVariateNormalPrimitive.CreateMultivariateNormal(mean, covariance);
- }
+ public MultiVariateNormalPrimitive(double[] mean, Matrix covariance, Random gen)
+ {
+ Mean = mean;
+ Covariance = covariance;
+ Gen = gen;
+ mvn = CreateMultivariateNormal(mean, covariance);
+ }
- public Func Sample
- {
- get
- {
- return () => mvn.Sample().Row(0).ToArray();
- }
- }
- }
+ public Func Sample
+ => () => mvn.Sample().Row(0).ToArray();
+ }
}
-
-
\ No newline at end of file
diff --git a/ProbabilityMonad/Primitive/Normal.cs b/ProbabilityMonad/Primitive/Normal.cs
index 0c35483..3a8274d 100644
--- a/ProbabilityMonad/Primitive/Normal.cs
+++ b/ProbabilityMonad/Primitive/Normal.cs
@@ -10,23 +10,18 @@ public class NormalPrimitive : PrimitiveDist
{
public double Mean { get; }
public double Variance { get; }
- public Random Gen {get;}
+ public Random Gen { get; }
public Normal normal { get; }
- public NormalPrimitive(double mean, double variance, Random gen)
- {
+ public NormalPrimitive(double mean, double variance, Random gen)
+ {
Mean = mean;
Variance = variance;
Gen = gen;
normal = Normal.WithMeanVariance(Mean, Variance, Gen);
- }
-
- public Func Sample
- {
- get
- {
- return () => normal.Sample();
- }
}
+
+ public Func Sample
+ => () => normal.Sample();
}
}
diff --git a/ProbabilityMonad/Primitive/Poisson.cs b/ProbabilityMonad/Primitive/Poisson.cs
index 5911bf7..6642aad 100644
--- a/ProbabilityMonad/Primitive/Poisson.cs
+++ b/ProbabilityMonad/Primitive/Poisson.cs
@@ -9,9 +9,10 @@ namespace ProbCSharp
///
public class PoissonPrimitive : PrimitiveDist
{
- public double Lambda;
+ public double Lambda;
public Poisson Dist;
public Random Gen;
+
public PoissonPrimitive(double lambda, Random gen)
{
Lambda = lambda;
@@ -20,11 +21,6 @@ public PoissonPrimitive(double lambda, Random gen)
}
public Func Sample
- {
- get
- {
- return () => Dist.Sample();
- }
- }
+ => () => Dist.Sample();
}
}
diff --git a/ProbabilityMonad/Primitive/PrimitiveDist.cs b/ProbabilityMonad/Primitive/PrimitiveDist.cs
index f688b98..340fc2c 100644
--- a/ProbabilityMonad/Primitive/PrimitiveDist.cs
+++ b/ProbabilityMonad/Primitive/PrimitiveDist.cs
@@ -19,15 +19,12 @@ public interface PrimitiveDist
public class SampleDist : PrimitiveDist
{
private readonly Func _sample;
+
public SampleDist(Func sample)
- {
- _sample = sample;
- }
+ => _sample = sample;
public Func Sample
- {
- get { return _sample; }
- }
+ => _sample;
}
@@ -35,23 +32,17 @@ public Func Sample
public static class PrimitiveDistMonad
{
public static PrimitiveDist Select(this PrimitiveDist self, Func f)
- {
- return new SampleDist(() => f(self.Sample()));
- }
+ => new SampleDist(() => f(self.Sample()));
public static PrimitiveDist SelectMany(
this PrimitiveDist self,
Func> bind,
- Func project
- )
- {
- return new SampleDist(() =>
+ Func project)
+ => new SampleDist(() =>
{
var firstSample = self.Sample();
var secondSample = bind(firstSample).Sample();
return project(firstSample, secondSample);
});
- }
}
-
}
diff --git a/ProbabilityMonad/Primitive/Uniform.cs b/ProbabilityMonad/Primitive/Uniform.cs
index 4090a93..c12c017 100644
--- a/ProbabilityMonad/Primitive/Uniform.cs
+++ b/ProbabilityMonad/Primitive/Uniform.cs
@@ -3,20 +3,17 @@
namespace ProbCSharp
{
- public class ContinuousUniformPrimitive : PrimitiveDist
- {
- public ContinuousUniform dist;
- public ContinuousUniformPrimitive(double lower, double upper, Random gen)
- {
- dist = new ContinuousUniform(lower,upper);
- }
- public ContinuousUniformPrimitive(Random gen)
- {
- dist = new ContinuousUniform();
- }
- public Func Sample
- {
- get { return () => dist.Sample(); }
- }
- }
+ public class ContinuousUniformPrimitive : PrimitiveDist
+ {
+ public ContinuousUniform dist;
+
+ public ContinuousUniformPrimitive(double lower, double upper, Random gen)
+ => dist = new ContinuousUniform(lower, upper);
+
+ public ContinuousUniformPrimitive(Random gen)
+ => dist = new ContinuousUniform();
+
+ public Func Sample
+ => () => dist.Sample();
+ }
}
diff --git a/ProbabilityMonad/Primitive/Wishart.cs b/ProbabilityMonad/Primitive/Wishart.cs
index 10d48d7..42ac744 100644
--- a/ProbabilityMonad/Primitive/Wishart.cs
+++ b/ProbabilityMonad/Primitive/Wishart.cs
@@ -4,30 +4,25 @@
namespace ProbCSharp
{
- ///
- /// Primitive Wishart distribution
- ///
- public class WishartPrimitive : PrimitiveDist>
- {
- public double DegreesOfFreedom { get; }
- public Matrix Scale { get; }
- public Random Gen { get; }
- public Wishart Wishart { get; }
-
- public WishartPrimitive(double dof, Matrix scale, Random gen)
+ ///
+ /// Primitive Wishart distribution
+ ///
+ public class WishartPrimitive : PrimitiveDist>
{
- DegreesOfFreedom = dof;
- Scale = scale;
- Gen = gen;
- Wishart = new Wishart(dof, scale);
- }
+ public double DegreesOfFreedom { get; }
+ public Matrix Scale { get; }
+ public Random Gen { get; }
+ public Wishart Wishart { get; }
- public Func> Sample
- {
- get
- {
- return () => Wishart.Sample();
- }
+ public WishartPrimitive(double dof, Matrix scale, Random gen)
+ {
+ DegreesOfFreedom = dof;
+ Scale = scale;
+ Gen = gen;
+ Wishart = new Wishart(dof, scale);
+ }
+
+ public Func> Sample
+ => () => Wishart.Sample();
}
- }
}
diff --git a/ProbabilityMonad/Prob.cs b/ProbabilityMonad/Prob.cs
index e8eea2d..47b5bee 100644
--- a/ProbabilityMonad/Prob.cs
+++ b/ProbabilityMonad/Prob.cs
@@ -21,17 +21,10 @@ public class DoubleProb : Prob
public double Value { get; }
public double LogValue
- {
- get
- {
- return Math.Log(Value);
- }
- }
+ => Math.Log(Value);
public DoubleProb(double probability)
- {
- Value = probability;
- }
+ => Value = probability;
public override string ToString()
{
@@ -40,24 +33,16 @@ public override string ToString()
}
public Prob Mult(Prob other)
- {
- return new DoubleProb(Value * other.Value);
- }
+ => new DoubleProb(Value * other.Value);
public Prob Div(Prob other)
- {
- return new DoubleProb(Value / other.Value);
- }
+ => new DoubleProb(Value / other.Value);
public int CompareTo(Prob other)
- {
- return Value.CompareTo(other.Value);
- }
+ => Value.CompareTo(other.Value);
public bool Equals(Prob other)
- {
- return Value.Equals(other.Value);
- }
+ => Value.Equals(other.Value);
}
///
@@ -75,45 +60,24 @@ public LogProb(double logProb)
}
public double Value
- {
- get
- {
- return Math.Exp(logProb);
- }
- }
+ => Math.Exp(logProb);
public double LogValue
- {
- get
- {
- return logProb;
- }
- }
+ => logProb;
public override string ToString()
- {
- var str = $"{Value*100:G3}%";
- return str;
- }
+ => $"{Value*100:G3}%";
public int CompareTo(Prob other)
- {
- return logProb.CompareTo(other.LogValue);
- }
+ => logProb.CompareTo(other.LogValue);
public Prob Div(Prob other)
- {
- return new LogProb(logProb - other.LogValue);
- }
+ => new LogProb(logProb - other.LogValue);
public bool Equals(Prob other)
- {
- return other.LogValue == logProb;
- }
+ => other.LogValue == logProb;
public Prob Mult(Prob other)
- {
- return new LogProb(logProb + other.LogValue);
- }
+ => new LogProb(logProb + other.LogValue);
}
}
diff --git a/ProbabilityMonad/Sampler.cs b/ProbabilityMonad/Sampler.cs
index 28f2e2c..1a76909 100644
--- a/ProbabilityMonad/Sampler.cs
+++ b/ProbabilityMonad/Sampler.cs
@@ -11,18 +11,14 @@ public static class SamplerExt
/// This will throw an exception if the distribution contains any conditionals.
///
public static A Sample(this Dist dist)
- {
- return dist.Run(new Sampler());
- }
+ => dist.Run(new Sampler());
///
/// Draws n samples from a distribution.
/// This will throw an exception if the distribution contains any conditionals.
///
public static IEnumerable SampleN(this Dist dist, int n)
- {
- return Enumerable.Range(0, n).Select(_ => dist.Sample());
- }
+ => Enumerable.Range(0, n).Select(_ => dist.Sample());
}
///
@@ -31,9 +27,7 @@ public static IEnumerable SampleN(this Dist dist, int n)
public class Sampler : DistInterpreter
{
public DistInterpreter New()
- {
- return new Sampler() as DistInterpreter;
- }
+ => new Sampler() as DistInterpreter;
A DistInterpreter.Bind(Dist dist, Func> bind)
{
@@ -45,19 +39,13 @@ A DistInterpreter.Bind(Dist dist, Func> bind)
/// All conditionals must be removed before sampling.
///
A DistInterpreter.Conditional(Func lik, Dist dist)
- {
- throw new ArgumentException("Cannot sample from conditional distribution.");
- }
+ => throw new ArgumentException("Cannot sample from conditional distribution.");
A DistInterpreter.Primitive(PrimitiveDist dist)
- {
- return dist.Sample();
- }
+ => dist.Sample();
A DistInterpreter.Pure(A value)
- {
- return value;
- }
+ => value;
}
}
diff --git a/ProbabilityMonad/Samples.cs b/ProbabilityMonad/Samples.cs
index 2e6e009..bc7c65e 100644
--- a/ProbabilityMonad/Samples.cs
+++ b/ProbabilityMonad/Samples.cs
@@ -13,35 +13,25 @@ namespace ProbCSharp
public static class SamplesExt
{
public static Samples Select(this Samples self, Func, ItemProb> f)
- {
- return Samples(self.Weights.Select(f));
- }
+ => Samples(self.Weights.Select(f));
public static Samples MapSample(this Samples self, Func f)
- {
- return Samples(self.Weights.Select(ip => ItemProb(f(ip.Item), ip.Prob)));
- }
+ => Samples(self.Weights.Select(ip => ItemProb(f(ip.Item), ip.Prob)));
///
/// Sum all probabilities in samples
///
public static Prob SumProbs(this Samples self)
- {
- return Prob(self.Weights.Select(ip => ip.Prob.Value).Sum());
- }
+ => Prob(self.Weights.Select(ip => ip.Prob.Value).Sum());
public static Samples Normalize(this Samples self)
- {
- return Importance.Normalize(self);
- }
+ => Importance.Normalize(self);
///
/// Unzip samples into items and weights
///
public static Tuple, IEnumerable> Unzip(this Samples samples)
- {
- return new Tuple, IEnumerable>(samples.Weights.Select(ip => ip.Item), samples.Weights.Select(ip => ip.Prob));
- }
+ => new Tuple, IEnumerable>(samples.Weights.Select(ip => ip.Item), samples.Weights.Select(ip => ip.Prob));
}
// This is a Wrapper class for IEnumerable>
@@ -53,18 +43,12 @@ public class Samples : IEnumerable>
{
public readonly IEnumerable> Weights;
public Samples(IEnumerable> list)
- {
- Weights = list;
- }
+ => Weights = list;
public IEnumerator> GetEnumerator()
- {
- return Weights.GetEnumerator();
- }
+ => Weights.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
- {
- return Weights.GetEnumerator();
- }
+ => Weights.GetEnumerator();
}
}