-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve support for dims
in LKJCholeskyCov
#6828
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1413,14 +1413,32 @@ class LKJCholeskyCov: | |
""" | ||
|
||
def __new__(cls, name, eta, n, sd_dist, *, compute_corr=True, store_in_trace=True, **kwargs): | ||
dims = kwargs.pop("dims", None) | ||
|
||
if dims is not None: | ||
# TODO: Add check for 2d dims? | ||
packed_dim_name, packed_dim_value = cls._make_packed_coord_from_dims( | ||
n, dims, "packed_tril" | ||
) | ||
cls._register_new_coords_with_model(packed_dim_name, packed_dim_value) | ||
kwargs["dims"] = [packed_dim_name] | ||
|
||
packed_chol = _LKJCholeskyCov(name, eta=eta, n=n, sd_dist=sd_dist, **kwargs) | ||
|
||
if not compute_corr: | ||
return packed_chol | ||
else: | ||
chol, corr, stds = cls.helper_deterministics(n, packed_chol) | ||
if store_in_trace: | ||
corr = pm.Deterministic(f"{name}_corr", corr) | ||
stds = pm.Deterministic(f"{name}_stds", stds) | ||
corr_triu = corr[pt.triu_indices_from(corr, k=1)] | ||
corr_triu_dim_name, corr_triu_dim_value = cls._make_packed_coord_from_dims( | ||
n, dims, "corr", lower=False, k=1 | ||
) | ||
cls._register_new_coords_with_model(corr_triu_dim_name, corr_triu_dim_value) | ||
|
||
corr_tril = pm.Deterministic(f"{name}_corr", corr_triu, dims=corr_triu_dim_name) | ||
stds = pm.Deterministic(f"{name}_stds", stds, dims=dims[0]) | ||
|
||
return chol, corr, stds | ||
|
||
@classmethod | ||
|
@@ -1443,6 +1461,28 @@ def helper_deterministics(cls, n, packed_chol): | |
corr = inv_stds[None, :] * cov * inv_stds[:, None] | ||
return chol, corr, stds | ||
|
||
@classmethod | ||
def _make_packed_coord_from_dims(cls, n, dims, name_prefix, lower=True, k=0): | ||
mod = pm.modelcontext(None) | ||
chol_dims = [mod.coords[dim] for dim in dims] | ||
if lower: | ||
f_idx = np.tril_indices | ||
else: | ||
f_idx = np.triu_indices | ||
|
||
flat_tri_idx = np.arange(n**2, dtype=int).reshape(n, n)[f_idx(n, k=k)] | ||
coord_product = np.fromiter([f"{x}" for x in product(*chol_dims)], dtype="object") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit scared that this could break things when we serialize traces. Can we at least store those in netcdf and zarr? |
||
tri_coords = coord_product[flat_tri_idx].tolist() | ||
|
||
packed_dim_name = f"{name_prefix}_{dims[0]}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would much prefer a postfix instead of a prefix, especially in this case, because we already use that for the deterministics. |
||
return packed_dim_name, tri_coords | ||
|
||
@classmethod | ||
def _register_new_coords_with_model(cls, name, value): | ||
mod = pm.modelcontext(None) | ||
mod.coords[name] = value | ||
mod.dim_lengths[name] = pt.TensorConstant(pt.lscalar, np.array(len(value))) | ||
Comment on lines
+1483
to
+1484
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably use
Comment on lines
+1480
to
+1484
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This must be a method already in Model no? |
||
|
||
|
||
class LKJCorrRV(RandomVariable): | ||
name = "lkjcorr" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this should be done by default. If users want to resize the model they will have to know there is a special dims they also need to update