1313 Slice , CompoundStep )
1414from .plots .traceplot import traceplot
1515from .util import update_start_vals
16+ from pymc3 .step_methods .hmc import quadpotential
17+ from pymc3 .distributions import distribution
1618from tqdm import tqdm
1719
1820import sys
@@ -118,20 +120,27 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
118120 A step function or collection of functions. If there are variables
119121 without a step methods, step methods for those variables will
120122 be assigned automatically.
121- init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS', 'auto', None}
122- Initialization method to use. Only works for auto-assigned step methods.
123-
124- * ADVI: Run ADVI to estimate starting points and diagonal covariance
125- matrix. If njobs > 1 it will sample starting points from the estimated
126- posterior, otherwise it will use the estimated posterior mean.
127- * ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
128- * MAP: Use the MAP as starting point.
129- * NUTS: Run NUTS to estimate starting points and covariance matrix. If
130- njobs > 1 it will sample starting points from the estimated posterior,
131- otherwise it will use the estimated posterior mean.
132- * auto : Auto-initialize, if possible. Currently only works when NUTS
133- is auto-assigned as step method (default).
134- * None: Do not initialize.
123+ init : str
124+ Initialization method to use for auto-assigned NUTS samplers.
125+
126+ * auto : Choose a default initialization method automatically.
127+ Currently, this is `'advi+adapt_diag'`, but this can change in
128+ the future. If you depend on the exact behaviour, choose an
129+ initialization method explicitly.
130+ * adapt_diag : Start with a identity mass matrix and then adapt
131+ a diagonal based on the variance of the tuning samples.
132+ * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
133+ mass matrix based on the sample variance of the tuning samples.
134+ * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
135+ diagonal mass matrix based on the variance of the gradients
136+ during tuning. This is **experimental** and might be removed
137+ in a future release.
138+ * advi : Run ADVI to estimate posterior mean and diagonal mass
139+ matrix.
140+ * advi_map: Initialize ADVI with MAP and use MAP as starting point.
141+ * map : Use the MAP as starting point. This is discouraged.
142+ * nuts : Run NUTS and estimate posterior mean and mass matrix from
143+ the trace.
135144 n_init : int
136145 Number of iterations of initializer
137146 If 'ADVI', number of iterations, if 'nuts', number of draws.
@@ -220,9 +229,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
220229
221230 draws += tune
222231
223- if init is not None :
224- init = init .lower ()
225-
226232 if nuts_kwargs is not None :
227233 if step_kwargs is not None :
228234 raise ValueError ("Specify only one of step_kwargs and nuts_kwargs" )
@@ -236,8 +242,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
236242 pm ._log .info ('Auto-assigning NUTS sampler...' )
237243 args = step_kwargs if step_kwargs is not None else {}
238244 args = args .get ('nuts' , {})
239- if init == 'auto' :
240- init = 'ADVI'
241245 start_ , step = init_nuts (init = init , njobs = njobs , n_init = n_init ,
242246 model = model , random_seed = random_seed ,
243247 progressbar = progressbar , ** args )
@@ -643,28 +647,42 @@ def sample_ppc_w(traces, samples=None, models=None, size=None, weights=None,
643647 return {k : np .asarray (v ) for k , v in ppc .items ()}
644648
645649
646- def init_nuts (init = 'ADVI ' , njobs = 1 , n_init = 500000 , model = None ,
650+ def init_nuts (init = 'auto ' , njobs = 1 , n_init = 500000 , model = None ,
647651 random_seed = - 1 , progressbar = True , ** kwargs ):
648- """Initialize and sample from posterior of a continuous model .
652+ """Set up the mass matrix initialization for NUTS .
649653
650- This is a convenience function. NUTS convergence and sampling speed is extremely
651- dependent on the choice of mass/scaling matrix. In our experience, using ADVI
652- to estimate a diagonal covariance matrix and using this as the scaling matrix
653- produces robust results over a wide class of continuous models.
654+ NUTS convergence and sampling speed is extremely dependent on the
655+ choice of mass/scaling matrix. This function implements different
656+ methods for choosing or adapting the mass matrix.
654657
655658 Parameters
656659 ----------
657- init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS'}
660+ init : str
658661 Initialization method to use.
659- * ADVI : Run ADVI to estimate posterior mean and diagonal covariance matrix.
660- * ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
661- * MAP : Use the MAP as starting point.
662- * NUTS : Run NUTS and estimate posterior mean and covariance matrix.
662+
663+ * auto : Choose a default initialization method automatically.
664+ Currently, this is `'advi+adapt_diag'`, but this can change in
665+ the future. If you depend on the exact behaviour, choose an
666+ initialization method explicitly.
667+ * adapt_diag : Start with a identity mass matrix and then adapt
668+ a diagonal based on the variance of the tuning samples.
669+ * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
670+ mass matrix based on the sample variance of the tuning samples.
671+ * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
672+ diagonal mass matrix based on the variance of the gradients
673+ during tuning. This is **experimental** and might be removed
674+ in a future release.
675+ * advi : Run ADVI to estimate posterior mean and diagonal mass
676+ matrix.
677+ * advi_map: Initialize ADVI with MAP and use MAP as starting point.
678+ * map : Use the MAP as starting point. This is discouraged.
679+ * nuts : Run NUTS and estimate posterior mean and mass matrix from
680+ the trace.
663681 njobs : int
664682 Number of parallel jobs to start.
665683 n_init : int
666684 Number of iterations of initializer
667- If 'ADVI', number of iterations, if 'metropolis ', number of draws.
685+ If 'ADVI', number of iterations, if 'nuts ', number of draws.
668686 model : Model (optional if in `with` context)
669687 progressbar : bool
670688 Whether or not to display a progressbar for advi sampling.
@@ -678,20 +696,83 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
678696 nuts_sampler : pymc3.step_methods.NUTS
679697 Instantiated and initialized NUTS sampler object
680698 """
681-
682699 model = pm .modelcontext (model )
683700
684- pm ._log .info ('Initializing NUTS using {}...' .format (init ))
701+ vars = kwargs .get ('vars' , model .vars )
702+ if set (vars ) != set (model .vars ):
703+ raise ValueError ('Must use init_nuts on all variables of a model.' )
704+ if not pm .model .all_continuous (vars ):
705+ raise ValueError ('init_nuts can only be used for models with only '
706+ 'continuous variables.' )
685707
686- random_seed = int (np .atleast_1d (random_seed )[0 ])
708+ if not isinstance (init , str ):
709+ raise TypeError ('init must be a string.' )
687710
688711 if init is not None :
689712 init = init .lower ()
713+
714+ if init == 'auto' :
715+ init = 'advi+adapt_diag'
716+
717+ pm ._log .info ('Initializing NUTS using {}...' .format (init ))
718+
719+ random_seed = int (np .atleast_1d (random_seed )[0 ])
720+
690721 cb = [
691722 pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = 'absolute' ),
692723 pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = 'relative' ),
693724 ]
694- if init == 'advi' :
725+
726+ if init == 'adapt_diag' :
727+ start = []
728+ for _ in range (njobs ):
729+ vals = distribution .draw_values (model .free_RVs )
730+ point = {var .name : vals [i ] for i , var in enumerate (model .free_RVs )}
731+ start .append (point )
732+ mean = np .mean ([model .dict_to_array (vals ) for vals in start ], axis = 0 )
733+ var = np .ones_like (mean )
734+ potential = quadpotential .QuadPotentialDiagAdapt (model .ndim , mean , var , 10 )
735+ if njobs == 1 :
736+ start = start [0 ]
737+ elif init == 'advi+adapt_diag_grad' :
738+ approx = pm .fit (
739+ random_seed = random_seed ,
740+ n = n_init , method = 'advi' , model = model ,
741+ callbacks = cb ,
742+ progressbar = progressbar ,
743+ obj_optimizer = pm .adagrad_window ,
744+ )
745+ start = approx .sample (draws = njobs )
746+ start = list (start )
747+ stds = approx .gbij .rmap (approx .std .eval ())
748+ cov = model .dict_to_array (stds ) ** 2
749+ mean = approx .gbij .rmap (approx .mean .get_value ())
750+ mean = model .dict_to_array (mean )
751+ weight = 50
752+ potential = quadpotential .QuadPotentialDiagAdaptGrad (
753+ model .ndim , mean , cov , weight )
754+ if njobs == 1 :
755+ start = start [0 ]
756+ elif init == 'advi+adapt_diag' :
757+ approx = pm .fit (
758+ random_seed = random_seed ,
759+ n = n_init , method = 'advi' , model = model ,
760+ callbacks = cb ,
761+ progressbar = progressbar ,
762+ obj_optimizer = pm .adagrad_window ,
763+ )
764+ start = approx .sample (draws = njobs )
765+ start = list (start )
766+ stds = approx .gbij .rmap (approx .std .eval ())
767+ cov = model .dict_to_array (stds ) ** 2
768+ mean = approx .gbij .rmap (approx .mean .get_value ())
769+ mean = model .dict_to_array (mean )
770+ weight = 50
771+ potential = quadpotential .QuadPotentialDiagAdapt (
772+ model .ndim , mean , cov , weight )
773+ if njobs == 1 :
774+ start = start [0 ]
775+ elif init == 'advi' :
695776 approx = pm .fit (
696777 random_seed = random_seed ,
697778 n = n_init , method = 'advi' , model = model ,
@@ -700,8 +781,10 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
700781 obj_optimizer = pm .adagrad_window
701782 ) # type: pm.MeanField
702783 start = approx .sample (draws = njobs )
784+ start = list (start )
703785 stds = approx .gbij .rmap (approx .std .eval ())
704786 cov = model .dict_to_array (stds ) ** 2
787+ potential = quadpotential .QuadPotentialDiag (cov )
705788 if njobs == 1 :
706789 start = start [0 ]
707790 elif init == 'advi_map' :
@@ -715,24 +798,31 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
715798 obj_optimizer = pm .adagrad_window
716799 )
717800 start = approx .sample (draws = njobs )
801+ start = list (start )
718802 stds = approx .gbij .rmap (approx .std .eval ())
719803 cov = model .dict_to_array (stds ) ** 2
804+ potential = quadpotential .QuadPotentialDiag (cov )
720805 if njobs == 1 :
721806 start = start [0 ]
722807 elif init == 'map' :
723808 start = pm .find_MAP ()
724809 cov = pm .find_hessian (point = start )
810+ start = [start ] * njobs
811+ potential = quadpotential .QuadPotentialFull (cov )
812+ if njobs == 1 :
813+ start = start [0 ]
725814 elif init == 'nuts' :
726815 init_trace = pm .sample (draws = n_init , step = pm .NUTS (),
727816 tune = n_init // 2 ,
728817 random_seed = random_seed )
729818 cov = np .atleast_1d (pm .trace_cov (init_trace ))
730- start = np .random .choice (init_trace , njobs )
819+ start = list (np .random .choice (init_trace , njobs ))
820+ potential = quadpotential .QuadPotentialFull (cov )
731821 if njobs == 1 :
732822 start = start [0 ]
733823 else :
734824 raise NotImplementedError ('Initializer {} is not supported.' .format (init ))
735825
736- step = pm .NUTS (scaling = cov , is_cov = True , ** kwargs )
826+ step = pm .NUTS (potential = potential , ** kwargs )
737827
738828 return start , step
0 commit comments