Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for adstock and saturation components #5

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class WeibullAdstock(AdstockTransformation):

def __init__(
self,
l_max: int = 10,
normalize: bool = False,
l_max: int,
normalize: bool = True,
kind=WeibullType.PDF,
mode: ConvMode = ConvMode.After,
priors: dict | None = None,
Expand Down Expand Up @@ -161,14 +161,20 @@ def _get_adstock_function(
function: str | AdstockTransformation,
**kwargs,
) -> AdstockTransformation:
"""Helper for use in the MMM to get an adstock function."""
if isinstance(function, AdstockTransformation):
return function

if function not in ADSTOCK_TRANSFORMATIONS:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
warnings.warn(
"The preferred method of initializing a lagging function is to use the class directly.",
DeprecationWarning,
stacklevel=1,
)

if isinstance(function, str):
return ADSTOCK_TRANSFORMATIONS[function](**kwargs)

return function
return ADSTOCK_TRANSFORMATIONS[function](**kwargs)
1 change: 1 addition & 0 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class MyTransformation(Transformation):
new_priors = {
parameter_name: priors[variable_name]
for parameter_name, variable_name in self.variable_mapping.items()
if variable_name in priors
}
if not new_priors:
available_priors = list(self.variable_mapping.values())
Expand Down
12 changes: 9 additions & 3 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,13 @@ class HillSaturation(SaturationTransformation):
def _get_saturation_function(
function: str | SaturationTransformation,
) -> SaturationTransformation:
if isinstance(function, str):
return SATURATION_TRANSFORMATIONS[function]()
"""Helper for use in the MMM to get a saturation function."""
if isinstance(function, SaturationTransformation):
return function

return function
if function not in SATURATION_TRANSFORMATIONS:
raise ValueError(
f"Unknown saturation function: {function}. Choose from {list(SATURATION_TRANSFORMATIONS.keys())}"
)

return SATURATION_TRANSFORMATIONS[function]()
94 changes: 94 additions & 0 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,97 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
WeibullAdstock,
_get_adstock_function,
)


def adstocks() -> list[AdstockTransformation]:
return [
DelayedAdstock(l_max=10),
GeometricAdstock(l_max=10),
WeibullAdstock(l_max=10, kind="PDF"),
WeibullAdstock(l_max=10, kind="CDF"),
]


@pytest.fixture
def model() -> pm.Model:
coords = {"channel": ["a", "b", "c"]}
return pm.Model(coords=coords)


x = np.zeros(20)
x[0] = 1


@pytest.mark.parametrize(
"adstock",
adstocks(),
)
@pytest.mark.parametrize(
"x, dims",
[
(x, None),
(np.broadcast_to(x, (3, 20)).T, "channel"),
],
)
def test_apply(model, adstock, x, dims) -> None:
with model:
y = adstock.apply(x, dim_name=dims)

assert isinstance(y, pt.TensorVariable)
assert y.eval().shape == x.shape


@pytest.mark.parametrize(
"adstock",
adstocks(),
)
def test_default_prefix(adstock) -> None:
assert adstock.prefix == "adstock"
for value in adstock.variable_mapping.values():
assert value.startswith("adstock_")


@pytest.mark.parametrize(
"name, adstock_cls, kwargs",
[
("delayed", DelayedAdstock, {"l_max": 10}),
("geometric", GeometricAdstock, {"l_max": 10}),
("weibull", WeibullAdstock, {"l_max": 10}),
],
)
def test_get_adstock_function(name, adstock_cls, kwargs):
# Test for a warning
with pytest.warns(DeprecationWarning, match="The preferred method of initializing"):
adstock = _get_adstock_function(name, **kwargs)

assert isinstance(adstock, adstock_cls)


@pytest.mark.parametrize(
"adstock",
adstocks(),
)
def test_get_adstock_function_passthrough(adstock) -> None:
id_before = id(adstock)
id_after = id(_get_adstock_function(adstock))

assert id_after == id_before


def test_get_adstock_function_unknown():
with pytest.raises(
ValueError, match="Unknown adstock function: Unknown. Choose from"
):
_get_adstock_function(function="Unknown")
38 changes: 37 additions & 1 deletion tests/mmm/components/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pymc as pm
import pytest

from pymc_marketing.mmm.components.base import (
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_new_transformation_function_priors(new_transformation) -> None:
}


def test_new_transformation_update_priors(new_transformation_class) -> None:
def test_new_transformation_priors_at_init(new_transformation_class) -> None:
new_prior = {"a": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}}
new_transformation = new_transformation_class(priors=new_prior)
assert new_transformation.function_priors == {
Expand All @@ -161,6 +162,16 @@ def test_new_transformation_variable_mapping(new_transformation) -> None:
assert new_transformation.variable_mapping == {"a": "new_a", "b": "new_b"}


def test_apply(new_transformation):
x = np.array([1, 2, 3])
expected = np.array([6, 12, 18])
with pm.Model() as generative_model:
pm.Deterministic("y", new_transformation.apply(x, dim_name=None))

fixed_model = pm.do(generative_model, {"new_a": 2, "new_b": 3})
np.testing.assert_allclose(fixed_model["y"].eval(), expected)


def test_new_transformation_access_function(new_transformation) -> None:
x = np.array([1, 2, 3])
expected = np.array([6, 12, 18])
Expand All @@ -170,3 +181,28 @@ def test_new_transformation_access_function(new_transformation) -> None:
def test_new_transformation_apply_outside_model(new_transformation) -> None:
with pytest.raises(TypeError, match="on context stack"):
new_transformation.apply(1)


def test_model_config(new_transformation) -> None:
assert new_transformation.model_config == {
"new_a": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
"new_b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
}


def test_new_transform_update_priors(new_transformation) -> None:
new_transformation.update_priors(
{"new_a": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}}
)

assert new_transformation.function_priors == {
"a": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
}


def test_new_transformation_warning_no_priors_updated(new_transformation) -> None:
with pytest.warns(UserWarning, match="No priors were updated"):
new_transformation.update_priors(
{"new_c": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}}
)
108 changes: 108 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,111 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import signature

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_marketing.mmm.components.saturation import (
HillSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
TanhSaturation,
TanhSaturationBaselined,
_get_saturation_function,
)


@pytest.fixture
def model() -> pm.Model:
coords = {"channel": ["a", "b", "c"]}
return pm.Model(coords=coords)


def saturation_functions():
return [
LogisticSaturation(),
TanhSaturation(),
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
HillSaturation(),
]


@pytest.mark.parametrize(
"saturation",
saturation_functions(),
)
@pytest.mark.parametrize(
"x, dims",
[
(np.linspace(0, 1, 100), None),
(np.ones((100, 3)), "channel"),
],
)
def test_apply_method(model, saturation, x, dims) -> None:
with model:
y = saturation.apply(x, dim_name=dims)

assert isinstance(y, pt.TensorVariable)
assert y.eval().shape == x.shape


@pytest.mark.parametrize(
"saturation",
saturation_functions(),
)
def test_default_prefix(saturation) -> None:
assert saturation.prefix == "saturation"
for value in saturation.variable_mapping.values():
assert value.startswith("saturation_")


@pytest.mark.parametrize(
"saturation",
saturation_functions(),
)
def test_support_for_lift_test_integrations(saturation) -> None:
function_parameters = signature(saturation.function).parameters

for key in saturation.variable_mapping.keys():
assert isinstance(key, str)
assert key in function_parameters

assert len(saturation.variable_mapping) == len(function_parameters) - 1


@pytest.mark.parametrize(
"name, saturation_cls",
[
("logistic", LogisticSaturation),
("tanh", TanhSaturation),
("tanh_baselined", TanhSaturationBaselined),
("michaelis_menten", MichaelisMentenSaturation),
("hill", HillSaturation),
],
)
def test_get_saturation_function(name, saturation_cls) -> None:
saturation = _get_saturation_function(name)

assert isinstance(saturation, saturation_cls)


@pytest.mark.parametrize(
"saturation",
saturation_functions(),
)
def test_get_saturation_function_passthrough(saturation) -> None:
id_before = id(saturation)
id_after = id(_get_saturation_function(saturation))

assert id_after == id_before


def test_get_saturation_function_unknown() -> None:
with pytest.raises(
ValueError, match="Unknown saturation function: unknown. Choose from"
):
_get_saturation_function("unknown")