From b988ba99a084a18b2c907ef2ee4aeb1f4ab56160 Mon Sep 17 00:00:00 2001 From: Hennadii Madan Date: Fri, 28 Jul 2017 11:17:28 +0200 Subject: [PATCH] Make slice sampler sample from 1D conditionals as it should (#2446) * Make Slice sampler sample from 1D conditionals In the previous implementation it would sample jointly from non-scalar variables, and hang for when the size is high (due to low probability to get a joint sample within the slice in high-D). * slicer.py Fix broken indentation due to copypaste * Apply autopep8 * Delete a superfluous commented line * Update the master sample for Slice in test_step.py --- pymc3/step_methods/slicer.py | 61 +++++++++++++++++++----------------- pymc3/tests/test_step.py | 54 +++++++++++++++++++------------ 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/pymc3/step_methods/slicer.py b/pymc3/step_methods/slicer.py index ab9059b6b7..38465aaceb 100644 --- a/pymc3/step_methods/slicer.py +++ b/pymc3/step_methods/slicer.py @@ -34,8 +34,7 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs): self.model = modelcontext(model) self.w = w self.tune = tune - self.w_sum = 0 - self.n_tunes = 0 + self.n_tunes = 0. if vars is None: vars = self.model.cont_vars @@ -44,33 +43,39 @@ def __init__(self, vars=None, w=1., tune=True, model=None, **kwargs): super(Slice, self).__init__(vars, [self.model.fastlogp], **kwargs) def astep(self, q0, logp): - self.w = np.resize(self.w, len(q0)) - y = logp(q0) - nr.standard_exponential() - - # Stepping out procedure - q_left = q0 - nr.uniform(0, self.w) - q_right = q_left + self.w - - while (y < logp(q_left)).all(): - q_left -= self.w - - while (y < logp(q_right)).all(): - q_right += self.w - - q = nr.uniform(q_left, q_right, size=q_left.size) # new variable to avoid copies - while logp(q) <= y: - # Sample uniformly from slice - if (q > q0).all(): - q_right = q - elif (q < q0).all(): - q_left = q - q = nr.uniform(q_left, q_right, size=q_left.size) - + self.w = np.resize(self.w, len(q0)) # this is a repmat + q = np.copy(q0) # TODO: find out if we need this + ql = np.copy(q0) # l for left boundary + qr = np.copy(q0) # r for right boudary + for i in range(len(q0)): + # uniformly sample from 0 to p(q), but in log space + y = logp(q) - nr.standard_exponential() + ql[i] = q[i] - nr.uniform(0, self.w[i]) + qr[i] = q[i] + self.w[i] + # Stepping out procedure + while(y <= logp(ql)): # changed lt to leq for locally uniform posteriors + ql[i] -= self.w[i] + while(y <= logp(qr)): + qr[i] += self.w[i] + + q[i] = nr.uniform(ql[i], qr[i]) + while logp(q) < y: # Changed leq to lt, to accomodate for locally flat posteriors + # Sample uniformly from slice + if q[i] > q0[i]: + qr[i] = q[i] + elif q[i] < q0[i]: + ql[i] = q[i] + q[i] = nr.uniform(ql[i], qr[i]) + + if self.tune: # I was under impression from MacKays lectures that slice width can be tuned without + # breaking markovianness. Can we do it regardless of self.tune?(@madanh) + self.w[i] = self.w[i] * (self.n_tunes / (self.n_tunes + 1)) +\ + (qr[i] - ql[i]) / (self.n_tunes + 1) # same as before + # unobvious and important: return qr and ql to the same point + qr[i] = q[i] + ql[i] = q[i] if self.tune: - # Tune sampler parameters - self.w_sum += np.abs(q0 - q) - self.n_tunes += 1. - self.w = 2. * self.w_sum / self.n_tunes + self.n_tunes += 1 return q @staticmethod diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index e0f0c0b6b2..a42619a4aa 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -27,26 +27,40 @@ class TestStepMethods(object): # yield test doesn't work subclassing object master_samples = { Slice: np.array([ - -8.13087389e-01, -3.08921856e-01, -6.79377098e-01, 6.50812585e-01, -7.63577596e-01, - -8.13199793e-01, -1.63823548e+00, -7.03863676e-02, 2.05107771e+00, 1.68598170e+00, - 6.92463695e-01, -7.75120766e-01, -1.62296463e+00, 3.59722423e-01, -2.31421712e-01, - -7.80686956e-02, -6.05860731e-01, -1.13000202e-01, 1.55675942e-01, -6.78527612e-01, - 6.31052333e-01, 6.09012517e-01, -1.56621643e+00, 5.04330883e-01, 3.14824082e-03, - -1.31287073e+00, 4.10706927e-01, 8.93815792e-01, 8.19317020e-01, 3.71900919e-01, - -2.62067312e+00, -3.47616592e+00, 1.50335041e+00, -1.05993351e+00, 2.41571723e-01, - -1.06258156e+00, 5.87999429e-01, -1.78480091e-01, -3.60278680e-01, 1.90615274e-01, - -1.24399204e-01, 4.03845589e-01, -1.47797573e-01, 7.90445804e-01, -1.21043819e+00, - -1.33964776e+00, 1.36366329e+00, -7.50175388e-01, 9.25241839e-01, -4.17493767e-01, - 1.85311339e+00, -2.49715343e+00, -3.18571692e-01, -1.49099668e+00, -2.62079621e-01, - -5.82376852e-01, -2.53033395e+00, 2.07580503e+00, -9.82615856e-01, 6.00517782e-01, - -9.83941620e-01, -1.59014118e+00, -1.83931394e-03, -4.71163466e-01, 1.90073737e+00, - -2.08929125e-01, -6.98388847e-01, 1.64502092e+00, -1.19525944e+00, 1.44424109e+00, - 1.52974876e+00, -5.70140077e-01, 5.08633322e-01, -1.70862492e-02, -1.69887948e-01, - 5.19760297e-01, -4.15149647e-01, 8.63685174e-02, -3.66805233e-01, -9.24988952e-01, - 2.33307122e+00, -2.60391496e-01, -5.86271814e-01, -5.01297170e-01, -1.53866195e+00, - 5.71285373e-01, -1.30571830e+00, 8.59587795e-01, 6.72170694e-01, 9.12433943e-01, - 7.04959179e-01, 8.37863464e-01, -5.24200836e-01, 1.28261340e+00, 9.08774240e-01, - 8.80566763e-01, 7.82911967e-01, 8.01843432e-01, 7.09251098e-01, 5.73803618e-01]), + -5.95252353e-01, -1.81894861e-01, -4.98211488e-01, + -1.02262800e-01, -4.26726030e-01, 1.75446860e+00, + -1.30022548e+00, 8.35658004e-01, 8.95879638e-01, + -8.85214481e-01, -6.63530918e-01, -8.39303080e-01, + 9.42792225e-01, 9.03554344e-01, 8.45254684e-01, + -1.43299803e+00, 9.04897201e-01, -1.74303131e-01, + -6.38611581e-01, 1.50013968e+00, 1.06864438e+00, + -4.80484421e-01, -7.52199709e-01, 1.95067495e+00, + -3.67960104e+00, 2.49291588e+00, -2.11039152e+00, + 1.61674758e-01, -1.59564182e-01, 2.19089873e-01, + 1.88643940e+00, 4.04098154e-01, -4.59352326e-01, + -9.06370675e-01, 5.42817654e-01, 6.99040611e-03, + 1.66396391e-01, -4.74549281e-01, 8.19064437e-02, + 1.69689952e+00, -1.62667304e+00, 1.61295808e+00, + 1.30099144e+00, -5.46722750e-01, -7.87745494e-01, + 7.91027521e-01, -2.35706976e-02, 1.68824376e+00, + 7.10566880e-01, -7.23551374e-01, 8.85613069e-01, + -1.27300146e+00, 1.80274430e+00, 9.34266276e-01, + 2.40427061e+00, -1.85132552e-01, 4.47234196e-01, + -9.81894859e-01, -2.83399706e-01, 1.84717533e+00, + -1.58593284e+00, 3.18027270e-02, 1.40566006e+00, + -9.45758714e-01, 1.18813188e-01, -1.19938604e+00, + -8.26038466e-01, 5.03469984e-01, -4.72742758e-01, + 2.27820946e-01, -1.02608915e-03, -6.02507158e-01, + 7.72739682e-01, 7.16064505e-01, -1.63693490e+00, + -3.97161966e-01, 1.17147944e+00, -2.87796982e+00, + -1.59533297e+00, 6.73096114e-01, -3.34397247e-01, + 1.22357427e-01, -4.57299104e-02, 1.32005771e+00, + -1.29910645e+00, 8.16168850e-01, -1.47357594e+00, + 1.34688446e+00, 1.06377551e+00, 4.34296696e-02, + 8.23143354e-01, 8.40906324e-01, 1.88596864e+00, + 5.77120694e-01, 2.71732927e-01, -1.36217979e+00, + 2.41488213e+00, 4.68298379e-01, 4.86342250e-01, + -8.43949966e-01]), HamiltonianMC: np.array([ -0.74925631, -0.2566773 , -2.12480977, 1.64328926, -1.39315913, 2.04200003, 0.00706711, 0.34240498, 0.44276674, -0.21368043,