Skip to content

Commit

Permalink
Merge Aeppl as logprob submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
pymc-devs authored and twiecki committed Nov 25, 2022
1 parent 0654e7f commit e0d25c8
Show file tree
Hide file tree
Showing 69 changed files with 7,841 additions and 338 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ jobs:
pymc/tests/ode/test_utils.py
pymc/tests/step_methods/hmc/test_quadpotential.py
- |
pymc/tests/logprob/test_abstract.py
pymc/tests/logprob/test_censoring.py
pymc/tests/logprob/test_composite_logprob.py
pymc/tests/logprob/test_cumsum.py
pymc/tests/logprob/test_mixture.py
pymc/tests/logprob/test_rewriting.py
pymc/tests/logprob/test_scan.py
pymc/tests/logprob/test_tensor.py
pymc/tests/logprob/test_transforms.py
pymc/tests/logprob/test_utils.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ repos:
(?x)(arviz-devs.github.io|
python.arviz.org|
aesara.readthedocs.io|
aeppl.readthedocs.io|
pymc-experimental.readthedocs.io|
docs.pymc.io|
www.pymc.io|
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ channels:
- defaults
dependencies:
# Base dependencies
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.13.0
- blas
Expand Down Expand Up @@ -41,3 +40,4 @@ dependencies:
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
- numdifftools>=0.9.40
3 changes: 2 additions & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ channels:
- defaults
dependencies:
# Base dependencies
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.13.0
- blas
Expand All @@ -29,3 +28,5 @@ dependencies:
- pytest>=3.0
- mypy=0.990
- types-cachetools
- pip:
- numdifftools>=0.9.40
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ channels:
- defaults
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.13.0
- blas
Expand Down Expand Up @@ -38,3 +37,4 @@ dependencies:
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
- numdifftools>=0.9.40
3 changes: 2 additions & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ channels:
- defaults
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.13.0
- blas
Expand All @@ -30,3 +29,5 @@ dependencies:
- pytest>=3.0
- mypy=0.990
- types-cachetools
- pip:
- numdifftools>=0.9.40
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@
intersphinx_mapping = {
"arviz": ("https://python.arviz.org/en/latest/", None),
"aesara": ("https://aesara.readthedocs.io/en/latest/", None),
"aeppl": ("https://aeppl.readthedocs.io/en/latest/", None),
"home": ("https://www.pymc.io", None),
"pmx": ("https://www.pymc.io/projects/experimental/en/latest", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down
20 changes: 17 additions & 3 deletions docs/source/learn/core_notebooks/GLM_linear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,15 @@
"source": [
"%load_ext watermark\n",
"\n",
"%watermark -n -u -v -iv -w -p aesara,aeppl"
"%watermark -n -u -v -iv -w -p aesara"
]
}
],
"metadata": {
"anaconda-cloud": {},
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -543,14 +544,27 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.8.10"
},
"latex_envs": {
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 0
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
20 changes: 17 additions & 3 deletions docs/source/learn/core_notebooks/model_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -536,16 +536,17 @@
}
],
"source": [
"%watermark -n -u -v -iv -w -p xarray,aesara,aeppl"
"%watermark -n -u -v -iv -w -p xarray,aesara"
]
}
],
"metadata": {
"hide_input": false,
"interpreter": {
"hash": "baf205d70af30bf8b721a304f5a44beb31bf8af014f6b7340f1a7ae004926653"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -559,7 +560,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
20 changes: 17 additions & 3 deletions docs/source/learn/core_notebooks/posterior_predictive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4649,14 +4649,15 @@
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -w -p aesara,aeppl"
"%watermark -n -u -v -iv -w -p aesara"
]
}
],
"metadata": {
"anaconda-cloud": {},
"hide_input": false,
"kernelspec": {
"display_name": "Python 3.9.13 ('pymc-dev-py39')",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -4670,7 +4671,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"vscode": {
"interpreter": {
Expand Down
185 changes: 140 additions & 45 deletions docs/source/learn/core_notebooks/pymc_aesara.ipynb

Large diffs are not rendered by default.

20 changes: 17 additions & 3 deletions docs/source/learn/core_notebooks/pymc_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4340,14 +4340,15 @@
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -w -p xarray,aeppl"
"%watermark -n -u -v -iv -w -p xarray"
]
}
],
"metadata": {
"anaconda-cloud": {},
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -4361,7 +4362,20 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __set_compiler_flags():
from pymc.distributions import *
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.logprob import *
from pymc.math import (
expand_packed_triangular,
invlogit,
Expand Down
10 changes: 5 additions & 5 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import pandas as pd
import scipy.sparse as sps

from aeppl.logprob import CheckParameterValue
from aeppl.transforms import RVTransform
from aesara import scalar
from aesara.compile import Function, Mode, get_mode
from aesara.gradient import grad
Expand Down Expand Up @@ -65,6 +63,8 @@
from aesara.tensor.var import TensorConstant, TensorVariable

from pymc.exceptions import NotConstantValueError
from pymc.logprob.transforms import RVTransform
from pymc.logprob.utils import CheckParameterValue
from pymc.vartypes import continuous_types, isgenerator, typefilter

PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
Expand Down Expand Up @@ -944,7 +944,7 @@ def largest_common_dtype(tensors):

@node_rewriter(tracks=[CheckParameterValue])
def local_remove_check_parameter(fgraph, node):
"""Rewrite that removes Aeppl's CheckParameterValue
"""Rewrite that removes CheckParameterValue
This is used when compile_rv_inplace
"""
Expand Down Expand Up @@ -1068,13 +1068,13 @@ def compile_pymc(
Ensures that compiled functions containing random variables will produce new
samples on each call.
local_check_parameter_to_ninf_switch
Replaces Aeppl's CheckParameterValue assertions is logp expressions with Switches
Replaces CheckParameterValue assertions is logp expressions with Switches
that return -inf in place of the assert.
Optional rewrites
-----------------
local_remove_check_parameter
Replaces Aeppl's CheckParameterValue assertions is logp expressions. This is used
Replaces CheckParameterValue assertions is logp expressions. This is used
as an alteranative to the default local_check_parameter_to_ninf_switch whenenver
this function is called within a model context and the model `check_bounds` flag
is set to False.
Expand Down
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pymc.distributions.logprob import ( # isort:skip
logcdf,
logp,
joint_logp,
)

from pymc.distributions.bound import Bound
Expand Down Expand Up @@ -198,7 +197,6 @@
"Censored",
"CAR",
"PolyaGamma",
"joint_logp",
"logp",
"logcdf",
]
2 changes: 1 addition & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_aeppl = True
inline_logprob = True
_print_name = ("Censored", "\\operatorname{Censored}")


Expand Down
10 changes: 5 additions & 5 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import aesara.tensor as at
import numpy as np

from aeppl.logprob import _logprob, logcdf, logprob
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.raise_op import Assert
Expand Down Expand Up @@ -57,6 +56,8 @@
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from pymc.logprob.abstract import _logprob, logcdf, logprob

try:
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -531,6 +532,9 @@ def logcdf(value, mu, sigma):
msg="sigma > 0",
)

def icdf(value, mu, sigma):
return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)


class TruncatedNormalRV(RandomVariable):
name = "truncated_normal"
Expand Down Expand Up @@ -1290,10 +1294,6 @@ class Exponential(PositiveContinuous):
Variance :math:`\dfrac{1}{\lambda^2}`
======== ============================
Notes
-----
Logp calculation is defined in `aeppl.logprob <https://github.com/aesara-devs/aeppl/blob/main/aeppl/logprob.py/>`_.
Parameters
----------
lam : tensor_like of float
Expand Down
3 changes: 3 additions & 0 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,9 @@ def logcdf(value, p):
msg="0 <= p <= 1",
)

def icdf(value, p):
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")


class HyperGeometric(Discrete):
R"""
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import scipy.linalg
import scipy.stats

from aeppl.logprob import CheckParameterValue
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
Expand All @@ -38,6 +37,7 @@

from pymc.aesaraf import floatX
from pymc.distributions.shape_utils import to_tuple
from pymc.logprob.utils import CheckParameterValue

solve_lower = SolveTriangular(lower=True)
solve_upper = SolveTriangular(lower=False)
Expand Down
Loading

0 comments on commit e0d25c8

Please sign in to comment.