diff --git a/pymc3/step_methods/hmc/quadpotential.py b/pymc3/step_methods/hmc/quadpotential.py index 1943c774fd0..b22fd094ff7 100644 --- a/pymc3/step_methods/hmc/quadpotential.py +++ b/pymc3/step_methods/hmc/quadpotential.py @@ -307,7 +307,7 @@ def update(self, sample, grad, tune): self._ngrads2 += 1 if self._n_samples <= 150: - super().update(sample, grad) + super().update(sample, grad, tune) else: self._update((self._ngrads1 / self._grads1) ** 2) diff --git a/pymc3/tests/test_quadpotential.py b/pymc3/tests/test_quadpotential.py index 6052f180fe2..c0ab5f7aacd 100644 --- a/pymc3/tests/test_quadpotential.py +++ b/pymc3/tests/test_quadpotential.py @@ -283,3 +283,9 @@ def test_full_adapt_sampling(seed=289586): pymc3.sample( draws=10, tune=1000, random_seed=seed, step=step, cores=1, chains=1 ) + + +def test_issue_3965(): + with pymc3.Model(): + pymc3.Normal('n') + pymc3.sample(100, tune=300, chains=1, init='advi+adapt_diag_grad')