Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for dims in LKJCholeskyCov #6828

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,14 +1413,32 @@ class LKJCholeskyCov:
"""

def __new__(cls, name, eta, n, sd_dist, *, compute_corr=True, store_in_trace=True, **kwargs):
dims = kwargs.pop("dims", None)

if dims is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be done by default. If users want to resize the model they will have to know there is a special dims they also need to update

# TODO: Add check for 2d dims?
packed_dim_name, packed_dim_value = cls._make_packed_coord_from_dims(
n, dims, "packed_tril"
)
cls._register_new_coords_with_model(packed_dim_name, packed_dim_value)
kwargs["dims"] = [packed_dim_name]

packed_chol = _LKJCholeskyCov(name, eta=eta, n=n, sd_dist=sd_dist, **kwargs)

if not compute_corr:
return packed_chol
else:
chol, corr, stds = cls.helper_deterministics(n, packed_chol)
if store_in_trace:
corr = pm.Deterministic(f"{name}_corr", corr)
stds = pm.Deterministic(f"{name}_stds", stds)
corr_triu = corr[pt.triu_indices_from(corr, k=1)]
corr_triu_dim_name, corr_triu_dim_value = cls._make_packed_coord_from_dims(
n, dims, "corr", lower=False, k=1
)
cls._register_new_coords_with_model(corr_triu_dim_name, corr_triu_dim_value)

corr_tril = pm.Deterministic(f"{name}_corr", corr_triu, dims=corr_triu_dim_name)
stds = pm.Deterministic(f"{name}_stds", stds, dims=dims[0])

return chol, corr, stds

@classmethod
Expand All @@ -1443,6 +1461,28 @@ def helper_deterministics(cls, n, packed_chol):
corr = inv_stds[None, :] * cov * inv_stds[:, None]
return chol, corr, stds

@classmethod
def _make_packed_coord_from_dims(cls, n, dims, name_prefix, lower=True, k=0):
mod = pm.modelcontext(None)
chol_dims = [mod.coords[dim] for dim in dims]
if lower:
f_idx = np.tril_indices
else:
f_idx = np.triu_indices

flat_tri_idx = np.arange(n**2, dtype=int).reshape(n, n)[f_idx(n, k=k)]
coord_product = np.fromiter([f"{x}" for x in product(*chol_dims)], dtype="object")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit scared that this could break things when we serialize traces. Can we at least store those in netcdf and zarr?
I wouldn't mind if this just had integer coords either...

tri_coords = coord_product[flat_tri_idx].tolist()

packed_dim_name = f"{name_prefix}_{dims[0]}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would much prefer a postfix instead of a prefix, especially in this case, because we already use that for the deterministics.

return packed_dim_name, tri_coords

@classmethod
def _register_new_coords_with_model(cls, name, value):
mod = pm.modelcontext(None)
mod.coords[name] = value
mod.dim_lengths[name] = pt.TensorConstant(pt.lscalar, np.array(len(value)))
Comment on lines +1483 to +1484
Copy link
Member

@ricardoV94 ricardoV94 Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably use model.add_coord/s and check name doesn't clash with existing one

Comment on lines +1480 to +1484
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must be a method already in Model no?



class LKJCorrRV(RandomVariable):
name = "lkjcorr"
Expand Down
Loading