Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VI work on v4 #4582

Merged
merged 97 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
e59ebba
resolve merge conflicts
ferrine Mar 31, 2021
8aa290f
start fixing things
ferrine Jun 6, 2021
f96c626
make a simple test pass
ferrine Jun 7, 2021
0af6dac
fix some more tests
ferrine Jun 11, 2021
7e60bcc
fix some more tests
ferrine Jun 11, 2021
e0fbb98
add scaling for VI
ferrine Jun 18, 2021
e515217
add shape check
ferrine Jun 18, 2021
6dfc18c
aet -> at
ferrine Jun 18, 2021
39e635b
use rvs_to_values from the model in opi.py
ferrine Jun 21, 2021
9f61021
refactor cloning routines (fix pymc references)
ferrine Jun 21, 2021
8909ac7
Run pre-commit and include VI tests in pytest workflow (rebase)
michaelosthege Jul 2, 2021
1076fa1
Run pre-commit and include VI tests in pytest workflow
michaelosthege Jul 2, 2021
7e73cd7
seems like Grouped inference not working
ferrine Jul 28, 2021
64ba837
spot an error in a simple test case
ferrine Aug 3, 2021
4b91bce
fix the test case with grouping
ferrine Aug 3, 2021
c81458a
fix sampling with changed shape
ferrine Aug 3, 2021
11ef0b6
remove not implemented error for local inference
ferrine Aug 3, 2021
98dd81d
support inferencedata
ferrine Aug 8, 2021
c08eea3
get rid of shape error for batched mvnormal
ferrine Aug 8, 2021
77443f5
do not support AEVB with an error message
ferrine Aug 8, 2021
215f92b
fix some meore tests
ferrine Sep 8, 2021
94a28e5
fix some more tests
ferrine Sep 16, 2021
509f7ba
fix full rank test
ferrine Sep 16, 2021
c0c8fb9
fix tests
ferrine Sep 16, 2021
7745ac6
test vi
ferrine Sep 16, 2021
3dafc10
fix conversion function
ferrine Sep 16, 2021
2752ebd
propagate model
ferrine Sep 16, 2021
ff5f8c8
fix
ferrine Sep 16, 2021
c154063
fix elbo
ferrine Sep 19, 2021
af9c24d
fix elbo full rank
ferrine Sep 19, 2021
a9d40ef
Fixing broken scaling with float32
ferrine Sep 23, 2021
54d2a43
ignore a nasty test
ferrine Sep 23, 2021
6d46a2f
xfail one test with float 32
ferrine Sep 26, 2021
2ce5a7d
fix pre commit
ferrine Sep 26, 2021
69b9486
fix import
ferrine Sep 26, 2021
1beec12
fix import.1
ferrine Sep 26, 2021
894d5ce
Update pymc/variational/opvi.py
ferrine Sep 27, 2021
8d2ec8b
fix docstrings
ferrine Sep 27, 2021
60e5653
Merge branch 'v4-4523' of github.com:pymc-devs/pymc3 into v4-4523
ferrine Sep 27, 2021
c03352e
fix error with nans
ferrine Oct 14, 2021
00c1d14
remove TODO comments
ferrine Oct 14, 2021
27b4261
Merge branch 'main' into v4-4523
ferrine Oct 14, 2021
694286a
print statements to logging
ferrine Oct 14, 2021
8dba7d5
revert bart test
ferrine Oct 14, 2021
6a2fc35
apply changes from main
ferrine Oct 15, 2021
3a5915a
fix pylint issues
ferrine Oct 15, 2021
f6d9b98
fix test bart
ferrine Oct 20, 2021
9a79e27
fix interence_data in init
ferrine Oct 20, 2021
deafa96
ignore pickling problems
ferrine Oct 26, 2021
0f45e73
fix aevb test
ferrine Oct 26, 2021
4957765
Merge branch 'main' into v4-4523
ferrine Oct 26, 2021
0ab2fba
Merge branch 'main' into v4-4523
ferrine Nov 2, 2021
b1b4938
Merge branch 'main' into v4-4523
ferrine Nov 7, 2021
8d48870
fix name error
ferrine Nov 7, 2021
6efd630
xfail test ramdom fn
ferrine Nov 7, 2021
b2e9c0f
mark xfail
ferrine Nov 7, 2021
a92aad8
refactor test
ferrine Nov 7, 2021
f253417
xfail fix
ferrine Nov 7, 2021
f09d33a
fix xfail syntax
ferrine Nov 8, 2021
19ea8c9
pytest
ferrine Nov 8, 2021
f14cbc1
test fixed
ferrine Nov 8, 2021
02fc30f
5090 fixed
ferrine Nov 8, 2021
baefac6
do not test local flows
ferrine Nov 15, 2021
bf38d33
Merge branch 'main' into v4-4523
ferrine Nov 16, 2021
beb75ba
change model.logpt not to return float
ferrine Nov 16, 2021
74e19fd
Merge branch 'main' into v4-4523
ferrine Nov 23, 2021
c2d24de
add a test for the replacenent in the graph
ferrine Nov 27, 2021
8fdf9a2
Merge branch 'main' into v4-4523
michaelosthege Dec 20, 2021
3943e0f
merge main into PR
ferrine Jan 16, 2022
13a970e
fix sample node functionality
ferrine Jan 16, 2022
994fba5
Fix test with var replacement
ferrine Jan 16, 2022
6090029
add uncommited changes
ferrine Jan 16, 2022
48041f5
resolve @ricardoV94's comment about initial point
ferrine Jan 23, 2022
cb0fee9
restore test_bart.py as in main branch
ferrine Jan 23, 2022
c5911ac
resolve duplicated _get_scaling function
ferrine Jan 23, 2022
a466ffc
Merge branch 'main' into v4-4523
ferrine Jan 23, 2022
78ca582
change job order
ferrine Jan 23, 2022
e4cbb33
use commit initial point in the test file
ferrine Jan 23, 2022
8fad157
use compute initial point in the opvi.py
ferrine Jan 23, 2022
7f281bd
remove unnessesary pattern broadcast
ferrine Jan 24, 2022
8e8f63e
mark test as xfail before aesara release
ferrine Jan 24, 2022
72a7556
Do not mark anything but just wait for the new release
ferrine Jan 24, 2022
57e8342
Merge branch 'main' into v4-4523
ferrine Jan 30, 2022
a6f54ac
use compute_initial_point
ferrine Feb 13, 2022
1ee5536
Merge branch 'main' into v4-4523
ferrine Feb 14, 2022
4fab824
Merge branch 'main' into v4-4523
ferrine Feb 20, 2022
b4a2f62
Update pymc/variational/opvi.py
ferrine Feb 20, 2022
f9d16a7
run upgraded pre-commit
ferrine Feb 20, 2022
bc712ef
Merge branch 'v4-4523' of github.com:pymc-devs/pymc3 into v4-4523
ferrine Feb 20, 2022
6a3ee61
move pipe back
ferrine Feb 20, 2022
cd2cda9
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
670edb9
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
01fb223
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
32006cd
Add removed newline
ricardoV94 Feb 23, 2022
1cb1418
Use compile_pymc instead of aesara.function
ricardoV94 Feb 23, 2022
ceddb5c
Replace None by empty list in output
ricardoV94 Feb 23, 2022
ef5f91b
Apply suggestions from code review
ferrine Feb 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ jobs:
--ignore=pymc/tests/test_step.py
--ignore=pymc/tests/test_tuning.py
--ignore=pymc/tests/test_transforms.py
--ignore=pymc/tests/test_variational_inference.py
--ignore=pymc/tests/test_sampling_jax.py
--ignore=pymc/tests/test_dist_math.py
--ignore=pymc/tests/test_minibatches.py
Expand All @@ -67,9 +66,7 @@ jobs:
--ignore=pymc/tests/test_bart.py
--ignore=pymc/tests/test_missing.py

- |
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
pymc/tests/test_distributions.py

ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
- pymc/tests/test_distributions.py
- |
pymc/tests/test_modelcontext.py
pymc/tests/test_dist_math.py
Expand Down Expand Up @@ -165,6 +162,7 @@ jobs:
pymc/tests/test_distributions_random.py
pymc/tests/test_distributions_moments.py
pymc/tests/test_distributions_timeseries.py
pymc/tests/test_variational_inference.py
- |
pymc/tests/test_parallel_sampling.py
pymc/tests/test_sampling.py
Expand Down
1 change: 1 addition & 0 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def is_data(name, var) -> bool:
and var not in self.model.observed_RVs
and var not in self.model.free_RVs
and var not in self.model.potentials
and var not in self.model.value_vars
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
and (self.observations is None or name not in self.observations)
and isinstance(var, (Constant, SharedVariable))
)
Expand Down
21 changes: 13 additions & 8 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import singledispatch
from typing import Dict, List, Optional, Union
ferrine marked this conversation as resolved.
Show resolved Hide resolved

import aesara
import aesara.tensor as at
import numpy as np

Expand Down Expand Up @@ -45,13 +46,15 @@ def logp_transform(op: Op):

def _get_scaling(total_size, shape, ndim):
ferrine marked this conversation as resolved.
Show resolved Hide resolved
"""
Gets scaling constant for logp
Gets scaling constant for logp.

Parameters
----------
total_size: int or list[int]
total_size: Optional[int|List[int]]
size of a fully observed data without minibatching,
`None` means data is fully observed
shape: shape
shape to scale
shape of an observed data
ndim: int
ndim hint

Expand All @@ -60,7 +63,7 @@ def _get_scaling(total_size, shape, ndim):
scalar
"""
if total_size is None:
coef = floatX(1)
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
coef = 1.0
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
Expand Down Expand Up @@ -90,21 +93,23 @@ def _get_scaling(total_size, shape, ndim):
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
return floatX(1)
coef = 1.0
ferrine marked this conversation as resolved.
Show resolved Hide resolved
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
begin_coef = [
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
]
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = at.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return at.as_tensor(floatX(coef))
return at.as_tensor(coef, dtype=aesara.config.floatX)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


subtensor_types = (
Expand Down
4 changes: 3 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, Minibatch
from pymc.distributions import joint_logpt, logp_transform
from pymc.distributions.logprob import _get_scaling
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
from pymc.initial_point import make_initial_point_fn
from pymc.math import flatten_list
Expand Down Expand Up @@ -1235,6 +1236,7 @@ def register_rv(
name = self.name_for(name)
rv_var.name = name
rv_var.tag.total_size = total_size
rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim)

# Associate previously unknown dimension names with
# the length of the corresponding RV dimension.
Expand Down Expand Up @@ -1870,7 +1872,7 @@ def Potential(name, var, model=None):
"""
model = modelcontext(model)
var.name = model.name_for(name)
var.tag.scaling = None
var.tag.scaling = 1.0
model.potentials.append(var)
model.add_random_variable(var)

Expand Down
6 changes: 3 additions & 3 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2365,7 +2365,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
std_apoint = approx.std.eval()
cov = std_apoint ** 2
mean = approx.mean.get_value()
Expand All @@ -2382,7 +2382,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "advi_map":
Expand All @@ -2396,7 +2396,7 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
initial_points = list(approx.sample(draws=chains))
initial_points = list(approx.sample(draws=chains, return_inferencedata=False))
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
Expand Down
Loading