From 8e21327ebe420fa1cd9e0b1c1c95fe42330a76c8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 15 Nov 2022 17:34:32 +0100 Subject: [PATCH 1/3] Update PyMC dependency --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 48d2355a..8d19789a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pymc>=4.3.0 +pymc>=4.4.0 From 1c4315b287e0efc0407eb5e9b0516a9a9713ebfe Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 22 Nov 2022 11:38:19 +0100 Subject: [PATCH 2/3] Add pytest slow mark --- conftest.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..3178a8fd --- /dev/null +++ b/conftest.py @@ -0,0 +1,19 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) From c0f622cb7cd2aca2d547d86ec0d58be58b4a4a9f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 14 Nov 2022 18:57:24 +0100 Subject: [PATCH 3/3] Automatic logp marginalization of finite discrete variables --- .../marginalized_changepoint_model.ipynb | 821 ++++++++++++++++++ pymc_experimental/__init__.py | 1 + pymc_experimental/marginal_model.py | 477 ++++++++++ .../tests/test_marginal_model.py | 410 +++++++++ 4 files changed, 1709 insertions(+) create mode 100644 notebooks/marginalized_changepoint_model.ipynb create mode 100644 pymc_experimental/marginal_model.py create mode 100644 pymc_experimental/tests/test_marginal_model.py diff --git a/notebooks/marginalized_changepoint_model.ipynb b/notebooks/marginalized_changepoint_model.ipynb new file mode 100644 index 00000000..e0bd9cb1 --- /dev/null +++ b/notebooks/marginalized_changepoint_model.ipynb @@ -0,0 +1,821 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pymc as pm\n", + "from pymc_experimental.marginal_model import MarginalModel\n", + "import pandas as pd\n", + "import numpy as np\n", + "import arviz as az" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The original model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Try to handle np.nan\n", + "# fmt: off\n", + "disaster_data = pd.Series(\n", + " [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,\n", + " 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,\n", + " 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,\n", + " 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,\n", + " 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,\n", + " 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,\n", + " 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]\n", + ")\n", + "# fmt: on\n", + "\n", + "years = np.arange(1851, 1962)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ricardo/Documents/Projects/pymc/pymc/model.py:1378: ImputationWarning: Data in disasters contains missing values and will be automatically imputed from the sampling distribution.\n", + " warnings.warn(impute_message, ImputationWarning)\n" + ] + } + ], + "source": [ + "with MarginalModel() as disaster_model:\n", + " switchpoint = pm.DiscreteUniform(\"switchpoint\", lower=years.min(), upper=years.max())\n", + "\n", + " early_rate = pm.Exponential(\"early_rate\", 1.0)\n", + " late_rate = pm.Exponential(\"late_rate\", 1.0)\n", + " rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)\n", + "\n", + " disasters = pm.Poisson(\"disasters\", rate, observed=disaster_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster2\n", + "\n", + "2\n", + "\n", + "\n", + "cluster109\n", + "\n", + "109\n", + "\n", + "\n", + "cluster111\n", + "\n", + "111\n", + "\n", + "\n", + "\n", + "switchpoint\n", + "\n", + "switchpoint\n", + "~\n", + "DiscreteUniform\n", + "\n", + "\n", + "\n", + "disasters_missing\n", + "\n", + "disasters_missing\n", + "~\n", + "Poisson\n", + "\n", + "\n", + "\n", + "switchpoint->disasters_missing\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters_observed\n", + "\n", + "disasters_observed\n", + "~\n", + "Poisson\n", + "\n", + "\n", + "\n", + "switchpoint->disasters_observed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "early_rate\n", + "\n", + "early_rate\n", + "~\n", + "Exponential\n", + "\n", + "\n", + "\n", + "early_rate->disasters_missing\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "early_rate->disasters_observed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "late_rate\n", + "\n", + "late_rate\n", + "~\n", + "Exponential\n", + "\n", + "\n", + "\n", + "late_rate->disasters_missing\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "late_rate->disasters_observed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters\n", + "\n", + "disasters\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "disasters_missing->disasters\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters_observed->disasters\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.model_to_graphviz(disaster_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Multiprocess sampling (4 chains in 4 jobs)\n", + "CompoundStep\n", + ">CompoundStep\n", + ">>Metropolis: [switchpoint]\n", + ">>Metropolis: [disasters_missing]\n", + ">NUTS: [early_rate, late_rate]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [8000/8000 00:05<00:00 Sampling 4 chains, 0 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.\n" + ] + } + ], + "source": [ + "with disaster_model:\n", + " trace = pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "az.plot_posterior(trace, var_names=[\"~switchpoint\", \"~disasters\"]);" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
disasters_missing[0]2.2031.8690.0005.0000.0950.067373.0417.01.01
disasters_missing[1]0.9130.9670.0003.0000.0370.027723.0932.01.00
early_rate3.0810.2892.5513.6460.0060.0042427.02785.01.00
late_rate0.9300.1160.7131.1490.0020.0022725.02749.01.00
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd \\\n", + "disasters_missing[0] 2.203 1.869 0.000 5.000 0.095 0.067 \n", + "disasters_missing[1] 0.913 0.967 0.000 3.000 0.037 0.027 \n", + "early_rate 3.081 0.289 2.551 3.646 0.006 0.004 \n", + "late_rate 0.930 0.116 0.713 1.149 0.002 0.002 \n", + "\n", + " ess_bulk ess_tail r_hat \n", + "disasters_missing[0] 373.0 417.0 1.01 \n", + "disasters_missing[1] 723.0 932.0 1.00 \n", + "early_rate 2427.0 2785.0 1.00 \n", + "late_rate 2725.0 2749.0 1.00 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(trace, var_names=[\"~switchpoint\", \"~disasters\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Marginalized model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ricardo/Documents/Projects/pymc-experimental/pymc_experimental/marginal_model.py:117: UserWarning: There are multiple dependent variables in a FiniteDiscreteMarginalRV.\n", + "Their joint logp terms will be assigned to the first RV: disasters_missing\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "disaster_model.marginalize([switchpoint])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster2\n", + "\n", + "2\n", + "\n", + "\n", + "cluster109\n", + "\n", + "109\n", + "\n", + "\n", + "cluster111\n", + "\n", + "111\n", + "\n", + "\n", + "\n", + "early_rate\n", + "\n", + "early_rate\n", + "~\n", + "Exponential\n", + "\n", + "\n", + "\n", + "disasters_missing\n", + "\n", + "disasters_missing\n", + "~\n", + "Poisson\n", + "\n", + "\n", + "\n", + "early_rate->disasters_missing\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters_observed\n", + "\n", + "disasters_observed\n", + "~\n", + "Poisson\n", + "\n", + "\n", + "\n", + "early_rate->disasters_observed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "late_rate\n", + "\n", + "late_rate\n", + "~\n", + "Exponential\n", + "\n", + "\n", + "\n", + "late_rate->disasters_missing\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "late_rate->disasters_observed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters\n", + "\n", + "disasters\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "disasters_missing->disasters\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "disasters_observed->disasters\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.model_to_graphviz(disaster_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Multiprocess sampling (4 chains in 4 jobs)\n", + "CompoundStep\n", + ">NUTS: [early_rate, late_rate]\n", + ">Metropolis: [disasters_missing]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [8000/8000 01:16<00:00 Sampling 4 chains, 0 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 77 seconds.\n" + ] + } + ], + "source": [ + "with disaster_model:\n", + " trace = pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "az.plot_posterior(trace, var_names=\"~disasters\");" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
disasters_missing[0]2.2721.9060.0006.0000.0790.060618.0685.01.01
disasters_missing[1]0.9901.0450.0003.0000.0400.030712.0652.01.01
early_rate3.0850.2832.5383.5920.0050.0033770.03235.01.00
late_rate0.9290.1180.7111.1470.0020.0013214.02258.01.00
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd \\\n", + "disasters_missing[0] 2.272 1.906 0.000 6.000 0.079 0.060 \n", + "disasters_missing[1] 0.990 1.045 0.000 3.000 0.040 0.030 \n", + "early_rate 3.085 0.283 2.538 3.592 0.005 0.003 \n", + "late_rate 0.929 0.118 0.711 1.147 0.002 0.001 \n", + "\n", + " ess_bulk ess_tail r_hat \n", + "disasters_missing[0] 618.0 685.0 1.01 \n", + "disasters_missing[1] 712.0 652.0 1.01 \n", + "early_rate 3770.0 3235.0 1.00 \n", + "late_rate 3214.0 2258.0 1.00 " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(trace, var_names=\"~disasters\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "pymcx", + "language": "python", + "name": "pymcx" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 838149d4..ed021a51 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -28,3 +28,4 @@ from pymc_experimental import distributions, gp, utils from pymc_experimental.inference.fit import fit +from pymc_experimental.marginal_model import MarginalModel diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/marginal_model.py new file mode 100644 index 00000000..0d94dcc9 --- /dev/null +++ b/pymc_experimental/marginal_model.py @@ -0,0 +1,477 @@ +import warnings +from typing import Sequence, Tuple, Union + +import aesara.tensor as at +import numpy as np +from aeppl import factorized_joint_logprob +from aeppl.abstract import _get_measurable_outputs +from aeppl.logprob import _logprob +from aesara import Mode +from aesara.compile import SharedVariable +from aesara.compile.builders import OpFromGraph +from aesara.graph import Constant, FunctionGraph, ancestors, clone_replace +from aesara.scan import map as scan_map +from aesara.tensor import TensorVariable +from aesara.tensor.elemwise import Elemwise +from pymc import SymbolicRandomVariable +from pymc.aesaraf import constant_fold, inputvars +from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform +from pymc.model import Model + +__all__ = ["MarginalModel"] + + +class MarginalModel(Model): + """Subclass of PyMC Model that implements functionality for automatic + marginalization of variables in the logp transformation + + After defining the full Model, the `marginalize` method can be used to indicate a + subset of variables that should be marginalized + + Notes + ----- + Marginalization functionality is still very restricted. Only finite discrete + variables can be marginalized. Deterministics and Potentials cannot be conditionally + dependent on the marginalized variables. + + Furthermore, not all instances of such variables can be marginalized. If a variable + has batched dimensions, it is required that any conditionally dependent variables + use information from an individual batched dimension. In other words, the graph + connecting the marginalized variable(s) to the dependent variable(s) must be + composed strictly of Elemwise Operations. This is necessary to ensure an efficient + logprob graph can be generated. If you want to bypass this restriction you can + separate each dimension of the marginalized variable into the scalar components + and then stack them together. Note that such graphs will grow exponentially in the + number of marginalized variables. + + For the same reason, it's not possible to marginalize RVs with multivariate + dependent RVs. + + Examples + -------- + + Marginalize over a single variable + + .. code-block:: python + import pymc as pm + from pymc_experimental import MarginalModel + + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + m.marginalize([x]) + + idata = pm.sample() + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.marginalized_rvs = [] + + def _delete_rv_mappings(self, rv: TensorVariable) -> None: + """Remove all model mappings referring to rv + + This can be used to "delete" an RV from a model + """ + assert rv in self.basic_RVs, "rv is not part of the Model" + + name = rv.name + self.named_vars.pop(name) + if name in self.named_vars_to_dims: + self.named_vars_to_dims.pop(name) + + value = self.rvs_to_values.pop(rv) + self.values_to_rvs.pop(value) + + self.rvs_to_transforms.pop(rv) + if rv in self.free_RVs: + self.free_RVs.remove(rv) + self.rvs_to_initial_values.pop(rv) + else: + self.observed_RVs.remove(rv) + if rv in self.rvs_to_total_sizes: + self.rvs_to_total_sizes.pop(rv) + + def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None: + """Transfer model mappings from old_rv to new_rv""" + + assert old_rv in self.basic_RVs, "old_rv is not part of the Model" + assert new_rv not in self.basic_RVs, "new_rv is already part of the Model" + + self.named_vars.pop(old_rv.name) + new_rv.name = old_rv.name + self.named_vars[new_rv.name] = new_rv + if old_rv in self.named_vars_to_dims: + self._RV_dims[new_rv] = self._RV_dims.pop(old_rv) + + value = self.rvs_to_values.pop(old_rv) + self.rvs_to_values[new_rv] = value + self.values_to_rvs[value] = new_rv + + self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv) + if old_rv in self.free_RVs: + index = self.free_RVs.index(old_rv) + self.free_RVs.pop(index) + self.free_RVs.insert(index, new_rv) + self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv) + elif old_rv in self.observed_RVs: + index = self.observed_RVs.index(old_rv) + self.observed_RVs.pop(index) + self.observed_RVs.insert(index, new_rv) + if old_rv in self.rvs_to_total_sizes: + self.rvs_to_total_sizes[new_rv] = self.rvs_to_total_sizes.pop(old_rv) + + def _marginalize(self, user_warnings=False): + fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False) + + toposort = fg.toposort() + rvs_left_to_marginalize = self.marginalized_rvs + for rv_to_marginalize in sorted( + self.marginalized_rvs, + key=lambda rv: toposort.index(rv.owner), + reverse=True, + ): + # Check that no deterministics or potentials dependend on the rv to marginalize + for det in self.deterministics: + if is_conditional_dependent( + det, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize + ): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}" + ) + for pot in self.potentials: + if is_conditional_dependent( + pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize + ): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" + ) + + old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph( + fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize + ) + + if user_warnings and len(new_rvs) > 2: + warnings.warn( + "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " + f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}", + UserWarning, + ) + + rvs_left_to_marginalize.remove(rv_to_marginalize) + + for old_rv, new_rv in zip(old_rvs, new_rvs): + new_rv.name = old_rv.name + if old_rv in self.marginalized_rvs: + idx = self.marginalized_rvs.index(old_rv) + self.marginalized_rvs.pop(idx) + self.marginalized_rvs.insert(idx, new_rv) + if old_rv in self.basic_RVs: + self._transfer_rv_mappings(old_rv, new_rv) + if user_warnings: + transform = self.rvs_to_transforms[new_rv] + if transform is not None: + warnings.warn( + "Transforms for variables that depend on marginalized RVs are currently not working, " + f"rv={new_rv}, transform={transform}", + UserWarning, + ) + return self + + def _logp(self, *args, **kwargs): + return super().logp(*args, **kwargs) + + def logp(self, vars=None, **kwargs): + m = self.clone()._marginalize() + if vars is not None: + if not isinstance(vars, Sequence): + vars = (vars,) + vars = [m[var.name] for var in vars] + return m._logp(vars=vars, **kwargs) + + def clone(self): + m = MarginalModel() + vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs + cloned_vars = clone_replace(vars) + vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} + + m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()} + m.named_vars_to_dims = self.named_vars_to_dims + m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()} + m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()} + m.rvs_to_transforms = {vars_to_clone[rv]: i for rv, i in self.rvs_to_transforms.items()} + # Special logic due to bug in pm.Model + m.rvs_to_total_sizes = { + vars_to_clone[rv]: i for rv, i in self.rvs_to_total_sizes.items() if rv in vars_to_clone + } + m.rvs_to_initial_values = { + vars_to_clone[rv]: i for rv, i in self.rvs_to_initial_values.items() + } + m.free_RVs = [vars_to_clone[rv] for rv in self.free_RVs] + m.observed_RVs = [vars_to_clone[rv] for rv in self.observed_RVs] + m.potentials = [vars_to_clone[pot] for pot in self.potentials] + m.deterministics = [vars_to_clone[det] for det in self.deterministics] + + m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs] + return m + + def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]): + if not isinstance(rvs_to_marginalize, Sequence): + rvs_to_marginalize = (rvs_to_marginalize,) + + supported_dists = (Bernoulli, Categorical, DiscreteUniform) + for rv_to_marginalize in rvs_to_marginalize: + if rv_to_marginalize not in self.free_RVs: + raise ValueError( + f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" + ) + if not isinstance(rv_to_marginalize.owner.op, supported_dists): + raise NotImplementedError( + f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " + f"Supported distribution include {supported_dists}" + ) + + self._delete_rv_mappings(rv_to_marginalize) + self.marginalized_rvs.append(rv_to_marginalize) + + # Raise errors and warnings immediately + self.clone()._marginalize(user_warnings=True) + + +class MarginalRV(SymbolicRandomVariable): + """Base class for Marginalized RVs""" + + +class FiniteDiscreteMarginalRV(MarginalRV): + """Base class for Finite Discrete Marginalized RVs""" + + +def find_conditional_input_rvs(output_rvs, all_rvs): + """Find conditionally indepedent input RVs""" + blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] + return [ + var + for var in ancestors(output_rvs, blockers=blockers) + if var in blockers + or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) + ] + + +def is_conditional_dependent( + dependent_rv: TensorVariable, dependable_rv: TensorVariable, all_rvs +) -> bool: + """Check if dependent_rv is conditionall dependent on dependable_rv, + given all conditionally independent all_rvs""" + + return dependable_rv in find_conditional_input_rvs((dependent_rv,), all_rvs) + + +def find_conditional_dependent_rvs(dependable_rv, all_rvs): + """Find rvs than depend on dependable""" + return [ + rv + for rv in all_rvs + if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) + ] + + +def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): + # TODO: No need to consider apply nodes outside the subgraph... + fg = FunctionGraph(outputs=output_rvs, clone=False) + + non_elemwise_blockers = [ + o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs + ] + blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers + blockers = [var for var in blocker_candidates if var not in output_rvs] + + truncated_inputs = [ + var + for var in ancestors(output_rvs, blockers=blockers) + if ( + var in blockers + or (var.owner is None and not isinstance(var, (Constant, SharedVariable))) + ) + ] + + # Check that we reach the marginalized rv following a pure elemwise graph + if rv_to_marginalize not in truncated_inputs: + return False + + # Check that none of the truncated inputs depends on the marginalized_rv + other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] + # TODO: We don't need to go all the way to the root variables + if rv_to_marginalize in ancestors( + other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] + ): + return False + return True + + +def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): + # TODO: This should eventually be integrated in a more general routine that can + # identify other types of supported marginalization, of which finite discrete + # RVs is just one + + dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) + if not dependent_rvs: + raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") + + ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} + if max(ndim_supp) > 0: + raise NotImplementedError( + "Marginalization of withe dependent Multivariate RVs not implemented" + ) + + marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) + dependent_rvs_input_rvs = [ + rv + for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) + if rv is not rv_to_marginalize + ] + + # If the marginalized RV has batched dimensions, check that graph between + # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. + # This implies (?) that the dimensions are completely independent and a logp graph + # can ultimately be generated that is proportional to the support domain and not + # to the variables dimensions + # We don't need to worry about this if the RV is scalar. + if np.prod(constant_fold(tuple(rv_to_marginalize.shape))) > 1: + if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): + raise NotImplementedError( + "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " + "This is currently not supported", + ) + + input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] + rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + + outputs = rvs_to_marginalize + # Clone replace inner RV rng inputs so that we can be sure of the update order + # replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()} + # Clone replace outter RV inputs, so that their shared RNGs don't make it into + # the inner graph of the marginalized RVs + # FIXME: This shouldn't be needed! + replace_inputs = {} + replace_inputs.update({input_rv: input_rv.type() for input_rv in input_rvs}) + cloned_outputs = clone_replace(outputs, replace=replace_inputs) + + marginalization_op = FiniteDiscreteMarginalRV( + inputs=list(replace_inputs.values()), + outputs=cloned_outputs, + ndim_supp=0, + ) + marginalized_rvs = marginalization_op(*replace_inputs.keys()) + fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) + return rvs_to_marginalize, marginalized_rvs + + +@_get_measurable_outputs.register(FiniteDiscreteMarginalRV) +def _get_measurable_outputs_finite_discrete_marginal_rv(op, node): + # Marginalized RVs are not measurable + return node.outputs[1:] + + +def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: + op = rv.owner.op + if isinstance(op, Bernoulli): + return (0, 1) + elif isinstance(op, Categorical): + p_param = rv.owner.inputs[3] + return tuple(range(at.get_vector_length(p_param))) + elif isinstance(op, DiscreteUniform): + lower, upper = constant_fold(rv.owner.inputs[3:]) + return tuple(range(lower, upper + 1)) + + raise NotImplementedError(f"Cannot compute domain for op {op}") + + +@_logprob.register(FiniteDiscreteMarginalRV) +def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rvs_node = op.make_node(*inputs) + marginalized_rv, *dependent_rvs = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + + # Obtain the joint_logp graph of the inner RV graph + # Some inputs are not root inputs (such as transformed projections of value variables) + # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) + inputs = list(inputvars(inputs)) + rvs_to_values = {} + dummy_marginalized_value = marginalized_rv.clone() + rvs_to_values[marginalized_rv] = dummy_marginalized_value + rvs_to_values.update(zip(dependent_rvs, values)) + logps_dict = factorized_joint_logprob(rv_values=rvs_to_values, **kwargs) + + # Reduce logp dimensions corresponding to broadcasted variables + values_axis_bcast = [] + for value in values: + vbcast = value.type.broadcastable + mbcast = dummy_marginalized_value.type.broadcastable + mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast + values_axis_bcast.append([i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]) + joint_logp = logps_dict[dummy_marginalized_value] + for value, values_axis_bcast in zip(values, values_axis_bcast): + joint_logp += logps_dict[value].sum(values_axis_bcast, keepdims=True) + + # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different + # values of the marginalized RV + # OpFromGraph does not accept constant inputs + non_const_values = [ + value + for value in rvs_to_values.values() + if not isinstance(value, (Constant, SharedVariable)) + ] + joint_logp_op = OpFromGraph([*non_const_values, *inputs], [joint_logp], inline=True) + + # Compute the joint_logp for all possible n values of the marginalized RV. We assume + # each original dimension is independent so that it sufficies to evaluate the graph + # n times, once with each possible value of the marginalized RV replicated across + # batched dimensions of the marginalized RV + + # PyMC does not allow RVs in the logp graph, even if we are just using the shape + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape)) + marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) + marginalized_rv_domain_tensor = at.swapaxes( + at.full( + (*marginalized_rv_shape, len(marginalized_rv_domain)), + marginalized_rv_domain, + dtype=marginalized_rv.dtype, + ), + axis1=0, + axis2=-1, + ) + + # OpFromGraph does not accept constant inputs + non_const_values = [ + value for value in values if not isinstance(value, (Constant, SharedVariable)) + ] + # Arbitrary cutoff to switch to Scan implementation to keep graph size under control + if len(marginalized_rv_domain) <= 10: + joint_logps = [ + joint_logp_op(marginalized_rv_domain_tensor[i], *non_const_values, *inputs) + for i in range(len(marginalized_rv_domain)) + ] + else: + # Make sure this is rewrite is registered + from pymc.aesaraf import local_remove_check_parameter + + def logp_fn(marginalized_rv_const, *non_sequences): + return joint_logp_op(marginalized_rv_const, *non_sequences) + + joint_logps, _ = scan_map( + fn=logp_fn, + sequences=marginalized_rv_domain_tensor, + non_sequences=[*non_const_values, *inputs], + mode=Mode().including("local_remove_check_parameter"), + ) + + joint_logps = at.logsumexp(joint_logps, axis=0) + + # We have to add dummy logps for the remaining value variables, otherwise AePPL will raise + return joint_logps, *(at.constant(0),) * (len(values) - 1) diff --git a/pymc_experimental/tests/test_marginal_model.py b/pymc_experimental/tests/test_marginal_model.py new file mode 100644 index 00000000..6048bc24 --- /dev/null +++ b/pymc_experimental/tests/test_marginal_model.py @@ -0,0 +1,410 @@ +import itertools +from contextlib import suppress as does_not_warn + +import aesara.tensor as at +import numpy as np +import pandas as pd +import pymc as pm +import pytest +from aeppl.logprob import _logprob +from pymc import ImputationWarning, inputvars +from pymc.distributions import transforms +from pymc.util import UNSET +from scipy.special import logsumexp + +from pymc_experimental.marginal_model import FiniteDiscreteMarginalRV, MarginalModel + + +@pytest.fixture +def disaster_model(): + # fmt: off + disaster_data = pd.Series( + [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, + 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5, + 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0, + 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, + 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1] + ) + # fmt: on + years = np.arange(1851, 1962) + + with MarginalModel() as disaster_model: + switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max()) + early_rate = pm.Exponential("early_rate", 1.0, initval=3) + late_rate = pm.Exponential("late_rate", 1.0, initval=1) + rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) + with pytest.warns(ImputationWarning): + disasters = pm.Poisson("disasters", rate, observed=disaster_data) + + return disaster_model, years + + +@pytest.mark.filterwarnings("error") +def test_marginalized_bernoulli_logp(): + """Test logp of IR TestFiniteMarginalDiscreteRV directly""" + mu = at.vector("mu") + + idx = pm.Bernoulli.dist(0.7, name="idx") + y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") + marginal_rv_node = FiniteDiscreteMarginalRV([mu], [idx, y], ndim_supp=None, n_updates=0,)( + mu + )[0].owner + + y_vv = y.clone() + (logp,) = _logprob( + marginal_rv_node.op, + (y_vv,), + *marginal_rv_node.inputs, + ) + + ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv) + np.testing.assert_almost_equal( + logp.eval({mu: [-1, 1], y_vv: 2}), + ref_logp.eval({mu: [-1, 1], y_vv: 2}), + ) + + +@pytest.mark.filterwarnings("error") +def test_marginalized_basic(): + data = [2] * 5 + + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6]) + mu = at.switch( + at.eq(idx, 0), + -1.0, + at.switch( + at.eq(idx, 1), + 0.0, + 1.0, + ), + ) + y = pm.Normal("y", mu=mu, sigma=sigma) + z = pm.Normal("z", y, observed=data) + + m.marginalize([idx]) + assert idx not in m.free_RVs + assert [rv.name for rv in m.marginalized_rvs] == ["idx"] + + # Test logp + with pm.Model() as m_ref: + sigma = pm.HalfNormal("sigma") + y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) + z = pm.Normal("z", y, observed=data) + + test_point = m_ref.initial_point() + ref_logp = m_ref.compile_logp()(test_point) + ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point) + + # Assert we can marginalize and unmarginalize internally non-destructively + for i in range(3): + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_logp, + ) + np.testing.assert_almost_equal( + m.compile_dlogp([m["y"]])(test_point), + ref_dlogp, + ) + + +@pytest.mark.filterwarnings("error") +def test_multiple_independent_marginalized_rvs(): + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx1 = pm.Bernoulli("idx1", p=0.75) + x = pm.Normal("x", mu=idx1, sigma=sigma) + idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) + y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) + + m.marginalize([idx1, idx2]) + m["x"].owner is not m["y"].owner + _m = m.clone()._marginalize() + _m["x"].owner is not _m["y"].owner + + with pm.Model() as m_ref: + sigma = pm.HalfNormal("sigma") + x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma) + y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,)) + + # Test logp + test_point = m_ref.initial_point() + x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point) + x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point) + np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum()) + np.testing.assert_array_almost_equal(y_logp, y_ref_logp) + + +@pytest.mark.filterwarnings("error") +def test_multiple_dependent_marginalized_rvs(): + """Test that marginalization works when there is more than one dependent RV""" + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.75) + x = pm.Normal("x", mu=idx, sigma=sigma) + y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,)) + + ref_logp_x_y_fn = m.compile_logp([idx, x, y]) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([idx]) + + m["x"].owner is not m["y"].owner + _m = m.clone()._marginalize() + _m["x"].owner is _m["y"].owner + + tp = m.initial_point() + ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)]) + logp_x_y = m.compile_logp([x, y])(tp) + np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) + + +@pytest.mark.filterwarnings("error") +def test_nested_marginalized_rvs(): + """Test that marginalization works when there are nested marginalized RVs""" + + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + + idx = pm.Bernoulli("idx", p=0.75) + dep = pm.Normal("dep", mu=at.switch(at.eq(idx, 0), -1000, 1000), sigma=sigma) + + sub_idx = pm.Bernoulli("sub_idx", p=at.switch(at.eq(idx, 0), 0.15, 0.95), shape=(5,)) + sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,)) + + ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep]) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([idx, sub_idx]) + + assert set(m.marginalized_rvs) == {idx, sub_idx} + + # Test logp + test_point = m.initial_point() + test_point["dep"] = 1000 + test_point["sub_dep"] = np.full((5,), 1000 + 100) + + ref_logp = [ + ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) + for idx in (0, 1) + for sub_idxs in itertools.product((0, 1), repeat=5) + ] + logp = m.compile_logp(vars=[dep, sub_dep])(test_point) + + np.testing.assert_almost_equal( + logp, + logsumexp(ref_logp), + ) + + +@pytest.mark.filterwarnings("error") +def test_marginalized_change_point_model(disaster_model): + m, years = disaster_model + + ip = m.initial_point() + ip.pop("switchpoint") + ref_logp_fn = m.compile_logp( + [m["switchpoint"], m["disasters_observed"], m["disasters_missing"]] + ) + ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(m["switchpoint"]) + + logp = m.compile_logp([m["disasters_observed"], m["disasters_missing"]])(ip) + np.testing.assert_almost_equal(logp, ref_logp) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("error") +def test_marginalized_change_point_model_sampling(disaster_model): + m, _ = disaster_model + + rng = np.random.default_rng(211) + + with m: + before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([m["switchpoint"]]) + + with m: + after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) + + np.testing.assert_allclose( + before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2 + ) + np.testing.assert_allclose( + before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2 + ) + np.testing.assert_allclose( + before_marg["disasters_missing"].mean(), after_marg["disasters_missing"].mean(), rtol=1e-2 + ) + + +@pytest.mark.filterwarnings("error") +def test_not_supported_marginalized(): + """Marginalized graphs with non-Elemwise Operations are not supported as they + would violate the batching logp assumption""" + mu = at.constant([-1, 1]) + + # Allowed, as only elemwise operations connect idx to y + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + idx = pm.Bernoulli("idx", p=p, size=2) + y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) + m.marginalize([idx]) + + # ALlowed, as index operation does not connext idx to y + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + idx = pm.Bernoulli("idx", p=p, size=2) + y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) + m.marginalize([idx]) + + # Not allowed, as index operation connects idx to y + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + idx = pm.Bernoulli("idx", p=p, size=2) + # Not allowed + y = pm.Normal("y", mu=mu[idx]) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + # Not allowed, as index operation connects idx to y, even though there is a + # pure Elemwise connection between the two + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + idx = pm.Bernoulli("idx", p=p, size=2) + y = pm.Normal("y", mu=mu[idx] + idx) + with pytest.raises(NotImplementedError): + m.marginalize(idx) + + # Multivariate dependent RVs not supported + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) + with pytest.raises( + NotImplementedError, + match="Marginalization of withe dependent Multivariate RVs not implemented", + ): + m.marginalize(x) + + +@pytest.mark.filterwarnings("error") +def test_marginalized_deterministic_and_potential(): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + z = pm.Normal("z", x) + det = pm.Deterministic("det", y + z) + pot = pm.Potential("pot", y + z + 1) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize([x]) + + y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5) + np.testing.assert_almost_equal(y_draw + z_draw, det_draw) + np.testing.assert_almost_equal(det_draw, pot_draw - 1) + + y_value = m.rvs_to_values[y] + z_value = m.rvs_to_values[z] + det_value, pot_value = m.replace_rvs_by_values([det, pot]) + assert set(inputvars([det_value, pot_value])) == {y_value, z_value} + assert det_value.eval({y_value: 2, z_value: 5}) == 7 + assert pot_value.eval({y_value: 2, z_value: 5}) == 8 + + +@pytest.mark.filterwarnings("error") +def test_not_supported_marginalized_deterministic_and_potential(): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + det = pm.Deterministic("det", x + y) + + with pytest.raises( + NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det" + ): + m.marginalize([x]) + + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7) + y = pm.Normal("y", x) + pot = pm.Potential("pot", x + y) + + with pytest.raises( + NotImplementedError, match="Cannot marginalize x due to dependent Potential pot" + ): + m.marginalize([x]) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize( + "transform, expected_warning", + ( + (None, does_not_warn()), + pytest.param( + UNSET, + pytest.warns( + UserWarning, match="Transforms for variables that depend on marginalized RVs" + ), + marks=pytest.mark.xfail( + reason="AePPL transform rewrite does not support multiple_output nodes", + raises=AssertionError, + ), + ), + pytest.param( + transforms.log, + pytest.warns( + UserWarning, match="Transforms for variables that depend on marginalized RVs" + ), + marks=pytest.mark.xfail( + reason="AePPL transform rewrite does not support multiple_output nodes", + raises=AssertionError, + ), + ), + ), +) +def test_marginalized_transforms(transform, expected_warning): + w = [0.1, 0.3, 0.6] + data = [0, 5, 10] + initval = 0.5 # Value that will be negative on the unconstrained space + + with pm.Model() as m_ref: + sigma = pm.Mixture( + "sigma", + w=w, + comp_dists=pm.HalfNormal.dist([1, 2, 3]), + initval=initval, + transform=transform, + ) + y = pm.Normal("y", 0, sigma, observed=data) + + with MarginalModel() as m: + idx = pm.Categorical("idx", p=w) + sigma = pm.HalfNormal( + "sigma", + at.switch( + at.eq(idx, 0), + 1, + at.switch( + at.eq(idx, 1), + 2, + 3, + ), + ), + initval=initval, + transform=transform, + ) + y = pm.Normal("y", 0, sigma, observed=data) + + with expected_warning: + m.marginalize([idx]) + + ip = m.initial_point() + if transform is not None: + assert "sigma_log__" in ip + np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))