Skip to content

Commit 6feadda

Browse files
xjing76brandonwillard
authored andcommitted
Add Gibbs step for negative-binomial dispersion term
1 parent 6e416ed commit 6feadda

File tree

2 files changed

+381
-3
lines changed

2 files changed

+381
-3
lines changed

aemcmc/gibbs.py

+223-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Dict, List, Tuple, Union
1+
from typing import Dict, List, Mapping, Tuple, Union
22

33
import aesara
44
import aesara.tensor as at
55
from aesara.graph import optimize_graph
6+
from aesara.graph.basic import Variable
67
from aesara.graph.opt import EquilibriumOptimizer
78
from aesara.graph.unify import eval_if_etuple
89
from aesara.ifelse import ifelse
@@ -212,6 +213,25 @@ def horseshoe_step(
212213
etuple(etuplize(at.mul), neg_one_lv, etuple(etuple(Dot), X_lv, beta_lv)),
213214
)
214215

216+
a_lv = var()
217+
b_lv = var()
218+
gamma_pattern = etuple(etuplize(at.random.gamma), var(), var(), var(), a_lv, b_lv)
219+
220+
221+
def gamma_match(
222+
graph: TensorVariable,
223+
) -> Tuple[TensorVariable, TensorVariable]:
224+
graph_opt = optimize_graph(graph)
225+
graph_et = etuplize(graph_opt)
226+
s = unify(graph_et, gamma_pattern)
227+
if s is False:
228+
raise ValueError("Not a gamma prior.")
229+
230+
a = eval_if_etuple(s[a_lv])
231+
b = eval_if_etuple(s[b_lv])
232+
233+
return a, b
234+
215235

216236
h_lv = var()
217237
nbinom_sigmoid_dot_pattern = etuple(
@@ -264,6 +284,113 @@ def nbinom_horseshoe_match(
264284
return h, X, beta_rv, lmbda_rv, tau_rv
265285

266286

287+
def sample_CRT(
288+
srng: RandomStream, y: TensorVariable, h: TensorVariable
289+
) -> Tuple[TensorVariable, Mapping[Variable, Variable]]:
290+
r"""Sample a Chinese Restaurant Process value: :math:`l \sim \operatorname{CRT}(y, h)`.
291+
292+
Sampling is performed according to the following:
293+
294+
.. math::
295+
296+
\begin{gather*}
297+
l = \sum_{n=1}^{y} b_n, \quad
298+
b_n \sim \operatorname{Bern}\left(\frac{h}{n - 1 + h}\right)
299+
\end{gather*}
300+
301+
References
302+
----------
303+
.. [1] Zhou, Mingyuan, and Lawrence Carin. 2012. “Augment-and-Conquer Negative Binomial Processes.” Advances in Neural Information Processing Systems 25.
304+
305+
"""
306+
307+
def single_sample_CRT(y_i: TensorVariable, h: TensorVariable):
308+
n_i = at.arange(2, y_i + 1)
309+
return at.switch(y_i < 1, 0, 1 + srng.bernoulli(h / (n_i - 1 + h)).sum())
310+
311+
res, updates = aesara.scan(
312+
single_sample_CRT,
313+
sequences=[y.ravel()],
314+
non_sequences=[h],
315+
strict=True,
316+
)
317+
res = res.reshape(y.shape)
318+
res.name = "CRT sample"
319+
320+
return res, updates
321+
322+
323+
def h_step(
324+
srng: RandomStream,
325+
h_last: TensorVariable,
326+
p: TensorVariable,
327+
a: TensorVariable,
328+
b: TensorVariable,
329+
y: TensorVariable,
330+
) -> Tuple[TensorVariable, Mapping[Variable, Variable]]:
331+
r"""Sample the conditional posterior for the dispersion parameter under a negative-binomial and gamma prior.
332+
333+
In other words, this draws a sample from :math:`h \mid Y = y` per
334+
335+
.. math::
336+
337+
\begin{align*}
338+
Y_i &\sim \operatorname{NB}(h, p_i) \\
339+
h &\sim \operatorname{Gamma}(a, b)
340+
\end{align*}
341+
342+
where `y` is a sample from :math:`y \sim Y`.
343+
344+
The conditional posterior sample step is derived from the following decomposition:
345+
346+
.. math::
347+
\begin{gather*}
348+
Y_i = \sum_{j=1}^{l_i} u_{i j}, \quad u_{i j} \sim \operatorname{Log}(p_i), \quad
349+
l_i \sim \operatorname{Pois}\left(-h \log(1 - p_i)\right)
350+
\end{gather*}
351+
352+
where :math:`\operatorname{Log}` is the logarithmic distribution. Under a
353+
gamma prior, :math:`h` is conjugate to :math:`l`. We draw samples from
354+
:math:`l` according to :math:`l \sim \operatorname{CRT(y, h)}`.
355+
356+
The resulting posterior is
357+
358+
.. math::
359+
360+
\begin{gather*}
361+
\left(h \mid Y = y\right) \sim \operatorname{Gamma}\left(a + \sum_{i=1}^N l_i, \frac{1}{1/b + \sum_{i=1}^N \log(1 - p_i)} \right)
362+
\end{gather*}
363+
364+
365+
References
366+
----------
367+
.. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. “Lognormal and Gamma Mixed Negative Binomial Regression.” Proceedings of the International Conference on Machine Learning. International Conference on Machine Learning 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/.
368+
369+
"""
370+
Ls, updates = sample_CRT(srng, y, h_last)
371+
L_sum = Ls.sum(axis=-1)
372+
h = srng.gamma(a + L_sum, at.reciprocal(b) - at.sum(at.log(1 - p), axis=-1))
373+
h.name = f"{h_last.name or 'h'}-posterior"
374+
return h, updates
375+
376+
377+
def nbinom_horseshoe_with_dispersion_match(
378+
Y_rv: TensorVariable,
379+
) -> Tuple[
380+
TensorVariable,
381+
TensorVariable,
382+
TensorVariable,
383+
TensorVariable,
384+
TensorVariable,
385+
TensorVariable,
386+
TensorVariable,
387+
]:
388+
X, h_rv, beta_rv = nbinom_sigmoid_dot_match(Y_rv)
389+
lmbda_rv, tau_rv = horseshoe_match(beta_rv)
390+
a, b = gamma_match(h_rv)
391+
return X, beta_rv, lmbda_rv, tau_rv, h_rv, a, b
392+
393+
267394
def nbinom_horseshoe_gibbs(
268395
srng: RandomStream, Y_rv: TensorVariable, y: TensorVariable, num_samples: int
269396
) -> Tuple[Union[TensorVariable, List[TensorVariable]], Dict]:
@@ -275,7 +402,7 @@ def nbinom_horseshoe_gibbs(
275402
.. math::
276403
277404
\begin{align*}
278-
Y_i &\sim \operatorname{NB}\left(p_i, h\right) \\
405+
Y_i &\sim \operatorname{NB}\left(h, p_i\right) \\
279406
p_i &= \frac{\exp(\psi_i)}{1 + \exp(\psi_i)} \\
280407
\psi_i &= x_i^\top \beta \\
281408
\beta_j &\sim \operatorname{N}(0, \lambda_j^2 \tau^2) \\
@@ -361,7 +488,6 @@ def nbinom_horseshoe_step(
361488
lmbda_inv_new, tau_inv_new = horseshoe_step(
362489
srng, beta_new, 1.0, lmbda_inv, tau_inv
363490
)
364-
365491
return beta_new, 1.0 / lmbda_inv_new, 1.0 / tau_inv_new
366492

367493
h, X, beta_rv, lmbda_rv, tau_rv = nbinom_horseshoe_match(Y_rv)
@@ -377,6 +503,100 @@ def nbinom_horseshoe_step(
377503
return outputs, updates
378504

379505

506+
def nbinom_horseshoe_gibbs_with_dispersion(
507+
srng: RandomStream,
508+
Y_rv: TensorVariable,
509+
y: TensorVariable,
510+
num_samples: TensorVariable,
511+
) -> Tuple[Union[TensorVariable, List[TensorVariable]], Mapping[Variable, Variable]]:
512+
r"""Build a Gibbs sampler for the negative binomial regression with a horseshoe prior and gamma prior dispersion.
513+
514+
This is a direct extension of `nbinom_horseshoe_gibbs_with_dispersion` that
515+
adds a gamma prior assumption to the :math:`h` parameter in the
516+
negative-binomial and samples according to [1]_.
517+
518+
In other words, this model is the same as `nbinom_horseshoe_gibbs` except
519+
for the addition assumption:
520+
521+
.. math::
522+
523+
\begin{gather*}
524+
h \sim \operatorname{Gamma}\left(a, b\right)
525+
\end{gather*}
526+
527+
528+
References
529+
----------
530+
.. [1] Zhou, Mingyuan, Lingbo Li, David Dunson, and Lawrence Carin. 2012. “Lognormal and Gamma Mixed Negative Binomial Regression.” Proceedings of the International Conference on Machine Learning. International Conference on Machine Learning 2012: 1343–50. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4180062/.
531+
532+
"""
533+
534+
def nbinom_horseshoe_step(
535+
beta: TensorVariable,
536+
lmbda: TensorVariable,
537+
tau: TensorVariable,
538+
h: TensorVariable,
539+
y: TensorVariable,
540+
X: TensorVariable,
541+
a: TensorVariable,
542+
b: TensorVariable,
543+
):
544+
"""Complete one full update of the Gibbs sampler and return the new state
545+
of the posterior conditional parameters.
546+
547+
Parameters
548+
----------
549+
beta
550+
Coefficients (other than intercept) of the regression model.
551+
lmbda
552+
Inverse of the local shrinkage parameter of the horseshoe prior.
553+
tau
554+
Inverse of the global shrinkage parameters of the horseshoe prior.
555+
h
556+
The "number of successes" parameter of the negative-binomial distribution
557+
used to model the data.
558+
y
559+
The observed count data.
560+
X
561+
The covariate matrix.
562+
a
563+
The shape parameter for the :math:`h` gamma prior.
564+
b
565+
The rate parameter for the :math:`h` gamma prior.
566+
567+
"""
568+
xb = X @ beta
569+
w = srng.gen(polyagamma, y + h, xb)
570+
z = 0.5 * (y - h) / w
571+
572+
lmbda_inv = 1.0 / lmbda
573+
tau_inv = 1.0 / tau
574+
beta_new = update_beta(srng, w, lmbda_inv * tau_inv, X, z)
575+
576+
lmbda_inv_new, tau_inv_new = horseshoe_step(
577+
srng, beta_new, 1.0, lmbda_inv, tau_inv
578+
)
579+
eta = X @ beta_new
580+
p = at.sigmoid(-eta)
581+
h_new, h_updates = h_step(srng, h, p, a, b, y)
582+
583+
return (beta_new, 1.0 / lmbda_inv_new, 1.0 / tau_inv_new, h_new), h_updates
584+
585+
X, beta_rv, lmbda_rv, tau_rv, h_rv, a, b = nbinom_horseshoe_with_dispersion_match(
586+
Y_rv
587+
)
588+
589+
outputs, updates = aesara.scan(
590+
nbinom_horseshoe_step,
591+
outputs_info=[beta_rv, lmbda_rv, tau_rv, h_rv],
592+
non_sequences=[y, X, a, b],
593+
n_steps=num_samples,
594+
strict=True,
595+
)
596+
597+
return outputs, updates
598+
599+
380600
bernoulli_sigmoid_dot_pattern = etuple(
381601
etuplize(at.random.bernoulli), var(), var(), var(), sigmoid_dot_pattern
382602
)

0 commit comments

Comments
 (0)