Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement and switch to lazy initval evaluation framework #4983

Merged
merged 10 commits into from
Oct 14, 2021
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
# → pytest will run only these files
- |
--ignore=pymc/tests/test_distributions_timeseries.py
--ignore=pymc/tests/test_initvals.py
--ignore=pymc/tests/test_initial_point.py
--ignore=pymc/tests/test_mixture.py
--ignore=pymc/tests/test_model_graph.py
--ignore=pymc/tests/test_modelcontext.py
Expand Down Expand Up @@ -61,7 +61,7 @@ jobs:
--ignore=pymc/tests/test_idata_conversion.py

- |
pymc/tests/test_initvals.py
pymc/tests/test_initial_point.py
pymc/tests/test_distributions.py

- |
Expand Down Expand Up @@ -154,7 +154,7 @@ jobs:
floatx: [float32, float64]
test-subset:
- |
pymc/tests/test_initvals.py
pymc/tests/test_initial_point.py
pymc/tests/test_distributions_random.py
pymc/tests/test_distributions_timeseries.py
- |
Expand Down
12 changes: 7 additions & 5 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ class NUTSInitSuite:
def time_glm_hierarchical_init(self, init):
"""How long does it take to run the initialization."""
with glm_hierarchical_model():
pm.init_nuts(init=init, chains=self.chains, progressbar=False)
pm.init_nuts(
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)

def track_glm_hierarchical_ess(self, init):
with glm_hierarchical_model():
start, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, random_seed=123
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)
t0 = time.time()
idata = pm.sample(
Expand All @@ -187,7 +189,7 @@ def track_glm_hierarchical_ess(self, init):
cores=4,
chains=self.chains,
start=start,
random_seed=100,
seeds=np.arange(self.chains),
progressbar=False,
compute_convergence_checks=False,
)
Expand All @@ -199,7 +201,7 @@ def track_marginal_mixture_model_ess(self, init):
model, start = mixture_model()
with model:
_, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, random_seed=123
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
)
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
t0 = time.time()
Expand All @@ -209,7 +211,7 @@ def track_marginal_mixture_model_ess(self, init):
cores=4,
chains=self.chains,
start=start,
random_seed=100,
seeds=np.arange(self.chains),
progressbar=False,
compute_convergence_checks=False,
)
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aesara>=2.1.0
- aesara>=2.2.2
- arviz>=0.11.2
- cachetools
- cloudpickle
Expand Down
10 changes: 8 additions & 2 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,13 @@ class Flat(Continuous):

rv_op = flat

def __new__(cls, *args, **kwargs):
kwargs.setdefault("initval", "moment")
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
res.tag.test_value = np.full(size, floatX(0.0))
return res

def get_moment(rv, size, *rv_inputs):
Expand Down Expand Up @@ -430,10 +433,13 @@ class HalfFlat(PositiveContinuous):

rv_op = halfflat

def __new__(cls, *args, **kwargs):
kwargs.setdefault("initval", "moment")
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
res.tag.test_value = np.full(size, floatX(1.0))
return res

def get_moment(value_var, size, *rv_inputs):
Expand Down
33 changes: 10 additions & 23 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def __new__(
dims : tuple, optional
A tuple of dimension names known to the model.
initval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
Numeric or symbolic untransformed initial value of matching shape,
or one of the following initial value strategies: "moment", "prior".
Depending on the sampler's settings, a random jitter may be added to numeric, symbolic
or moment-based initial values in the transformed space.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
Expand Down Expand Up @@ -600,31 +602,16 @@ def dist(cls, *args, **kwargs):
else:
dtype = cls.rv_op.dtype
ndim_supp = cls.rv_op.ndim_supp
if not hasattr(output.tag, "test_value"):
size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp
output.tag.test_value = np.zeros(size, dtype)
return output


def default_not_implemented(rv_name, method_name):
if method_name == "random":
# This is a hack to catch the NotImplementedError when creating the RV without random
# If the message starts with "Cannot sample from", then it uses the test_value as
# the initial_val.
message = (
f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} "
"keyword argument was not provided when the distribution was "
f"but this method had not been provided when the distribution was "
f"constructed. Please re-build your model and provide a callable "
f"to '{rv_name}'s {method_name} keyword argument.\n"
)
else:
message = (
f"Attempted to run {method_name} on the DensityDist '{rv_name}', "
f"but this method had not been provided when the distribution was "
f"constructed. Please re-build your model and provide a callable "
f"to '{rv_name}'s {method_name} keyword argument.\n"
)
message = (
f"Attempted to run {method_name} on the DensityDist '{rv_name}', "
f"but this method had not been provided when the distribution was "
f"constructed. Please re-build your model and provide a callable "
f"to '{rv_name}'s {method_name} keyword argument.\n"
)

def func(*args, **kwargs):
raise NotImplementedError(message)
Expand Down
Loading