Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analytic methods #62

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
360 changes: 360 additions & 0 deletions squigglepy/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections.abc import Iterable

from abc import ABC, abstractmethod
from functools import reduce
from numbers import Real


class BaseDistribution(ABC):
Expand Down Expand Up @@ -172,6 +174,43 @@ def __rpow__(self, dist):
def __hash__(self):
return hash(repr(self))

def simplify(self):
"""Simplify a distribution by evaluating all operations that can be
performed analytically, for example by reducing a sum of normal
distributions into a single normal distribution. Return a new
``OperableDistribution`` that represents the simplified distribution.

Possible simplifications:
any - any = any + (-any)
any / constant = any * (1 / constant)

-Normal = Normal

Normal + Normal = Normal
constant + Normal = Normal
Bernoulli + Bernoulli = Binomial
Bernoulli + Binomial = Binomial
Binomial + Binomial = Binomial
Chi-Square + Chi-Square = Chi-Square
Exponential + Exponential = Exponential
Exponential + Gamma = Gamma
Gamma + Gamma = Gamma
Poisson + Poisson = Poisson

constant * Normal = Normal
Lognormal * Lognormal = Lognormal
constant * Lognormal = Lognormal
constant * Exponential = Exponential
constant * Gamma = Gamma

Lognormal / Lognormal = Lognormal
constant / Lognormal = Lognormal

Lognormal ** constant = Lognormal

"""
return self


# Distribution are either discrete, continuous, or composite

Expand Down Expand Up @@ -242,6 +281,9 @@ def __str__(self):
raise ValueError
return out

def simplify(self):
return FlatTree.build(self).simplify()


def _get_fname(f, name):
if name is None:
Expand Down Expand Up @@ -1699,3 +1741,321 @@ def geometric(p):
<Distribution> geometric(0.1)
"""
return GeometricDistribution(p=p)


class FlatTree:
"""Helper class for simplifying analytic expressions. A ``FlatTree`` is
sort of like a ``ComplexDistribution`` except that it flattens
commutative/associative operations onto a single object instead of having
one object per binary operation.

This class operates in two phases.

Phase 1: Generate a ``FlatTree`` object from a :ref:``BaseDistribution`` by
calling ``FlatTree.build(dist)``. This generates a tree where any series of
a single commutative/associative operation done repeatedly is flattened
onto a single ``FlatTree`` node. It also converts operations into a
normalized form, for example converting ``a - b`` into ``a + (-b)``.

Phase 2: Generate a simplified ``Distribution`` by calling
:ref:``simplify``. This works by combing through each flat list of
distributions to find which ones can be analytically simplified (for
example, converting a sum of normal distributions into a single normal
distribution).

"""

COMMUTABLE_OPERATIONS = set([operator.add, operator.mul])

def __init__(self, dist=None, fn=None, fn_str=None, children=None, is_unary=False, infix=None):
self.dist = dist
self.fn = fn
self.fn_str = fn_str
self.children = children
self.is_unary = is_unary
self.infix = infix
if dist is not None:
self.is_leaf = True
elif fn is not None and children is not None:
self.is_leaf = False
else:
raise ValueError("Missing arguments to FlatTree constructor")

def __str__(self):
if self.is_leaf:
return f"FlatTree({self.dist})"
else:
return "FlatTree({})[{}]".format(self.fn_str, ", ".join(map(str, self.children)))

def __repr__(self):
return str(self)

@classmethod
def build(cls, dist):
if dist is None:
return None
if isinstance(dist, Real):
return cls(dist=dist)
if not isinstance(dist, BaseDistribution):
raise ValueError(f"dist must be a BaseDistribution or numeric type, not {type(dist)}")
if not isinstance(dist, ComplexDistribution):
return cls(dist=dist)

is_unary = dist.right is None
if is_unary and dist.right is not None:
raise ValueError(f"Multiple arguments provided for unary operator {dist.fn}")

# Convert x - y into x + (-y)
if dist.fn == operator.sub:
return cls.build(
ComplexDistribution(
dist.left,
ComplexDistribution(dist.right, right=None, fn=operator.neg, fn_str="-"),
fn=operator.add,
fn_str="+",
)
)

# If the denominator is a constant, replace division by constant
# with multiplication by the reciprocal of the constant
if dist.fn == operator.truediv and isinstance(dist.right, Real):
if dist.right == 0:
raise ZeroDivisionError("Division by zero in ComplexDistribution: {dist}")
return cls.build(
ComplexDistribution(
dist.left,
1 / dist.right,
fn=operator.mul,
fn_str="*",
)
)
if dist.fn == operator.truediv and isinstance(dist.right, LognormalDistribution):
return cls.build(
ComplexDistribution(
dist.left,
LognormalDistribution(
norm_mean=-dist.right.norm_mean, norm_sd=dist.right.norm_sd
),
fn=operator.mul,
fn_str="*",
)
)

left_tree = cls.build(dist.left)
right_tree = cls.build(dist.right)

# Make a list of possibly-joinable distributions, plus a list of
# children as trees who could not be simplified at this level
children = []

# If the child nodes use the same commutable operation as ``dist``, add
# their flattened ``children`` lists to ``children``. Otherwise, put
# the whole node in ``children``.
if left_tree.is_leaf:
children.append(left_tree.dist)
elif left_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS:
children.extend(left_tree.children)
else:
children.append(left_tree)
if right_tree is not None:
if right_tree.is_leaf:
children.append(right_tree.dist)
elif right_tree.fn == dist.fn and dist.fn in cls.COMMUTABLE_OPERATIONS:
children.extend(right_tree.children)
else:
children.append(right_tree)

return cls(
fn=dist.fn, fn_str=dist.fn_str, children=children, is_unary=is_unary, infix=dist.infix
)

def _join_dists(self, left_type, right_type, join_fn, commutative=True, condition=None):
simplified_dists = []
acc = None
acc_index = None
acc_is_left = True
for i, x in enumerate(self.children):
if isinstance(x, BaseDistribution) and (x.lclip is not None or x.rclip is not None):
# We can't simplify a clipped distribution
simplified_dists.append(x)
elif acc is None and isinstance(x, left_type):
acc = x
acc_index = i
elif (
acc is not None
and isinstance(x, right_type)
and acc_is_left
and (condition is None or condition(acc, x))
):
acc = join_fn(acc, x)
elif commutative and acc is None and isinstance(x, right_type):
acc = x
acc_index = i
acc_is_left = False
elif (
commutative
and acc is not None
and isinstance(x, left_type)
and not acc_is_left
and (condition is None or condition(x, acc))
):
acc = join_fn(x, acc)
else:
simplified_dists.append(x)

if acc is not None:
simplified_dists.insert(acc_index, acc)
self.children = simplified_dists

@classmethod
def _lognormal_times_const(cls, norm_mean, norm_sd, k):
if k == 0:
return 0
elif k > 0:
return LognormalDistribution(norm_mean=norm_mean + np.log(k), norm_sd=norm_sd)
else:
return -LognormalDistribution(norm_mean=norm_mean + np.log(-k), norm_sd=norm_sd)

def simplify(self):
"""Convert a FlatTree back into a Distribution, simplifying as much as
possible."""
if self.is_leaf:
return self.dist

for i in range(len(self.children)):
if isinstance(self.children[i], FlatTree):
self.children[i] = self.children[i].simplify()

# Simplify unary operations
if len(self.children) == 1:
child = self.children[0]
if self.fn == operator.neg:
if isinstance(child, Real):
return -child
if isinstance(child, NormalDistribution):
return NormalDistribution(mean=-child.mean, sd=child.sd)

return ComplexDistribution(
child, right=None, fn=self.fn, fn_str=self.fn_str, infix=self.infix
)

if self.fn == operator.add:
self._join_dists(
NormalDistribution,
NormalDistribution,
lambda x, y: NormalDistribution(
mean=x.mean + y.mean, sd=np.sqrt(x.sd**2 + y.sd**2)
),
)
self._join_dists(
NormalDistribution, Real, lambda x, y: NormalDistribution(mean=x.mean + y, sd=x.sd)
)
self._join_dists(
BernoulliDistribution,
BernoulliDistribution,
lambda x, y: BinomialDistribution(n=2, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
BinomialDistribution,
BernoulliDistribution,
lambda x, y: BinomialDistribution(n=x.n + 1, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
BinomialDistribution,
BinomialDistribution,
lambda x, y: BinomialDistribution(n=x.n + y.n, p=x.p),
condition=lambda x, y: x.p == y.p,
)
self._join_dists(
ChiSquareDistribution,
ChiSquareDistribution,
lambda x, y: ChiSquareDistribution(df=x.df + y.df),
)
self._join_dists(
ExponentialDistribution,
ExponentialDistribution,
lambda x, y: GammaDistribution(shape=2, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
ExponentialDistribution,
GammaDistribution,
lambda x, y: GammaDistribution(shape=y.shape + 1, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
GammaDistribution,
GammaDistribution,
lambda x, y: GammaDistribution(shape=x.shape + y.shape, scale=x.scale),
condition=lambda x, y: x.scale == y.scale,
)
self._join_dists(
PoissonDistribution,
PoissonDistribution,
lambda x, y: PoissonDistribution(lam=x.lam + y.lam),
)

elif self.fn == operator.mul:
self._join_dists(
NormalDistribution,
Real,
lambda x, y: NormalDistribution(mean=x.mean * y, sd=x.sd * y),
)
self._join_dists(
LognormalDistribution,
LognormalDistribution,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean + y.norm_mean,
norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2),
),
)
self._join_dists(
LognormalDistribution,
Real,
lambda x, y: self._lognormal_times_const(x.norm_mean, x.norm_sd, y),
)
self._join_dists(
ExponentialDistribution,
Real,
lambda x, y: ExponentialDistribution(scale=x.scale * y),
)
self._join_dists(
GammaDistribution,
Real,
lambda x, y: GammaDistribution(shape=x.shape, scale=x.scale * y),
)

elif self.fn == operator.truediv:
self._join_dists(
LognormalDistribution,
LognormalDistribution,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean - y.norm_mean,
norm_sd=np.sqrt(x.norm_sd**2 + y.norm_sd**2),
),
commutative=False,
)
self._join_dists(
Real,
LognormalDistribution,
lambda x, y: self._lognormal_times_const(-y.norm_mean, y.norm_sd, x),
commutative=False,
)

elif self.fn == operator.pow:
self._join_dists(
LognormalDistribution,
Real,
lambda x, y: LognormalDistribution(
norm_mean=x.norm_mean * y, norm_sd=x.norm_sd * y
),
commutative=False,
condition=lambda x, y: y > 0,
)

return reduce(
lambda acc, x: ComplexDistribution(acc, x, fn=self.fn, fn_str=self.fn_str),
self.children,
)
Loading