Skip to content

Commit

Permalink
Fix remaining type problems in metropolis.py
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Nov 19, 2022
1 parent 35e7c59 commit 9e09aa4
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 59 deletions.
4 changes: 2 additions & 2 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from aeppl.logprob import CheckParameterValue
from aeppl.transforms import RVTransform
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.compile import Function, Mode, get_mode
from aesara.gradient import grad
from aesara.graph import node_rewriter, rewrite_graph
from aesara.graph.basic import (
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def compile_pymc(
random_seed: SeedSequenceSeed = None,
mode=None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
) -> Function:
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.
This function also ensures shared RandomState/Generator used by RandomVariables
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def step(self, point) -> Tuple[PointType, StatsType]:
return super().step(point)


def metrop_select(mr, q, q0):
def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> Tuple[np.ndarray, bool]:
"""Perform rejection/acceptance step for Metropolis class samplers.
Returns the new sample q if a uniform random number is less than the
Expand Down
95 changes: 39 additions & 56 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from typing import Callable, Dict, List, Optional, Tuple

import aesara
import numpy as np
import numpy.random as nr
import scipy.linalg
Expand Down Expand Up @@ -68,38 +67,29 @@ def __init__(self, s):

class NormalProposal(Proposal):
def __call__(self, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
return rng.normal(scale=self.s)
return (rng or nr).normal(scale=self.s)


class UniformProposal(Proposal):
def __call__(self, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
return rng.uniform(low=-self.s, high=self.s, size=len(self.s))
return (rng or nr).uniform(low=-self.s, high=self.s, size=len(self.s))


class CauchyProposal(Proposal):
def __call__(self, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
return rng.standard_cauchy(size=np.size(self.s)) * self.s
return (rng or nr).standard_cauchy(size=np.size(self.s)) * self.s


class LaplaceProposal(Proposal):
def __call__(self, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
size = np.size(self.s)
return (rng.standard_exponential(size=size) - rng.standard_exponential(size=size)) * self.s
r = rng or nr
return (r.standard_exponential(size=size) - r.standard_exponential(size=size)) * self.s


class PoissonProposal(Proposal):
def __call__(self, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
return rng.poisson(lam=self.s, size=np.size(self.s)) - self.s
return (rng or nr).poisson(lam=self.s, size=np.size(self.s)) - self.s


class MultivariateNormalProposal(Proposal):
Expand All @@ -111,13 +101,12 @@ def __init__(self, s):
self.chol = scipy.linalg.cholesky(s, lower=True)

def __call__(self, num_draws=None, rng: Optional[np.random.Generator] = None):
if rng is None:
rng = nr
rng_ = rng or nr
if num_draws is not None:
b = rng.normal(size=(self.n, num_draws))
b = rng_.normal(size=(self.n, num_draws))
return np.dot(self.chol, b).T
else:
b = rng.normal(size=self.n)
b = rng_.normal(size=self.n)
return np.dot(self.chol, b)


Expand Down Expand Up @@ -247,7 +236,7 @@ def reset_tuning(self):
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:

point_map_info = q0.point_map_info
q0 = q0.data
q0d = q0.data

if not self.steps_until_tune and self.tune:
# Tune scaling parameter
Expand All @@ -261,30 +250,30 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
if self.any_discrete:
if self.all_discrete:
delta = np.round(delta, 0).astype("int64")
q0 = q0.astype("int64")
q = (q0 + delta).astype("int64")
q0d = q0d.astype("int64")
q = (q0d + delta).astype("int64")
else:
delta[self.discrete] = np.round(delta[self.discrete], 0)
q = q0 + delta
q = q0d + delta
else:
q = floatX(q0 + delta)
q = floatX(q0d + delta)

if self.elemwise_update:
q_temp = q0.copy()
q_temp = q0d.copy()
# Shuffle order of updates (probably we don't need to do this in every step)
np.random.shuffle(self.enum_dims)
for i in self.enum_dims:
q_temp[i] = q[i]
accept_rate_i = self.delta_logp(q_temp, q0)
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0)
accept_rate_i = self.delta_logp(q_temp, q0d)
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d)
q_temp[i] = q_temp_[i]
self.accept_rate_iter[i] = accept_rate_i
self.accepted_iter[i] = accepted_i
self.accepted_sum[i] += accepted_i
q = q_temp
else:
accept_rate = self.delta_logp(q, q0)
q, accepted = metrop_select(accept_rate, q, q0)
accept_rate = self.delta_logp(q, q0d)
q, accepted = metrop_select(accept_rate, q, q0d)
self.accept_rate_iter = accept_rate
self.accepted_iter = accepted
self.accepted_sum += accepted
Expand Down Expand Up @@ -399,11 +388,11 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):

super().__init__(vars, [model.compile_logp()])

def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
logp = args[0]
logp_q0 = logp(q0)
point_map_info = q0.point_map_info
q0 = q0.data
logp_q0 = logp(apoint)
point_map_info = apoint.point_map_info
q0 = apoint.data

# Convert adaptive_scale_factor to a jump probability
p_jump = 1.0 - 0.5**self.scaling
Expand All @@ -425,9 +414,7 @@ def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
"p_jump": p_jump,
}

q_new = RaveledVars(q_new, point_map_info)

return q_new, [stats]
return RaveledVars(q_new, point_map_info), [stats]

@staticmethod
def competence(var):
Expand Down Expand Up @@ -501,13 +488,13 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):

super().__init__(vars, [model.compile_logp()])

def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
logp: Callable[[RaveledVars], np.ndarray] = args[0]
order = self.order
if self.shuffle_dims:
nr.shuffle(order)

q = RaveledVars(np.copy(q0.data), q0.point_map_info)
q = RaveledVars(np.copy(apoint.data), apoint.point_map_info)

logp_curr = logp(q)

Expand Down Expand Up @@ -805,7 +792,7 @@ def __init__(
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:

point_map_info = q0.point_map_info
q0 = q0.data
q0d = q0.data

if not self.steps_until_tune and self.tune:
if self.tune == "scaling":
Expand All @@ -824,10 +811,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
r1 = DictToArrayBijection.map(self.population[ir1])
r2 = DictToArrayBijection.map(self.population[ir2])
# propose a jump
q = floatX(q0 + self.lamb * (r1.data - r2.data) + epsilon)
q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon)

accept = self.delta_logp(q, q0)
q_new, accepted = metrop_select(accept, q, q0)
accept = self.delta_logp(q, q0d)
q_new, accepted = metrop_select(accept, q, q0d)
self.accepted += accepted

self.steps_until_tune -= 1
Expand All @@ -840,9 +827,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
"accepted": accepted,
}

q_new = RaveledVars(q_new, point_map_info)

return q_new, [stats]
return RaveledVars(q_new, point_map_info), [stats]

@staticmethod
def competence(var, has_grad):
Expand Down Expand Up @@ -948,7 +933,7 @@ def __init__(
self.accepted = 0

# cache local history for the Z-proposals
self._history = []
self._history: List[np.ndarray] = []
# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(
scaling=self.scaling,
Expand All @@ -974,7 +959,7 @@ def reset_tuning(self):
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:

point_map_info = q0.point_map_info
q0 = q0.data
q0d = q0.data

# same tuning scheme as DEMetropolis
if not self.steps_until_tune and self.tune:
Expand All @@ -1001,13 +986,13 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
z1 = self._history[iz1]
z2 = self._history[iz2]
# propose a jump
q = floatX(q0 + self.lamb * (z1 - z2) + epsilon)
q = floatX(q0d + self.lamb * (z1 - z2) + epsilon)
else:
# propose just with noise in the first 2 iterations
q = floatX(q0 + epsilon)
q = floatX(q0d + epsilon)

accept = self.delta_logp(q, q0)
q_new, accepted = metrop_select(accept, q, q0)
accept = self.delta_logp(q, q0d)
q_new, accepted = metrop_select(accept, q, q0d)
self.accepted += accepted
self._history.append(q_new)

Expand All @@ -1021,9 +1006,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
"accepted": accepted,
}

q_new = RaveledVars(q_new, point_map_info)

return q_new, [stats]
return RaveledVars(q_new, point_map_info), [stats]

def stop_tuning(self):
"""At the end of the tuning phase, this method removes the first x% of the history
Expand Down Expand Up @@ -1053,7 +1036,7 @@ def delta_logp(
logp: at.TensorVariable,
vars: List[at.TensorVariable],
shared: Dict[at.TensorVariable, at.sharedvar.TensorSharedVariable],
) -> aesara.compile.Function:
):
[logp0], inarray0 = join_nonshared_inputs(
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
)
Expand Down
1 change: 1 addition & 0 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
pymc/step_methods/__init__.py
pymc/step_methods/arraystep.py
pymc/step_methods/compound.py
pymc/step_methods/metropolis.py
pymc/step_methods/hmc/__init__.py
pymc/step_methods/hmc/base_hmc.py
pymc/step_methods/hmc/hmc.py
Expand Down

0 comments on commit 9e09aa4

Please sign in to comment.