From 9e09aa44b2d83c2f9aa432973d859f614adadcf3 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 19 Nov 2022 01:51:15 +0100 Subject: [PATCH] Fix remaining type problems in `metropolis.py` --- pymc/aesaraf.py | 4 +- pymc/step_methods/arraystep.py | 2 +- pymc/step_methods/metropolis.py | 95 ++++++++++++++------------------- scripts/run_mypy.py | 1 + 4 files changed, 43 insertions(+), 59 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 37aba7402e7..0fdb53acfdf 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -35,7 +35,7 @@ from aeppl.logprob import CheckParameterValue from aeppl.transforms import RVTransform from aesara import scalar -from aesara.compile.mode import Mode, get_mode +from aesara.compile import Function, Mode, get_mode from aesara.gradient import grad from aesara.graph import node_rewriter, rewrite_graph from aesara.graph.basic import ( @@ -1044,7 +1044,7 @@ def compile_pymc( random_seed: SeedSequenceSeed = None, mode=None, **kwargs, -) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]: +) -> Function: """Use ``aesara.function`` with specialized pymc rewrites always enabled. This function also ensures shared RandomState/Generator used by RandomVariables diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index dad7fb27943..d4055dd43ed 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -273,7 +273,7 @@ def step(self, point) -> Tuple[PointType, StatsType]: return super().step(point) -def metrop_select(mr, q, q0): +def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> Tuple[np.ndarray, bool]: """Perform rejection/acceptance step for Metropolis class samplers. Returns the new sample q if a uniform random number is less than the diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index b0243e2b0e6..d3d2b8f036c 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Callable, Dict, List, Optional, Tuple -import aesara import numpy as np import numpy.random as nr import scipy.linalg @@ -68,38 +67,29 @@ def __init__(self, s): class NormalProposal(Proposal): def __call__(self, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr - return rng.normal(scale=self.s) + return (rng or nr).normal(scale=self.s) class UniformProposal(Proposal): def __call__(self, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr - return rng.uniform(low=-self.s, high=self.s, size=len(self.s)) + return (rng or nr).uniform(low=-self.s, high=self.s, size=len(self.s)) class CauchyProposal(Proposal): def __call__(self, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr - return rng.standard_cauchy(size=np.size(self.s)) * self.s + return (rng or nr).standard_cauchy(size=np.size(self.s)) * self.s class LaplaceProposal(Proposal): def __call__(self, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr size = np.size(self.s) - return (rng.standard_exponential(size=size) - rng.standard_exponential(size=size)) * self.s + r = rng or nr + return (r.standard_exponential(size=size) - r.standard_exponential(size=size)) * self.s class PoissonProposal(Proposal): def __call__(self, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr - return rng.poisson(lam=self.s, size=np.size(self.s)) - self.s + return (rng or nr).poisson(lam=self.s, size=np.size(self.s)) - self.s class MultivariateNormalProposal(Proposal): @@ -111,13 +101,12 @@ def __init__(self, s): self.chol = scipy.linalg.cholesky(s, lower=True) def __call__(self, num_draws=None, rng: Optional[np.random.Generator] = None): - if rng is None: - rng = nr + rng_ = rng or nr if num_draws is not None: - b = rng.normal(size=(self.n, num_draws)) + b = rng_.normal(size=(self.n, num_draws)) return np.dot(self.chol, b).T else: - b = rng.normal(size=self.n) + b = rng_.normal(size=self.n) return np.dot(self.chol, b) @@ -247,7 +236,7 @@ def reset_tuning(self): def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info - q0 = q0.data + q0d = q0.data if not self.steps_until_tune and self.tune: # Tune scaling parameter @@ -261,30 +250,30 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: if self.any_discrete: if self.all_discrete: delta = np.round(delta, 0).astype("int64") - q0 = q0.astype("int64") - q = (q0 + delta).astype("int64") + q0d = q0d.astype("int64") + q = (q0d + delta).astype("int64") else: delta[self.discrete] = np.round(delta[self.discrete], 0) - q = q0 + delta + q = q0d + delta else: - q = floatX(q0 + delta) + q = floatX(q0d + delta) if self.elemwise_update: - q_temp = q0.copy() + q_temp = q0d.copy() # Shuffle order of updates (probably we don't need to do this in every step) np.random.shuffle(self.enum_dims) for i in self.enum_dims: q_temp[i] = q[i] - accept_rate_i = self.delta_logp(q_temp, q0) - q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0) + accept_rate_i = self.delta_logp(q_temp, q0d) + q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d) q_temp[i] = q_temp_[i] self.accept_rate_iter[i] = accept_rate_i self.accepted_iter[i] = accepted_i self.accepted_sum[i] += accepted_i q = q_temp else: - accept_rate = self.delta_logp(q, q0) - q, accepted = metrop_select(accept_rate, q, q0) + accept_rate = self.delta_logp(q, q0d) + q, accepted = metrop_select(accept_rate, q, q0d) self.accept_rate_iter = accept_rate self.accepted_iter = accepted self.accepted_sum += accepted @@ -399,11 +388,11 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): super().__init__(vars, [model.compile_logp()]) - def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: logp = args[0] - logp_q0 = logp(q0) - point_map_info = q0.point_map_info - q0 = q0.data + logp_q0 = logp(apoint) + point_map_info = apoint.point_map_info + q0 = apoint.data # Convert adaptive_scale_factor to a jump probability p_jump = 1.0 - 0.5**self.scaling @@ -425,9 +414,7 @@ def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: "p_jump": p_jump, } - q_new = RaveledVars(q_new, point_map_info) - - return q_new, [stats] + return RaveledVars(q_new, point_map_info), [stats] @staticmethod def competence(var): @@ -501,13 +488,13 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): super().__init__(vars, [model.compile_logp()]) - def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: logp: Callable[[RaveledVars], np.ndarray] = args[0] order = self.order if self.shuffle_dims: nr.shuffle(order) - q = RaveledVars(np.copy(q0.data), q0.point_map_info) + q = RaveledVars(np.copy(apoint.data), apoint.point_map_info) logp_curr = logp(q) @@ -805,7 +792,7 @@ def __init__( def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info - q0 = q0.data + q0d = q0.data if not self.steps_until_tune and self.tune: if self.tune == "scaling": @@ -824,10 +811,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: r1 = DictToArrayBijection.map(self.population[ir1]) r2 = DictToArrayBijection.map(self.population[ir2]) # propose a jump - q = floatX(q0 + self.lamb * (r1.data - r2.data) + epsilon) + q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon) - accept = self.delta_logp(q, q0) - q_new, accepted = metrop_select(accept, q, q0) + accept = self.delta_logp(q, q0d) + q_new, accepted = metrop_select(accept, q, q0d) self.accepted += accepted self.steps_until_tune -= 1 @@ -840,9 +827,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: "accepted": accepted, } - q_new = RaveledVars(q_new, point_map_info) - - return q_new, [stats] + return RaveledVars(q_new, point_map_info), [stats] @staticmethod def competence(var, has_grad): @@ -948,7 +933,7 @@ def __init__( self.accepted = 0 # cache local history for the Z-proposals - self._history = [] + self._history: List[np.ndarray] = [] # remember initial settings before tuning so they can be reset self._untuned_settings = dict( scaling=self.scaling, @@ -974,7 +959,7 @@ def reset_tuning(self): def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info - q0 = q0.data + q0d = q0.data # same tuning scheme as DEMetropolis if not self.steps_until_tune and self.tune: @@ -1001,13 +986,13 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: z1 = self._history[iz1] z2 = self._history[iz2] # propose a jump - q = floatX(q0 + self.lamb * (z1 - z2) + epsilon) + q = floatX(q0d + self.lamb * (z1 - z2) + epsilon) else: # propose just with noise in the first 2 iterations - q = floatX(q0 + epsilon) + q = floatX(q0d + epsilon) - accept = self.delta_logp(q, q0) - q_new, accepted = metrop_select(accept, q, q0) + accept = self.delta_logp(q, q0d) + q_new, accepted = metrop_select(accept, q, q0d) self.accepted += accepted self._history.append(q_new) @@ -1021,9 +1006,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: "accepted": accepted, } - q_new = RaveledVars(q_new, point_map_info) - - return q_new, [stats] + return RaveledVars(q_new, point_map_info), [stats] def stop_tuning(self): """At the end of the tuning phase, this method removes the first x% of the history @@ -1053,7 +1036,7 @@ def delta_logp( logp: at.TensorVariable, vars: List[at.TensorVariable], shared: Dict[at.TensorVariable, at.sharedvar.TensorSharedVariable], -) -> aesara.compile.Function: +): [logp0], inarray0 = join_nonshared_inputs( point=point, outputs=[logp], inputs=vars, shared_inputs=shared ) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 35dc1c91d32..f34ad7d7a5b 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -64,6 +64,7 @@ pymc/step_methods/__init__.py pymc/step_methods/arraystep.py pymc/step_methods/compound.py +pymc/step_methods/metropolis.py pymc/step_methods/hmc/__init__.py pymc/step_methods/hmc/base_hmc.py pymc/step_methods/hmc/hmc.py