Skip to content

Sampling from model with LKJCholeskyCov raises ValueError #5441

Closed
@fonnesbeck

Description

@fonnesbeck

Description of your problem

In the process of trying to update the LKJ notebook in the examples following the merge of #5382, I am unable to get sampling to run. The model fails with ValueError: cannot reshape array of size 1 into shape (1,2) during initialization.

Please provide the full traceback.

Complete error traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
    306 for thunk, node, old_storage in zip(
    307     self.thunks, self.nodes, self.post_thunk_clear
    308 ):
--> 309     thunk()
    310     for old_s in old_storage:

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/graph/op.py:508, in Op.make_py_thunk.<locals>.rval(p, i, o, n)
    507 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508     r = p(n, [x[0] for x in i], o)
    509     for o in node.outputs:

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:382, in RandomVariable.perform(self, node, inputs, outputs)
    380 rng_var_out[0] = rng
--> 382 smpl_val = self.rng_fn(rng, *(args + [size]))
    384 if (
    385     not isinstance(smpl_val, np.ndarray)
    386     or str(smpl_val.dtype) != out_var.type.dtype
    387 ):

File ~/pymc/pymc/distributions/multivariate.py:1150, in _LKJCholeskyCovRV.rng_fn(self, rng, n, eta, D, size)
   1148 C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
-> 1150 D = D.reshape(flat_size, n)
   1151 C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]

ValueError: cannot reshape array of size 1 into shape (1,2)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Input In [49], in <module>
      1 with model:
----> 2     trace = pm.sample(
      3         idata_kwargs={"dims": {"chol_stds": ["axis"], "chol_corr": ["axis", "axis_bis"]}},
      4     )
      5 az.summary(trace, var_names="~chol", round_to=2)

File ~/pymc/pymc/sampling.py:496, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    493 try:
    494     # By default, try to use NUTS
    495     _log.info("Auto-assigning NUTS sampler...")
--> 496     initial_points, step = init_nuts(
    497         init=init,
    498         chains=chains,
    499         n_init=n_init,
    500         model=model,
    501         seeds=random_seed,
    502         progressbar=progressbar,
    503         jitter_max_retries=jitter_max_retries,
    504         tune=tune,
    505         initvals=initvals,
    506         **kwargs,
    507     )
    508 except (AttributeError, NotImplementedError, tg.NullTypeGradError):
    509     # gradient computation failed
    510     _log.info("Initializing NUTS failed. Falling back to elementwise auto-assignment.")

File ~/pymc/pymc/sampling.py:2320, in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2313 _log.info(f"Initializing NUTS using {init}...")
   2315 cb = [
   2316     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
   2317     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   2318 ]
-> 2320 initial_points = _init_jitter(
   2321     model,
   2322     initvals,
   2323     seeds=seeds,
   2324     jitter="jitter" in init,
   2325     jitter_max_retries=jitter_max_retries,
   2326 )
   2328 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   2329 apoints_data = [apoint.data for apoint in apoints]

File ~/pymc/pymc/sampling.py:2203, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2201 rng = np.random.RandomState(seed)
   2202 for i in range(jitter_max_retries + 1):
-> 2203     point = ipfn(seed)
   2204     if i < jitter_max_retries:
   2205         try:

File ~/pymc/pymc/initial_point.py:216, in make_initial_point_fn.<locals>.make_seeded_function.<locals>.inner(seed, *args, **kwargs)
    214         new_rng = np.random.Generator(seed)
    215     rng.set_value(new_rng, True)
--> 216 values = func(*args, **kwargs)
    217 return dict(zip(varnames, values))

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/compile/function/types.py:969, in Function.__call__(self, *args, **kwargs)
    966 t0_fn = time.time()
    967 try:
    968     outputs = (
--> 969         self.fn()
    970         if output_subset is None
    971         else self.fn(output_subset=output_subset)
    972     )
    973 except Exception:
    974     restore_defaults()

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/vm.py:313, in LoopGC.__call__(self)
    311             old_s[0] = None
    312 except Exception:
--> 313     raise_with_op(self.fgraph, node, thunk)

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/utils.py:525, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    520     warnings.warn(
    521         f"{exc_type} error does not allow us to add an extra error message"
    522     )
    523     # Some exception need extra parameter in inputs. So forget the
    524     # extra long error message in that case.
--> 525 raise exc_value.with_traceback(exc_trace)

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
    305 try:
    306     for thunk, node, old_storage in zip(
    307         self.thunks, self.nodes, self.post_thunk_clear
    308     ):
--> 309         thunk()
    310         for old_s in old_storage:
    311             old_s[0] = None

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/graph/op.py:508, in Op.make_py_thunk.<locals>.rval(p, i, o, n)
    507 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508     r = p(n, [x[0] for x in i], o)
    509     for o in node.outputs:
    510         compute_map[o][0] = True

File ~/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:382, in RandomVariable.perform(self, node, inputs, outputs)
    378     rng = copy(rng)
    380 rng_var_out[0] = rng
--> 382 smpl_val = self.rng_fn(rng, *(args + [size]))
    384 if (
    385     not isinstance(smpl_val, np.ndarray)
    386     or str(smpl_val.dtype) != out_var.type.dtype
    387 ):
    388     smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)

File ~/pymc/pymc/distributions/multivariate.py:1150, in _LKJCholeskyCovRV.rng_fn(self, rng, n, eta, D, size)
   1146     flat_size = np.prod(size)
   1148 C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
-> 1150 D = D.reshape(flat_size, n)
   1151 C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
   1153 tril_idx = np.tril_indices(n, k=0)

ValueError: cannot reshape array of size 1 into shape (1,2)
Apply node that caused the error: _lkjcholeskycov_rv{1, (0, 0, 1), floatX, False}(RandomStateSharedVariable(<RandomState(PCG64) at 0x7F002A93CC40>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2}, TensorConstant{2.0}, exponential_rv{0, (0,), floatX, False}.out)
Toposort index: 2
Inputs types: [RandomStateType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(int32, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: ['No shapes', (0,), (), (), (), ()]
Inputs strides: ['No strides', (8,), (), (), (), ()]
Inputs values: [RandomState(PCG64) at 0x7F002A93CC40, array([], dtype=int64), array(11), array(2, dtype=int32), array(2.), array(0.3459677)]
Outputs clients: [[], [Elemwise{second,no_inplace}(chol, TensorConstant{(1,) of 0.0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/fonnesbeck/pymc/pymc/distributions/multivariate.py", line 1181, in __new__
    return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
  File "/home/fonnesbeck/pymc/pymc/distributions/distribution.py", line 266, in __new__
    rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
  File "/home/fonnesbeck/pymc/pymc/distributions/distribution.py", line 165, in _make_rv_and_resize_shape
    rv_out = cls.dist(*args, **kwargs)
  File "/home/fonnesbeck/pymc/pymc/distributions/multivariate.py", line 1203, in dist
    return super().dist([n, eta, sd_dist], size=size, **kwargs)
  File "/home/fonnesbeck/pymc/pymc/distributions/distribution.py", line 353, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
  File "/home/fonnesbeck/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py", line 293, in __call__
    res = super().__call__(rng, size, dtype, *args, **kwargs)
  File "/home/fonnesbeck/miniforge3/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/graph/op.py", line 283, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/fonnesbeck/pymc/pymc/distributions/multivariate.py", line 1134, in make_node
    return super().make_node(rng, size, dtype, n, eta, D)

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: current master
  • Aesara/Theano Version: 2.3.8
  • Python Version: 3.9
  • Operating system: Linux (WSL)
  • How did you install PyMC/PyMC3: pip

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions