Skip to content

Commit

Permalink
Remove weight argument in quadpotential add_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Sep 22, 2021
1 parent 2f9e5d1 commit b0b56ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
21 changes: 9 additions & 12 deletions pymc3/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ def __init__(
The number of initial samples that are just discarded and not used to estimate
the mass matrix.
early_update : bool
Whether to update the mass matrix live during the first half of the first
adaptation window.
Whether to update the mass matrix live during the first adaptation window.
store_mass_matrix_trace : bool
If true, store the mass matrix at each step of the adaptation. Only for debugging
purposes.
Expand Down Expand Up @@ -268,8 +267,8 @@ def update(self, sample, grad, tune):
return

if self._n_samples > self._discard_window:
self._foreground_var.add_sample(sample, weight=1)
self._background_var.add_sample(sample, weight=1)
self._foreground_var.add_sample(sample)
self._background_var.add_sample(sample)

if self._early_update or self._n_samples > self.adaptation_window:
self._update_from_weightvar(self._foreground_var)
Expand Down Expand Up @@ -344,15 +343,13 @@ def __init__(
if self.mean.shape != (nelem,):
raise ValueError("Invalid shape for initial mean.")

def add_sample(self, x, weight):
def add_sample(self, x):
x = np.asarray(x)
if weight != 1:
raise ValueError("Setting weight != 1 is not supported.")
self.n_samples += 1
old_diff = x - self.mean
self.mean[:] += old_diff / self.n_samples
new_diff = x - self.mean
self.raw_var[:] += weight * old_diff * new_diff
self.raw_var[:] += old_diff * new_diff

def current_variance(self, out=None):
if self.n_samples == 0:
Expand Down Expand Up @@ -666,8 +663,8 @@ def update(self, sample, grad, tune):
# Steps since previous update
delta = self._n_samples - self._previous_update

self._foreground_cov.add_sample(sample, weight=1)
self._background_cov.add_sample(sample, weight=1)
self._foreground_cov.add_sample(sample)
self._background_cov.add_sample(sample)

# Update the covariance matrix and recompute the Cholesky factorization
# every "update_window" steps
Expand Down Expand Up @@ -726,13 +723,13 @@ def __init__(
if self.mean.shape != (nelem,):
raise ValueError("Invalid shape for initial mean.")

def add_sample(self, x, weight):
def add_sample(self, x):
x = np.asarray(x)
self.n_samples += 1
old_diff = x - self.mean
self.mean[:] += old_diff / self.n_samples
new_diff = x - self.mean
self.raw_cov[:] += weight * new_diff[:, None] * old_diff[None, :]
self.raw_cov[:] += new_diff[:, None] * old_diff[None, :]

def current_covariance(self, out=None):
if self.n_samples == 0:
Expand Down
4 changes: 2 additions & 2 deletions pymc3/tests/test_quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_weighted_covariance(ndim=10, seed=5432):

est = quadpotential._WeightedCovariance(ndim)
for sample in samples:
est.add_sample(sample, 1)
est.add_sample(sample)
mu_est = est.current_mean()
cov_est = est.current_covariance()

Expand All @@ -184,7 +184,7 @@ def test_weighted_covariance(ndim=10, seed=5432):
10,
)
for sample in samples[10:]:
est2.add_sample(sample, 1)
est2.add_sample(sample)
mu_est2 = est2.current_mean()
cov_est2 = est2.current_covariance()

Expand Down

0 comments on commit b0b56ce

Please sign in to comment.