Skip to content

Commit

Permalink
Revert "Add start_sigma to ADVI (pymc-devs#6096)"
Browse files Browse the repository at this point in the history
This reverts commit ec27b5c.
  • Loading branch information
ricardoV94 committed Sep 15, 2022
1 parent c53cd2f commit e8ed9b6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 61 deletions.
22 changes: 0 additions & 22 deletions pymc/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,28 +571,6 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)


def test_fit_start(inference_spec, simple_model):
mu_init = 17
mu_sigma_init = 13

with simple_model:
if type(inference_spec()) == ADVI:
has_start_sigma = True
else:
has_start_sigma = False

kw = {"start": {"mu": mu_init}}
if has_start_sigma:
kw.update({"start_sigma": {"mu": mu_sigma_init}})

with simple_model:
inference = inference_spec(**kw)
trace = inference.fit(n=0).sample(10000)
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
if has_start_sigma:
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)


def test_profile(inference):
inference.run_profiling(n=100).summary()

Expand Down
21 changes: 3 additions & 18 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,12 @@ def std(self):
def __init_group__(self, group):
super().__init_group__(group)
if not self._check_user_params():
self.shared_params = self.create_shared_params(
self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
)
self.shared_params = self.create_shared_params(self._kwargs.get("start", None))
self._finalize_init()

def create_shared_params(self, start=None, start_sigma=None):
# NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
# `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
# variables are given by `self.group` and that the mapping between original variables and flat array is given
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
# future code changes that break this assumption.
def create_shared_params(self, start=None):
start = self._prepare_start(start)
rho1 = np.zeros((self.ddim,))

if start_sigma is not None:
for name, slice_, *_ in self.ordering.values():
sigma = start_sigma.get(name)
if sigma is not None:
rho1[slice_] = np.log(np.expm1(np.abs(sigma)))
rho = rho1

rho = np.zeros((self.ddim,))
return {
"mu": aesara.shared(pm.floatX(start), "mu"),
"rho": aesara.shared(pm.floatX(rho), "rho"),
Expand Down
29 changes: 8 additions & 21 deletions pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ def _infmean(input_array):
)
)
else:
if n == 0:
logger.info(f"Initialization only")
elif n < 10:
if n < 10:
logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}")
else:
avg_loss = _infmean(scores[max(0, i - 1000) : i + 1])
Expand Down Expand Up @@ -435,10 +433,8 @@ class ADVI(KLqp):
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
start: `dict[str, np.ndarray]` or `StartDict`
start: `Point`
starting point for inference
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
References
----------
Expand Down Expand Up @@ -468,7 +464,7 @@ class FullRankADVI(KLqp):
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
start: `dict[str, np.ndarray]` or `StartDict`
start: `Point`
starting point for inference
References
Expand Down Expand Up @@ -536,11 +532,13 @@ class SVGD(ImplicitGradient):
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
temperature: float
parameter responsible for exploration, higher temperature gives more broad posterior estimate
start: `dict[str, np.ndarray]` or `StartDict`
start: `dict`
initial point for inference
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
start: `Point`
starting point for inference
kwargs: other keyword arguments passed to estimator
References
Expand Down Expand Up @@ -631,11 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
"is often **underestimated** when using temperature = 1."
)
if approx is None:
approx = FullRank(
model=kwargs.pop("model", None),
random_seed=kwargs.pop("random_seed", None),
start=kwargs.pop("start", None),
)
approx = FullRank(model=kwargs.pop("model", None))
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)

def fit(
Expand Down Expand Up @@ -666,7 +660,6 @@ def fit(
model=None,
random_seed=None,
start=None,
start_sigma=None,
inf_kwargs=None,
**kwargs,
):
Expand All @@ -691,10 +684,8 @@ def fit(
valid value to create instance specific one
inf_kwargs: dict
additional kwargs passed to :class:`Inference`
start: `dict[str, np.ndarray]` or `StartDict`
start: `Point`
starting point for inference
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
Other Parameters
----------------
Expand Down Expand Up @@ -737,10 +728,6 @@ def fit(
inf_kwargs["random_seed"] = random_seed
if start is not None:
inf_kwargs["start"] = start
if start_sigma is not None:
if method != "advi":
raise NotImplementedError("start_sigma is only available for method advi")
inf_kwargs["start_sigma"] = start_sigma
if model is None:
model = pm.modelcontext(model)
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)
Expand Down

0 comments on commit e8ed9b6

Please sign in to comment.