Skip to content

Commit

Permalink
Refactor Model, Uniform, and Normal so that they work with RandomVari…
Browse files Browse the repository at this point in the history
…ables
  • Loading branch information
brandonwillard committed Jan 25, 2021
1 parent 823906a commit eb1e01f
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 326 deletions.
3 changes: 2 additions & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __set_compiler_flags():
from pymc3.distributions import *
from pymc3.distributions import transforms
from pymc3.exceptions import *
from pymc3.glm import *

# from pymc3.glm import *
from pymc3.math import (
expand_packed_triangular,
invlogit,
Expand Down
149 changes: 132 additions & 17 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,125 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
from typing import Optional

from pymc3.distributions import shape_utils, timeseries, transforms
import theano.tensor as tt

from theano.tensor.random.op import Observed
from theano.tensor.var import TensorVariable


def get_rv_var_and_value(
rv_var: TensorVariable,
rv_value: Optional[TensorVariable] = None,
) -> TensorVariable:

if rv_value is None:
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
rv_var, rv_value = rv_var.owner.inputs
elif hasattr(rv_var.tag, "value_var"):
rv_value = rv_var.tag.value_var
else:
raise ValueError("value is unspecified")

return rv_var, rv_value


def logpt(
rv_var: TensorVariable,
rv_value: Optional[TensorVariable] = None,
jacobian: bool = True,
scaling: Optional[TensorVariable] = None,
**kwargs
) -> TensorVariable:
"""Get a graph of the log-likelihood for a random variable at a point."""

rv_var, rv_value = get_rv_var_and_value(rv_var, rv_value)
rv_node = rv_var.owner

if not rv_node:
raise TypeError("rv_var must be the output of a RandomVariable Op")

rng, size, dtype, *dist_params = rv_node.inputs

if jacobian:
logp_var = _logp(rv_node.op, rv_value, *dist_params, **kwargs)
else:
logp_var = _logp_nojac(rv_node.op, rv_value, *dist_params, **kwargs)

if scaling:
logp_var *= scaling

if rv_var.name is not None:
logp_var.name = "__logp_%s" % rv_var.name

return logp_var


@singledispatch
def _logp(op, value, *dist_params, **kwargs):
return tt.zeros_like(value)


def logcdf(rv_var, rv_value, **kwargs):

rv_var, rv_value = get_rv_var_and_value(rv_var, rv_value)
rv_node = rv_var.owner

if not rv_node:
raise TypeError()

rng, size, dtype, *dist_params = rv_node.inputs

return _logcdf(rv_node.op, rv_value, *dist_params, **kwargs)


@singledispatch
def _logcdf(op, value, *args, **kwargs):
raise NotImplementedError()


def logp_nojac(rv_var, rv_value=None, **kwargs):

rv_var, rv_value = get_rv_var_and_value(rv_var, rv_value)
rv_node = rv_var.owner

if not rv_node:
raise TypeError()

rng, size, dtype, *dist_params = rv_node.inputs

return _logp_nojac(rv_node.op, rv_value, **kwargs)


@singledispatch
def _logp_nojac(op, value, *args, **kwargs):
"""Return the logp, but do not include a jacobian term for transforms.
If we use different parametrizations for the same distribution, we
need to add the determinant of the jacobian of the transformation
to make sure the densities still describe the same distribution.
However, MAP estimates are not invariant with respect to the
parameterization, we need to exclude the jacobian terms in this case.
This function should be overwritten in base classes for transformed
distributions.
"""
return logpt(op, value, *args, **kwargs)


def logp_sum(var, value, *args, **kwargs):
"""Return the sum of the logp values for the given observations.
Subclasses can use this to improve the speed of logp evaluations
if only the sum of the logp values is needed.
"""
return tt.sum(logpt(var, value, *args, **kwargs))


# from pymc3.distributions import timeseries
from pymc3.distributions import shape_utils, transforms
from pymc3.distributions.bart import BART
from pymc3.distributions.bound import Bound
from pymc3.distributions.continuous import (
Expand Down Expand Up @@ -74,7 +191,6 @@
Discrete,
Distribution,
NoDistribution,
TensorType,
draw_values,
generate_samples,
)
Expand All @@ -94,15 +210,15 @@
)
from pymc3.distributions.posterior_predictive import fast_sample_posterior_predictive
from pymc3.distributions.simulator import Simulator
from pymc3.distributions.timeseries import (
AR,
AR1,
GARCH11,
GaussianRandomWalk,
MvGaussianRandomWalk,
MvStudentTRandomWalk,
)

# from pymc3.distributions.timeseries import (
# AR,
# AR1,
# GARCH11,
# GaussianRandomWalk,
# MvGaussianRandomWalk,
# MvStudentTRandomWalk,
# )
__all__ = [
"Uniform",
"Flat",
Expand Down Expand Up @@ -149,7 +265,6 @@
"Continuous",
"Discrete",
"NoDistribution",
"TensorType",
"MvNormal",
"MatrixNormal",
"KroneckerNormal",
Expand All @@ -161,13 +276,13 @@
"WishartBartlett",
"LKJCholeskyCov",
"LKJCorr",
"AR1",
"AR",
# "AR1",
# "AR",
"AsymmetricLaplace",
"GaussianRandomWalk",
"MvGaussianRandomWalk",
"MvStudentTRandomWalk",
"GARCH11",
# "GaussianRandomWalk",
# "MvGaussianRandomWalk",
# "MvStudentTRandomWalk",
# "GARCH11",
"SkewNormal",
"Mixture",
"NormalMixture",
Expand Down
Loading

0 comments on commit eb1e01f

Please sign in to comment.