Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add root saturation function (issue #702) #858

Merged
merged 21 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
694489c
feat: adding root_saturation to transformers.py
ruariwalker Jul 23, 2024
9ef2044
feat: adding RootSaturation class to saturation.py
ruariwalker Jul 23, 2024
75c4f76
chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS
ruariwalker Jul 23, 2024
731d422
Merge branch 'main' into saturation_functions
wd60622 Jul 23, 2024
cd4efc2
feat: adding root_saturation to transformers.py
ruariwalker Jul 23, 2024
7faf104
feat: adding RootSaturation class to saturation.py
ruariwalker Jul 23, 2024
3d0526d
chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS
ruariwalker Jul 23, 2024
c7dc8a4
chore: linting edits
ruariwalker Jul 23, 2024
88ea0b0
chore: removing duplicate class names
ruariwalker Jul 23, 2024
24ac7af
chore: adding coefficient to function
ruariwalker Jul 23, 2024
0a2abb2
chore: linting corrections
ruariwalker Jul 23, 2024
fedc606
Merge branch 'main' into saturation_functions
wd60622 Jul 24, 2024
e7114a8
Merge branch 'main' into saturation_functions
wd60622 Jul 24, 2024
11b12a0
Merge branch 'main' into saturation_functions
wd60622 Jul 24, 2024
4ca571a
chore: removed empty References section of docstring
ruariwalker Jul 25, 2024
09c8fa5
chore: produce visual examples of root saturation
ruariwalker Jul 25, 2024
6b84be2
Merge branch 'main' into saturation_functions
wd60622 Jul 25, 2024
f078b59
chore: adding root to test_saturation.py
ruariwalker Jul 25, 2024
e3dd639
Merge branch 'main' into saturation_functions
wd60622 Jul 25, 2024
f09a867
chore: adding RootSaturation to init file
ruariwalker Jul 25, 2024
0ff48fe
Merge branch 'main' into saturation_functions
wd60622 Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def function(self, x, b):
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
root_saturation,
tanh_saturation,
tanh_saturation_baselined,
)
Expand Down Expand Up @@ -369,6 +370,40 @@ class HillSaturation(SaturationTransformation):
}


class RootSaturation(SaturationTransformation):
"""Wrapper around Root saturation function.

For more information, see :func:`pymc_marketing.mmm.transformers.root_saturation`.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import RootSaturation

rng = np.random.default_rng(0)

saturation = RootSaturation()
prior = saturation.sample_prior(random_seed=rng)
curve = saturation.sample_curve(prior)
saturation.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

"""

def function(self, x, alpha, beta):
return beta * root_saturation(x, alpha)

lookup_name = "root"

# REMINDER
default_priors = {
"alpha": Prior("Beta", alpha=1, beta=2),
"beta": Prior("Gamma", mu=1, sigma=1),
}


SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
cls.lookup_name: cls
for cls in [
Expand All @@ -378,6 +413,7 @@ class HillSaturation(SaturationTransformation):
TanhSaturationBaselined,
MichaelisMentenSaturation,
HillSaturation,
RootSaturation,
]
}

Expand Down
49 changes: 49 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,52 @@ def hill_saturation(
The value of the Hill function for each input value of x.
"""
return sigma / (1 + pt.exp(-beta * (x - lam)))


def root_saturation(
x: pt.TensorLike,
alpha: pt.TensorLike,
) -> pt.TensorVariable:
r"""Root saturation transformation.

.. math::
f(x) = x^{\alpha}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import root_saturation
plt.style.use('arviz-darkgrid')
alpha_values = [0.25, 0.5, 0.75, 1, 1.25]
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for alpha in alpha_values:
y = root_saturation(x, alpha=alpha).eval()
plt.plot(x, y, label=f'alpha = {alpha}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()

Parameters
----------
x : tensor
Input tensor.
alpha : float
Exponent for the root transformation. Must be non-negative.

Returns
-------
tensor
Transformed tensor.

References
----------

"""
return x**alpha
Loading