Skip to content

Commit

Permalink
deprecate WeibullAdstock in favor of WeibullCDFAdstock and WeibullPDF…
Browse files Browse the repository at this point in the history
…Adstock (#957)

* deprecate in favor of WeibullCDFAdstock and WeibullPDFAdstock

* Update UML Diagrams

* Update UML Diagrams
  • Loading branch information
wd60622 authored Aug 22, 2024
1 parent 39d38b7 commit 255eac1
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 76 deletions.
Binary file modified docs/source/uml/classes_mmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
WeibullAdstock,
WeibullCDFAdstock,
WeibullPDFAdstock,
adstock_from_dict,
Expand Down Expand Up @@ -66,7 +65,6 @@
"TanhSaturationBaselined",
"saturation_from_dict",
"register_saturation_transformation",
"WeibullAdstock",
"WeibullCDFAdstock",
"WeibullPDFAdstock",
"adstock_from_dict",
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/mmm/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymc_marketing.mmm import (
SaturationTransformation,
MMM,
WeibullAdstock,
WeibullPDFAdstock,
)
class InfiniteReturns(SaturationTransformation):
Expand All @@ -34,7 +34,7 @@ def function(self, x, b):
saturation = InfiniteReturns()
adstock = WeibullAdstock(l_max=15, kind="PDF")
adstock = WeibullPDFAdstock(l_max=15)
mmm = MMM(
...,
Expand Down
68 changes: 0 additions & 68 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,79 +320,11 @@ def function(self, x, lam, k):
}


class WeibullAdstock(AdstockTransformation):
"""Wrapper around weibull adstock function.
For more information, see :func:`pymc_marketing.mmm.transformers.weibull_adstock`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import WeibullAdstock
rng = np.random.default_rng(0)
adstock = WeibullAdstock(l_max=10, kind="CDF")
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "weibull"

def __init__(
self,
l_max: int,
normalize: bool = True,
kind=WeibullType.PDF,
mode: ConvMode = ConvMode.After,
priors: dict | None = None,
prefix: str | None = None,
) -> None:
self.kind = kind

super().__init__(
l_max=l_max, normalize=normalize, mode=mode, priors=priors, prefix=prefix
)

msg = (
f"Use the Weibull{kind}Adstock class instead for better default priors. "
"This class will deprecate in 0.9.0."
)
warnings.warn(
msg,
UserWarning,
stacklevel=1,
)

def function(self, x, lam, k):
"""Weibull adstock function."""
return weibull_adstock(
x=x,
lam=lam,
k=k,
l_max=self.l_max,
mode=self.mode,
type=self.kind,
normalize=self.normalize,
)

default_priors = {
"lam": Prior("HalfNormal", sigma=1),
"k": Prior("HalfNormal", sigma=1),
}


ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {
cls.lookup_name: cls # type: ignore
for cls in [
GeometricAdstock,
DelayedAdstock,
WeibullAdstock,
WeibullPDFAdstock,
WeibullCDFAdstock,
]
Expand Down
4 changes: 0 additions & 4 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
WeibullAdstock,
WeibullCDFAdstock,
WeibullPDFAdstock,
adstock_from_dict,
Expand All @@ -44,8 +43,6 @@ def adstocks() -> list[AdstockTransformation]:
return [
DelayedAdstock(l_max=10),
GeometricAdstock(l_max=10),
WeibullAdstock(l_max=10, kind="PDF"),
WeibullAdstock(l_max=10, kind="CDF"),
WeibullPDFAdstock(l_max=10),
WeibullCDFAdstock(l_max=10),
]
Expand Down Expand Up @@ -95,7 +92,6 @@ def test_default_prefix(adstock) -> None:
[
("delayed", DelayedAdstock, {"l_max": 10}),
("geometric", GeometricAdstock, {"l_max": 10}),
("weibull", WeibullAdstock, {"l_max": 10}),
],
)
def test_get_adstock_function(name, adstock_cls, kwargs):
Expand Down

0 comments on commit 255eac1

Please sign in to comment.