From b11372a3ab611acae4c7a7aef7af460c41afdd78 Mon Sep 17 00:00:00 2001 From: Pablo de Roque Date: Mon, 5 Aug 2024 18:19:03 +0200 Subject: [PATCH] Move adstock and saturation method imports to mmm.__all__ (#908) * Resolves #892: Move adstock and saturation method imports to mmm.__all__ * fix: patch uml GHA * Update .github/workflows/uml.yml * Update .github/workflows/uml.yml * Add permissions: write-all to .github/workflows/uml.yml --------- Co-authored-by: Will Dean <57733339+wd60622@users.noreply.github.com> --- .github/workflows/uml.yml | 7 ++++++- environment.yml | 3 +++ pymc_marketing/mmm/__init__.py | 8 ++++++++ tests/mmm/components/test_adstock.py | 8 +++++--- tests/mmm/components/test_saturation.py | 4 ++-- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/.github/workflows/uml.yml b/.github/workflows/uml.yml index c581fff30..057c24974 100644 --- a/.github/workflows/uml.yml +++ b/.github/workflows/uml.yml @@ -10,6 +10,8 @@ on: paths: - "pymc_marketing/**" +permissions: write-all + jobs: build: runs-on: ubuntu-latest @@ -39,6 +41,9 @@ jobs: if git diff --staged --exit-code; then echo "No changes to commit" else + echo "Committing the changes" git commit -m "Update UML Diagrams" - git push + git push origin HEAD:${GITHUB_HEAD_REF} fi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/environment.yml b/environment.yml index 7a50b348f..6eb6e06b7 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,8 @@ dependencies: - pandas - streamlit>=1.25.0 - pip +- pydantic +- preliz # NOTE: Keep minimum pymc version in sync with ci.yml `OLDEST_PYMC_VERSION` - pymc>=5.12.0,<5.16.0 - scikit-learn>=1.1.1 @@ -41,3 +43,4 @@ dependencies: - lifetimes==0.11.3 - pytest==7.0.1 - pytest-cov==3.0.0 +- pytest-mock diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 4af4f39d2..1c452e70b 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -20,6 +20,8 @@ WeibullAdstock, WeibullCDFAdstock, WeibullPDFAdstock, + adstock_from_dict, + register_adstock_transformation, ) from pymc_marketing.mmm.components.saturation import ( HillSaturation, @@ -30,6 +32,8 @@ SaturationTransformation, TanhSaturation, TanhSaturationBaselined, + register_saturation_transformation, + saturation_from_dict, ) from pymc_marketing.mmm.delayed_saturated_mmm import MMM, DelayedSaturatedMMM from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier @@ -56,9 +60,13 @@ "SaturationTransformation", "TanhSaturation", "TanhSaturationBaselined", + "saturation_from_dict", + "register_saturation_transformation", "WeibullAdstock", "WeibullCDFAdstock", "WeibullPDFAdstock", + "adstock_from_dict", + "register_adstock_transformation", "YearlyFourier", "base", "delayed_saturated_mmm", diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index 7a91044a3..c932d7e5f 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -20,18 +20,20 @@ import xarray as xr from pydantic import ValidationError -from pymc_marketing.mmm.components.adstock import ( - ADSTOCK_TRANSFORMATIONS, +from pymc_marketing.mmm import ( AdstockTransformation, DelayedAdstock, GeometricAdstock, WeibullAdstock, WeibullCDFAdstock, WeibullPDFAdstock, - _get_adstock_function, adstock_from_dict, register_adstock_transformation, ) +from pymc_marketing.mmm.components.adstock import ( + ADSTOCK_TRANSFORMATIONS, + _get_adstock_function, +) from pymc_marketing.mmm.transformers import ConvMode from pymc_marketing.prior import Prior diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index f3c096775..43245f9b5 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -20,7 +20,7 @@ import xarray as xr from pydantic import ValidationError -from pymc_marketing.mmm.components.saturation import ( +from pymc_marketing.mmm import ( HillSaturation, InverseScaledLogisticSaturation, LogisticSaturation, @@ -28,9 +28,9 @@ RootSaturation, TanhSaturation, TanhSaturationBaselined, - _get_saturation_function, saturation_from_dict, ) +from pymc_marketing.mmm.components.saturation import _get_saturation_function from pymc_marketing.prior import Prior