Skip to content

Commit

Permalink
Add sdr.sum_distributions()
Browse files Browse the repository at this point in the history
Fixes #423
  • Loading branch information
mhostetter committed Jul 28, 2024
1 parent b9c02b0 commit 062f740
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions src/sdr/_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,140 @@ def Qinv(p: npt.ArrayLike) -> npt.NDArray[np.float64]:
x = np.sqrt(2) * scipy.special.erfcinv(2 * p)

return convert_output(x)


@export
def sum_distributions(
dist1: scipy.stats.rv_continuous | scipy.stats.rv_histogram,
dist2: scipy.stats.rv_continuous | scipy.stats.rv_histogram,
p: float = 1e-16,
) -> scipy.stats.rv_histogram:
r"""
Numerically calculates the distribution of the sum of two independent random variables.
Arguments:
dist1: The distribution of the first random variable $X$.
dist2: The distribution of the second random variable $Y$.
p: The probability of exceeding the x axis, on either side, for each distribution. This is used to determine
the bounds on the x axis for the numerical convolution. Smaller values of $p$ will result in more accurate
analysis, but will require more computation.
Returns:
The distribution of the sum $Z = X + Y$.
Notes:
The PDF of the sum of two independent random variables is the convolution of the PDF of the two distributions.
$$f_{X+Y}(t) = (f_X * f_Y)(t)$$
Examples:
Compute the distribution of the sum of two normal distributions.
.. ipython:: python
X = scipy.stats.norm(loc=-1, scale=0.5)
Y = scipy.stats.norm(loc=2, scale=1.5)
x = np.linspace(-5, 10, 1000)
@savefig sdr_sum_distributions_1.png
plt.figure(); \
plt.plot(x, X.pdf(x), label="X"); \
plt.plot(x, Y.pdf(x), label="Y"); \
plt.plot(x, sdr.sum_distributions(X, Y).pdf(x), label="X + Y"); \
plt.hist(X.rvs(100_000) + Y.rvs(100_000), bins=101, density=True, histtype="step", label="X + Y empirical"); \
plt.legend(); \
plt.xlabel("Random variable"); \
plt.ylabel("Probability density"); \
plt.title("Sum of two Normal distributions");
Compute the distribution of the sum of two Rayleigh distributions.
.. ipython:: python
X = scipy.stats.rayleigh(scale=1)
Y = scipy.stats.rayleigh(loc=1, scale=2)
x = np.linspace(0, 12, 1000)
@savefig sdr_sum_distributions_2.png
plt.figure(); \
plt.plot(x, X.pdf(x), label="X"); \
plt.plot(x, Y.pdf(x), label="Y"); \
plt.plot(x, sdr.sum_distributions(X, Y).pdf(x), label="X + Y"); \
plt.hist(X.rvs(100_000) + Y.rvs(100_000), bins=101, density=True, histtype="step", label="X + Y empirical"); \
plt.legend(); \
plt.xlabel("Random variable"); \
plt.ylabel("Probability density"); \
plt.title("Sum of two Rayleigh distributions");
Compute the distribution of the sum of two Rician distributions.
.. ipython:: python
X = scipy.stats.rice(2)
Y = scipy.stats.rice(3)
x = np.linspace(0, 12, 1000)
@savefig sdr_sum_distributions_3.png
plt.figure(); \
plt.plot(x, X.pdf(x), label="X"); \
plt.plot(x, Y.pdf(x), label="Y"); \
plt.plot(x, sdr.sum_distributions(X, Y).pdf(x), label="X + Y"); \
plt.hist(X.rvs(100_000) + Y.rvs(100_000), bins=101, density=True, histtype="step", label="X + Y empirical"); \
plt.legend(); \
plt.xlabel("Random variable"); \
plt.ylabel("Probability density"); \
plt.title("Sum of two Rician distributions");
Group:
probability
"""
# Determine the x axis of each distribution such that the probability of exceeding the x axis, on either side,
# is p.
x1_min, x1_max = _x_range(dist1, p)
x2_min, x2_max = _x_range(dist2, p)
dx1 = (x1_max - x1_min) / 1_000
dx2 = (x2_max - x2_min) / 1_000
dx = np.min([dx1, dx2]) # Use the smaller delta x -- must use the same dx for both distributions
x1 = np.arange(x1_min, x1_max, dx)
x2 = np.arange(x2_min, x2_max, dx)

# Compute the PDF of each distribution
y1 = dist1.pdf(x1)
y2 = dist2.pdf(x2)

# The PDF of the sum of two independent random variables is the convolution of the PDF of the two distributions
y = np.convolve(y1, y2, mode="full") * dx

# Determine the x axis for the output convolution
x = np.arange(y.size) * dx + x1[0] + x2[0]

# Adjust the histograms bins to be on either side of each point. So there is one extra point added.
x = np.append(x, x[-1] + dx)
x -= dx / 2

return scipy.stats.rv_histogram((y, x))


def _x_range(dist: scipy.stats.rv_continuous, p: float) -> tuple[float, float]:
r"""
Determines the range of x values for a given distribution such that the probability of exceeding the x axis, on
either side, is p.
"""
# Need to have these loops because for very small p, sometimes SciPy will return NaN instead of a valid value.
# The while loops will increase the p value until a valid value is returned.

pp = p
while True:
x_min = dist.ppf(pp)
if not np.isnan(x_min):
break
pp *= 10

pp = p
while True:
x_max = dist.isf(pp)
if not np.isnan(x_max):
break
pp *= 10

return x_min, x_max

0 comments on commit 062f740

Please sign in to comment.