Skip to content

Commit

Permalink
Merge branch 'distsigs' of https://github.com/cluhmann/pymc into dist…
Browse files Browse the repository at this point in the history
…sigs
  • Loading branch information
cluhmann committed Mar 29, 2022
2 parents b83e0bb + 705aff1 commit f33fc00
Show file tree
Hide file tree
Showing 24 changed files with 490 additions and 265 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/autoupdate-pre-commit-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: |
~/.cache/pre-commit
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/dispatched_pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
ref: ${{github.event.client_payload.pull_request.head.ref}}
token: ${{ secrets.ACTION_TRIGGER_TOKEN }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: |
~/.cache/pre-commit
Expand Down
20 changes: 10 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment-test-py37.yml has not changed
CACHE_NUMBER: 0
Expand All @@ -95,7 +95,7 @@ jobs:
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-test-py37.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
Expand Down Expand Up @@ -154,7 +154,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
uses: actions/cache@v3
env:
# Increase this value to reset cache if conda-envs/environment-test-py38.yml has not changed
CACHE_NUMBER: 0
Expand All @@ -163,7 +163,7 @@ jobs:
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/windows-environment-test-py38.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
Expand Down Expand Up @@ -230,7 +230,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment-test-py39.yml has not changed
CACHE_NUMBER: 0
Expand All @@ -239,7 +239,7 @@ jobs:
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-test-py39.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
Expand Down Expand Up @@ -292,7 +292,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment-test-py39.yml has not changed
CACHE_NUMBER: 0
Expand All @@ -301,7 +301,7 @@ jobs:
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-test-py39.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
Expand Down Expand Up @@ -359,7 +359,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
uses: actions/cache@v3
env:
# Increase this value to reset cache if conda-envs/environment-test-py38.yml has not changed
CACHE_NUMBER: 0
Expand All @@ -368,7 +368,7 @@ jobs:
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/windows-environment-test-py38.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
uses: actions/cache@v3
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This guide provides an overview on how to implement a distribution for version 4
It is designed for developers who wish to add a new distribution to the library.
Users will not be aware of all this complexity and should instead make use of helper methods such as `~pymc.distributions.DensityDist`.

PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers.
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `moment` methods as well as other initialization and validation helpers.
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.

Here is a summary check-list of the steps needed to implement a new distribution.
Expand All @@ -13,7 +13,7 @@ Each section will be expanded below:
1. Creating a new `RandomVariable` `Op`
1. Implementing the corresponding `Distribution` class
1. Adding tests for the new `RandomVariable`
1. Adding tests for `logp` / `logcdf` and `get_moment` methods
1. Adding tests for `logp` / `logcdf` and `moment` methods
1. Documenting the new `Distribution`.

This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions.
Expand Down Expand Up @@ -119,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in
PyMC 4.x works in a very {term}`functional <Functional Programming>` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together.
In practice, they take care of:

1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods.
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `moment`, `logp` and `logcdf` methods.
1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer.
1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables.
1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of.
Expand Down Expand Up @@ -154,9 +154,9 @@ class Blah(PositiveContinuous):
# the rv_op needs in order to be instantiated
return super().dist([param1, param2], **kwargs)

# get_moment returns a symbolic expression for the stable moment from which to start sampling
# moment returns a symbolic expression for the stable moment from which to start sampling
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`
def get_moment(rv, size, param1, param2):
def moment(rv, size, param1, param2):
moment, _ = at.broadcast_arrays(param1, param2)
if not rv_size_is_none(size):
moment = at.full(size, moment)
Expand Down Expand Up @@ -193,30 +193,29 @@ class Blah(PositiveContinuous):

Some notes:

1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also override the `__new__` method, as is done for the {class}`~pymc.distributions.multivariate.Dirichlet`.
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a standard `initval` by
overriding `__new__`.
1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, `SimplexContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also create a `_default_transform` dispatch function as is done for the {class}`~pymc.distributions.multivariate.LKJCholeskyCov`.
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a `moment` method.
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
1. The `logcdf` method is not a requirement, but it's a nice plus!
1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
1. Currently only one moment is supported in the `moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
1. When creating the `moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.

For a quick check that things are working you can try the following:

```python

import pymc as pm
from pymc.distributions.distribution import get_moment
from pymc.distributions.distribution import moment

# pm.blah = pm.Normal in this example
blah = pm.blah.dist(mu = 0, sigma = 1)
blah = pm.blah.dist(mu=0, sigma=1)

# Test that the returned blah_op is still working fine
blah.eval()
# array(-1.01397228)

# Test the get_moment method
get_moment(blah).eval()
# Test the moment method
moment(blah).eval()
# array(0.)

# Test the logp method
Expand Down Expand Up @@ -367,9 +366,9 @@ def test_blah_logcdf(self):

```

## 5. Adding tests for the `get_moment` method
## 5. Adding tests for the `moment` method

Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
Tests for the `moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
which checks if:
1. Moments return the `expected` values
1. Moments have the expected size and shape
Expand Down
12 changes: 9 additions & 3 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import floatX, intX
from pymc.distributions.continuous import BoundedContinuous
from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.logprob import logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.model import modelcontext
from pymc.util import check_dist_not_registered

Expand Down Expand Up @@ -82,6 +83,11 @@ def logp(value, distribution, lower, upper):
)


@_default_transform.register(BoundRV)
def bound_default_transform(op, rv):
return bounded_cont_transform(op, rv, _ContinuousBounded.bound_args_indices)


class DiscreteBoundRV(BoundRV):
name = "discrete_bound"
dtype = "int64"
Expand All @@ -94,8 +100,8 @@ class _DiscreteBounded(Discrete):
rv_op = discrete_boundrv

def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", None)
if transform is not None:
kwargs.setdefault("transform", None)
if kwargs.get("transform") is not None:
raise ValueError("Cannot transform discrete variable.")
return super().__new__(cls, *args, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.distributions.distribution import SymbolicDistribution, _get_moment
from pymc.distributions.distribution import SymbolicDistribution, _moment
from pymc.util import check_dist_not_registered


Expand Down Expand Up @@ -124,8 +124,8 @@ def graph_rvs(cls, rv):
return (rv.tag.dist,)


@_get_moment.register(Clip)
def get_moment_censored(op, rv, dist, lower, upper):
@_moment.register(Clip)
def moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
at.eq(lower, -np.inf),
at.switch(
Expand Down
Loading

0 comments on commit f33fc00

Please sign in to comment.