Skip to content

Commit

Permalink
allow init to be None
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Mar 19, 2022
1 parent aab25b3 commit aa20d7e
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,18 @@ def sample(
draws : int
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded
by default. See ``discard_tuned_samples``.
init : str
Initialization method to use for auto-assigned NUTS samplers.
See `pm.init_nuts` for a list of all options.
init : str or None
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
of all options. If ``None`` no initialization method will be used.
step : function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
step methods for those variables will be assigned automatically. By default the NUTS step
method will be used, if appropriate to the model; this is a good default for beginning
users.
step methods for those variables will be assigned automatically. By default the NUTS step
method will be used, if appropriate to the model.
n_init : int
Number of iterations of initializer. Only works for 'ADVI' init methods.
initvals : optional, dict, array of dict
Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`.
The keys should be names of transformed random variables.
Dict or list of dicts with initial value strategies to use instead of the defaults from
`Model.initial_values`. The keys should be names of transformed random variables.
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
trace : backend or list
This should be a backend instance, or a list of variables to track.
Expand All @@ -317,8 +316,8 @@ def sample(
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
random_seed : int or list of ints
Random seed(s) used by the sampling steps. A list is accepted if
``cores`` is greater than one.
Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than
one.
discard_tuned_samples : bool
Whether to discard posterior samples of the tune interval.
compute_convergence_checks : bool, default=True
Expand All @@ -330,17 +329,17 @@ def sample(
is drawn from.
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
jitter_max_retries : int
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
``jitter+adapt_full`` init methods.
return_inferencedata : bool
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `True`.
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
`MultiTrace` (False). Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing
documentation for details.
A multiprocessing context for parallel sampling.
See multiprocessing documentation for details.
Returns
-------
Expand Down Expand Up @@ -373,7 +372,7 @@ def sample(
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
where n is the dimensionality of the parameter space
Alternativelly, if you manually declare the ``step_method``\ s, within the ``step``
Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
kwarg, then you can address the ``step_method`` kwargs directly.
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
you could send ::
Expand Down Expand Up @@ -464,7 +463,7 @@ def sample(

if isinstance(step, list):
step = CompoundStep(step)
elif isinstance(step, NUTS):
elif isinstance(step, NUTS) and init is not None:
if "nuts" in kwargs:
nuts_kwargs = kwargs.pop("nuts")
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
Expand Down Expand Up @@ -2272,8 +2271,7 @@ def init_nuts(
if not isinstance(init, str):
raise TypeError("init must be a string.")

if init is not None:
init = init.lower()
init = init.lower()

if init == "auto":
init = "jitter+adapt_diag"
Expand Down

0 comments on commit aa20d7e

Please sign in to comment.