Skip to content

Commit

Permalink
Upgrade to aesara=2.8.2 and aeppl=0.0.35
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica authored and ricardoV94 committed Aug 25, 2022
1 parent 5a7c827 commit 4ca3414
Show file tree
Hide file tree
Showing 18 changed files with 51 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ repos:
- types-filelock
- types-setuptools
- arviz
- aesara==2.7.9
- aeppl==0.0.34
- aesara==2.8.2
- aeppl==0.0.35
always_run: true
require_serial: true
pass_filenames: false
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
- defaults
dependencies:
# Base dependencies
- aeppl=0.0.34
- aesara=2.7.9
- aeppl=0.0.35
- aesara=2.8.2
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
- defaults
dependencies:
# Base dependencies
- aeppl=0.0.34
- aesara=2.7.9
- aeppl=0.0.35
- aesara=2.8.2
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
- defaults
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.34
- aesara=2.7.9
- aeppl=0.0.35
- aesara=2.8.2
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
- defaults
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.34
- aesara=2.7.9
- aeppl=0.0.35
- aesara=2.8.2
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
6 changes: 3 additions & 3 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from aesara import config, scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph import local_optimizer
from aesara.graph import node_rewriter
from aesara.graph.basic import (
Apply,
Constant,
Expand Down Expand Up @@ -875,7 +875,7 @@ def largest_common_dtype(tensors):
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype


@local_optimizer(tracks=[CheckParameterValue])
@node_rewriter(tracks=[CheckParameterValue])
def local_remove_check_parameter(fgraph, node):
"""Rewrite that removes Aeppl's CheckParameterValue
Expand All @@ -885,7 +885,7 @@ def local_remove_check_parameter(fgraph, node):
return [node.inputs[0]]


@local_optimizer(tracks=[CheckParameterValue])
@node_rewriter(tracks=[CheckParameterValue])
def local_check_parameter_to_ninf_switch(fgraph, node):
if isinstance(node.op, CheckParameterValue):
logp_expr, *logp_conds = node.inputs
Expand Down
26 changes: 21 additions & 5 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from aesara.tensor.math import tanh
from aesara.tensor.random.basic import (
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
Expand Down Expand Up @@ -1464,7 +1463,7 @@ def dist(cls, lam, *args, **kwargs):
lam = at.as_tensor_variable(floatX(lam))

# Aesara exponential op is parametrized in terms of mu (1/lam)
return super().dist([at.inv(lam)], **kwargs)
return super().dist([at.reciprocal(lam)], **kwargs)

def moment(rv, size, mu):
if not rv_size_is_none(size):
Expand All @@ -1487,7 +1486,7 @@ def logcdf(value, mu):
-------
TensorVariable
"""
lam = at.inv(mu)
lam = at.reciprocal(mu)
res = at.switch(
at.lt(value, 0),
-np.inf,
Expand Down Expand Up @@ -2313,7 +2312,7 @@ def logcdf(value, alpha, inv_beta):
-------
TensorVariable
"""
beta = at.inv(inv_beta)
beta = at.reciprocal(inv_beta)
res = at.switch(
at.lt(value, 0),
-np.inf,
Expand Down Expand Up @@ -2518,8 +2517,15 @@ def logcdf(value, nu):


# TODO: Remove this once logp for multiplication is working!
class WeibullBetaRV(WeibullRV):
class WeibullBetaRV(RandomVariable):
name = "weibull"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Weibull", "\\operatorname{Weibull}")

def __call__(self, alpha, beta, size=None, **kwargs):
return super().__call__(alpha, beta, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
Expand Down Expand Up @@ -2615,6 +2621,16 @@ def logcdf(value, alpha, beta):

return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")

def logp(value, alpha, beta):
res = (
at.log(alpha)
- at.log(beta)
+ (alpha - 1.0) * at.log(value / beta)
- at.pow(value / beta, alpha)
)
res = at.switch(at.ge(value, 0.0), res, -np.inf)
return check_parameters(res, 0 < alpha, 0 < beta, msg="alpha > 0, beta > 0")


class HalfStudentTRV(RandomVariable):
name = "halfstudentt"
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def sigma2rho(sigma):
"""
`sigma -> rho` Aesara converter
:math:`mu + sigma*e = mu + log(1+exp(rho))*e`"""
return at.log(at.exp(at.abs_(sigma)) - 1.0)
return at.log(at.exp(at.abs(sigma)) - 1.0)


def rho2sigma(rho):
Expand Down Expand Up @@ -213,7 +213,7 @@ def log_normal(x, mean, **kwargs):
else:
std = tau ** (-1)
std += f(eps)
return f(c) - at.log(at.abs_(std)) - (x - mean) ** 2 / (2.0 * std**2)
return f(c) - at.log(at.abs(std)) - (x - mean) ** 2 / (2.0 * std**2)


def MvNormalLogp():
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aeppl.abstract import assign_custom_measurable_outputs
from aeppl.logprob import logcdf as logcdf_aeppl
from aeppl.logprob import logprob as logp_aeppl
from aeppl.transforms import TransformValuesOpt
from aeppl.transforms import TransformValuesRewrite
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import (
Expand Down Expand Up @@ -231,7 +231,7 @@ def joint_logp(
if original_value_var is not None and hasattr(original_value_var.tag, "transform"):
transform_map[value_var] = original_value_var.tag.transform

transform_opt = TransformValuesOpt(transform_map)
transform_opt = TransformValuesRewrite(transform_map)
temp_logp_var_dict = factorized_joint_logprob(
tmp_rvs_to_values,
extra_rewrites=transform_opt,
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def moment(rv, size, n, p):
n = at.shape_padright(n)
mode = at.round(n * p)
diff = n - at.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = at.abs_(diff) > 0
inc_bool_arr = at.abs(diff) > 0
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
if not rv_size_is_none(size):
output_size = at.concatenate([size, [p.shape[-1]]])
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def gaussian(epsilon, obs_data, sim_data):

def laplace(epsilon, obs_data, sim_data):
"""Laplace kernel."""
return -at.abs_((obs_data - sim_data) / epsilon)
return -at.abs((obs_data - sim_data) / epsilon)


class KullbackLeibler:
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aeppl.logprob import _logprob
from aesara import scan
from aesara.compile.builders import OpFromGraph
from aesara.graph import FunctionGraph, optimize_graph
from aesara.graph import FunctionGraph, rewrite_graph
from aesara.graph.basic import Node
from aesara.raise_op import Assert
from aesara.tensor import TensorVariable
Expand Down Expand Up @@ -495,7 +495,7 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
features=[ShapeFeature()],
clone=True,
)
(folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
(folded_shape,) = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs
folded_shape = getattr(folded_shape, "data", None)
if folded_shape is None:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def dist(self, X, Xs):
Xs = at.transpose(X)
else:
Xs = at.transpose(Xs)
return at.abs_((X - Xs + self.c) % (self.c * 2) - self.c)
return at.abs((X - Xs + self.c) % (self.c * 2) - self.c)

def weinland(self, t):
return (1 + self.tau * t / self.c) * at.clip(1 - t / self.c, 0, np.inf) ** self.tau
Expand Down
4 changes: 2 additions & 2 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# pylint: disable=unused-import
from aesara.tensor import (
abs_,
abs,
and_,
ceil,
clip,
Expand Down Expand Up @@ -90,7 +90,7 @@
# pylint: enable=unused-import

__all__ = [
"abs_",
"abs",
"and_",
"ceil",
"clip",
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_jaxified_graph(
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
mode.JAX.optimizer.optimize(fgraph)
mode.JAX.optimizer.rewrite(fgraph)

# We now jaxify the optimized fgraph
return jax_funcify(fgraph)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def check_jacobian_det(
if not elemwise:
jac = at.log(at.nlinalg.det(jacobian(x, [y])))
else:
jac = at.log(at.abs_(at.diag(jacobian(x, [y]))))
jac = at.log(at.abs(at.diag(jacobian(x, [y]))))

# ljd = log jacobian det
actual_ljd = aesara.function([y], jac)
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify.
# See that file for comments about the need/usage of each dependency.

aeppl==0.0.34
aesara==2.7.9
aeppl==0.0.35
aesara==2.8.2
arviz>=0.12.0
cachetools>=4.2.1
cloudpickle
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
aeppl==0.0.34
aesara==2.7.9
aeppl==0.0.35
aesara==2.8.2
arviz>=0.12.0
cachetools>=4.2.1
cloudpickle
Expand Down

0 comments on commit 4ca3414

Please sign in to comment.