Skip to content

Commit

Permalink
Make VI work on v4 (#4582)
Browse files Browse the repository at this point in the history
* resolve merge conflicts

* start fixing things

* make a simple test pass

* fix some more tests

* fix some more tests

* add scaling for VI

* add shape check

* aet -> at

* use rvs_to_values from the model in opi.py

* refactor cloning routines (fix pymc references)

* Run pre-commit and include VI tests in pytest workflow (rebase)

* Run pre-commit and include VI tests in pytest workflow

* seems like Grouped inference not working

* spot an error in a simple test case

* fix the test case with grouping

* fix sampling with changed shape

* remove not implemented error for local inference

* support inferencedata

* get rid of shape error for batched mvnormal

* do not support AEVB with an error message

* fix some meore tests

* fix some more tests

* fix full rank test

* fix tests

* test vi

* fix conversion function

* propagate model

* fix

* fix elbo

* fix elbo full rank

* Fixing broken scaling with float32

* ignore a nasty test

* xfail one test with float 32

* fix pre commit

* fix import

* fix import.1

* Update pymc/variational/opvi.py

Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>

* fix docstrings

* fix error with nans

* remove TODO comments

* print statements to logging

* revert bart test

* fix pylint issues

* fix test bart

* fix interence_data in init

* ignore pickling problems

* fix aevb test

* fix name error

* xfail test ramdom fn

* mark xfail

* refactor test

* xfail fix

* fix xfail syntax

* pytest

* test fixed

* 5090 fixed

* do not test local flows

* change model.logpt not to return float

* add a test for the replacenent in the graph

* fix sample node functionality

* Fix test with var replacement

* add uncommited changes

* resolve @ricardoV94's comment about initial point

* restore test_bart.py as in main branch

* resolve duplicated _get_scaling function

* change job order

* use commit initial point in the test file

* use compute initial point in the opvi.py

* remove unnessesary pattern broadcast

* mark test as xfail before aesara release

* Do not mark anything but just wait for the new release

* use compute_initial_point

* Update pymc/variational/opvi.py

Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>

* run upgraded pre-commit

* move pipe back

* Update pymc/variational/opvi.py

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Update pymc/variational/opvi.py

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Update pymc/variational/opvi.py

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>

* Add removed newline

* Use compile_pymc instead of aesara.function

* Replace None by empty list in output

* Apply suggestions from code review

Co-authored-by: Michael Osthege <michael.osthege@outlook.com>

Co-authored-by: Michael Osthege <m.osthege@fz-juelich.de>
Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
  • Loading branch information
6 people authored Feb 25, 2022
1 parent ac2b82e commit e987950
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 171 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ jobs:
--ignore=pymc/tests/test_step.py
--ignore=pymc/tests/test_tuning.py
--ignore=pymc/tests/test_transforms.py
--ignore=pymc/tests/test_variational_inference.py
--ignore=pymc/tests/test_sampling_jax.py
--ignore=pymc/tests/test_dist_math.py
--ignore=pymc/tests/test_minibatches.py
Expand Down Expand Up @@ -169,6 +168,7 @@ jobs:
pymc/tests/test_distributions_random.py
pymc/tests/test_distributions_moments.py
pymc/tests/test_distributions_timeseries.py
pymc/tests/test_variational_inference.py
- |
pymc/tests/test_parallel_sampling.py
pymc/tests/test_sampling.py
Expand Down
1 change: 1 addition & 0 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def is_data(name, var) -> bool:
and var not in self.model.observed_RVs
and var not in self.model.free_RVs
and var not in self.model.potentials
and var not in self.model.value_vars
and (self.observations is None or name not in self.observations)
and isinstance(var, (Constant, SharedVariable))
)
Expand Down
25 changes: 15 additions & 10 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from collections.abc import Mapping
from functools import singledispatch
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

import aesara
import aesara.tensor as at
import numpy as np

Expand Down Expand Up @@ -43,15 +44,17 @@ def logp_transform(op: Op):
return None


def _get_scaling(total_size, shape, ndim):
def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int):
"""
Gets scaling constant for logp
Gets scaling constant for logp.
Parameters
----------
total_size: int or list[int]
total_size: Optional[int|List[int]]
size of a fully observed data without minibatching,
`None` means data is fully observed
shape: shape
shape to scale
shape of an observed data
ndim: int
ndim hint
Expand All @@ -60,7 +63,7 @@ def _get_scaling(total_size, shape, ndim):
scalar
"""
if total_size is None:
coef = floatX(1)
coef = 1.0
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
Expand Down Expand Up @@ -90,21 +93,23 @@ def _get_scaling(total_size, shape, ndim):
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
return floatX(1)
coef = 1.0
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
begin_coef = [
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
]
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = at.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return at.as_tensor(floatX(coef))
return at.as_tensor(coef, dtype=aesara.config.floatX)


subtensor_types = (
Expand Down
4 changes: 3 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, Minibatch
from pymc.distributions import joint_logpt, logp_transform
from pymc.distributions.logprob import _get_scaling
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
from pymc.initial_point import make_initial_point_fn
from pymc.math import flatten_list
Expand Down Expand Up @@ -1238,6 +1239,7 @@ def register_rv(
name = self.name_for(name)
rv_var.name = name
rv_var.tag.total_size = total_size
rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim)

# Associate previously unknown dimension names with
# the length of the corresponding RV dimension.
Expand Down Expand Up @@ -1870,7 +1872,7 @@ def Potential(name, var, model=None):
"""
model = modelcontext(model)
var.name = model.name_for(name)
var.tag.scaling = None
var.tag.scaling = 1.0
model.potentials.append(var)
model.add_random_variable(var)

Expand Down
6 changes: 3 additions & 3 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
std_apoint = approx.std.eval()
cov = std_apoint**2
mean = approx.mean.get_value()
Expand All @@ -2402,7 +2402,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "advi_map":
Expand All @@ -2416,7 +2416,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
Expand Down
Loading

0 comments on commit e987950

Please sign in to comment.