1818import sys
1919sys .setrecursionlimit (10000 )
2020
21- __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'init_nuts' ]
21+ __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'sample_ppc_w' , ' init_nuts' ]
2222
2323STEP_METHODS = (NUTS , HamiltonianMC , Metropolis , BinaryMetropolis ,
2424 BinaryGibbsMetropolis , Slice , CategoricalGibbsMetropolis )
@@ -489,14 +489,15 @@ def _update_start_vals(a, b, model):
489489
490490 a .update ({k : v for k , v in b .items () if k not in a })
491491
492+
492493def sample_ppc (trace , samples = None , model = None , vars = None , size = None ,
493494 random_seed = None , progressbar = True ):
494495 """Generate posterior predictive samples from a model given a trace.
495496
496497 Parameters
497498 ----------
498499 trace : backend, list, or MultiTrace
499- Trace generated from MCMC sampling
500+ Trace generated from MCMC sampling.
500501 samples : int
501502 Number of posterior predictive samples to generate. Defaults to the
502503 length of `trace`
@@ -508,12 +509,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
508509 size : int
509510 The number of random draws from the distribution specified by the
510511 parameters in each sample of the trace.
512+ random_seed : int
513+ Seed for the random number generator.
514+ progressbar : bool
515+ Whether or not to display a progress bar in the command line. The
516+ bar shows the percentage of completion, the sampling speed in
517+ samples per second (SPS), and the estimated remaining time until
518+ completion ("expected time of arrival"; ETA).
511519
512520 Returns
513521 -------
514522 samples : dict
515- Dictionary with the variables as keys. The values corresponding
516- to the posterior predictive samples.
523+ Dictionary with the variables as keys. The values corresponding to the
524+ posterior predictive samples.
517525 """
518526 if samples is None :
519527 samples = len (trace )
@@ -526,18 +534,124 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
526534
527535 seed (random_seed )
528536
537+ indices = randint (0 , len (trace ), samples )
529538 if progressbar :
530- indices = tqdm (randint (0 , len (trace ), samples ), total = samples )
531- else :
532- indices = randint (0 , len (trace ), samples )
539+ indices = tqdm (indices , total = samples )
533540
534541 try :
535542 ppc = defaultdict (list )
536543 for idx in indices :
537544 param = trace [idx ]
538545 for var in vars :
539- vals = var .distribution .random (point = param , size = size )
540- ppc [var .name ].append (vals )
546+ ppc [var .name ].append (var .distribution .random (point = param ,
547+ size = size ))
548+
549+ except KeyboardInterrupt :
550+ pass
551+
552+ finally :
553+ if progressbar :
554+ indices .close ()
555+
556+ return {k : np .asarray (v ) for k , v in ppc .items ()}
557+
558+
559+ def sample_ppc_w (traces , samples = None , models = None , size = None , weights = None ,
560+ random_seed = None , progressbar = True ):
561+ """Generate weighted posterior predictive samples from a list of models and
562+ a list of traces according to a set of weights.
563+
564+ Parameters
565+ ----------
566+ traces : list
567+ List of traces generated from MCMC sampling. The number of traces should
568+ be equal to the number of weights.
569+ samples : int
570+ Number of posterior predictive samples to generate. Defaults to the
571+ length of the shorter trace in traces.
572+ models : list
573+ List of models used to generate the list of traces. The number of models
574+ should be equal to the number of weights and the number of observed RVs
575+ should be the same for all models.
576+ By default a single model will be inferred from `with` context, in this
577+ case results will only be meaningful if all models share the same
578+ distributions for the observed RVs.
579+ size : int
580+ The number of random draws from the distributions specified by the
581+ parameters in each sample of the trace.
582+ weights: array-like
583+ Individual weights for each trace. Default, same weight for each model.
584+ random_seed : int
585+ Seed for the random number generator.
586+ progressbar : bool
587+ Whether or not to display a progress bar in the command line. The
588+ bar shows the percentage of completion, the sampling speed in
589+ samples per second (SPS), and the estimated remaining time until
590+ completion ("expected time of arrival"; ETA).
591+
592+ Returns
593+ -------
594+ samples : dict
595+ Dictionary with the variables as keys. The values corresponding to the
596+ posterior predictive samples from the weighted models.
597+ """
598+ seed (random_seed )
599+
600+ if models is None :
601+ models = [modelcontext (models )] * len (traces )
602+
603+ if weights is None :
604+ weights = [1 ] * len (traces )
605+
606+ if len (traces ) != len (weights ):
607+ raise ValueError ('The number of traces and weights should be the same' )
608+
609+ if len (models ) != len (weights ):
610+ raise ValueError ('The number of models and weights should be the same' )
611+
612+ lenght_morv = len (models [0 ].observed_RVs )
613+ if not all (len (i .observed_RVs ) == lenght_morv for i in models ):
614+ raise ValueError (
615+ 'The number of observed RVs should be the same for all models' )
616+
617+ weights = np .asarray (weights )
618+ p = weights / np .sum (weights )
619+
620+ min_tr = min ([len (i ) for i in traces ])
621+
622+ n = (min_tr * p ).astype ('int' )
623+ # ensure n sum up to min_tr
624+ idx = np .argmax (n )
625+ n [idx ] = n [idx ] + min_tr - np .sum (n )
626+
627+ trace = np .concatenate ([np .random .choice (traces [i ], j )
628+ for i , j in enumerate (n )])
629+
630+ variables = []
631+ for i , m in enumerate (models ):
632+ variables .extend (m .observed_RVs * n [i ])
633+
634+ len_trace = len (trace )
635+
636+ if samples is None :
637+ samples = len_trace
638+
639+ indices = randint (0 , len_trace , samples )
640+
641+ if progressbar :
642+ indices = tqdm (indices , total = samples )
643+
644+ try :
645+ ppc = defaultdict (list )
646+ for idx in indices :
647+ param = trace [idx ]
648+ var = variables [idx ]
649+ ppc [var .name ].append (var .distribution .random (point = param ,
650+ size = size ))
651+
652+ except KeyboardInterrupt :
653+ pass
654+
541655 finally :
542656 if progressbar :
543657 indices .close ()
0 commit comments