diff --git a/pymc/sampling.py b/pymc/sampling.py index f5174b37a01..fad0ee1f19d 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -278,19 +278,18 @@ def sample( draws : int The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded by default. See ``discard_tuned_samples``. - init : str - Initialization method to use for auto-assigned NUTS samplers. - See `pm.init_nuts` for a list of all options. + init : str or None + Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list + of all options. If ``None`` no initialization method will be used. step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, - step methods for those variables will be assigned automatically. By default the NUTS step - method will be used, if appropriate to the model; this is a good default for beginning - users. + step methods for those variables will be assigned automatically. By default the NUTS step + method will be used, if appropriate to the model. n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. initvals : optional, dict, array of dict - Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. - The keys should be names of transformed random variables. + Dict or list of dicts with initial value strategies to use instead of the defaults from + `Model.initial_values`. The keys should be names of transformed random variables. Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. trace : backend or list This should be a backend instance, or a list of variables to track. @@ -317,8 +316,8 @@ def sample( model : Model (optional if in ``with`` context) Model to sample from. The model needs to have free random variables. random_seed : int or list of ints - Random seed(s) used by the sampling steps. A list is accepted if - ``cores`` is greater than one. + Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than + one. discard_tuned_samples : bool Whether to discard posterior samples of the tune interval. compute_convergence_checks : bool, default=True @@ -330,17 +329,17 @@ def sample( is drawn from. Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. jitter_max_retries : int - Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter - that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full`` - init methods. + Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform + jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and + ``jitter+adapt_full`` init methods. return_inferencedata : bool - Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False) - Defaults to `True`. + Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a + `MultiTrace` (False). Defaults to `True`. idata_kwargs : dict, optional Keyword arguments for :func:`pymc.to_inference_data` mp_ctx : multiprocessing.context.BaseContent - A multiprocessing context for parallel sampling. See multiprocessing - documentation for details. + A multiprocessing context for parallel sampling. + See multiprocessing documentation for details. Returns ------- @@ -373,7 +372,7 @@ def sample( The initial guess for the step size scaled down by :math:`1/n**(1/4)`, where n is the dimensionality of the parameter space - Alternativelly, if you manually declare the ``step_method``\ s, within the ``step`` + Alternatively, if you manually declare the ``step_method``\ s, within the ``step`` kwarg, then you can address the ``step_method`` kwargs directly. e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, you could send :: @@ -464,7 +463,7 @@ def sample( if isinstance(step, list): step = CompoundStep(step) - elif isinstance(step, NUTS): + elif isinstance(step, NUTS) and init is not None: if "nuts" in kwargs: nuts_kwargs = kwargs.pop("nuts") [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] @@ -2272,8 +2271,7 @@ def init_nuts( if not isinstance(init, str): raise TypeError("init must be a string.") - if init is not None: - init = init.lower() + init = init.lower() if init == "auto": init = "jitter+adapt_diag"