Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dims >= 3 #754

Closed
wd60622 opened this issue Jun 17, 2024 · 1 comment · Fixed by #759
Closed

Support dims >= 3 #754

wd60622 opened this issue Jun 17, 2024 · 1 comment · Fixed by #759

Comments

@wd60622
Copy link
Contributor

wd60622 commented Jun 17, 2024

With the introduction of hierarchical model config in #743, the operations on variable var with dims dims in order to broadcast with dims desired_dims was hardcoded up to 3D. That code here:

def handle_scalar(var: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorLike:
"""Broadcast a scalar to the desired dims."""
return var
def handle_1d(var: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorLike:
"""Broadcast a 1D variable to the desired dims."""
return var
def handle_2d(var: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorLike:
"""Broadcast a 2D variable to the desired dims."""
if dims == desired_dims:
return var
if dims[::-1] == desired_dims:
return var.T
if dims[0] == desired_dims[-2]:
return var[:, None]
return var
HANDLE_MAPPING = MappingProxyType({0: handle_scalar, 1: handle_1d, 2: handle_2d})
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]

and implemented here:
handle = HANDLE_MAPPING[ndims]
def handle_shape(
var: pt.TensorLike,
dims: Dims,
) -> pt.TensorLike:
"""Handle the shape for a hierarchical parameter."""
dims = desired_dims if dims is None else dims
dims = dims if isinstance(dims, tuple) else (dims,)
if not set(dims).issubset(set(desired_dims)):
raise UnsupportedShapeError("The dims of the variable are not supported.")
return handle(var, dims, desired_dims)

Some test cases might look like this:

var = np.array([1, 2, 3])
dims = "channel"
desired_dims = ("channel", "geo", "hierarchy")

handle(var, dims, desired_dims) == np.expand_dims(var, (1, 2))

var = np.array([[1, 2, 3], [4, 5, 6]])
dims = ("geo", "channel")
desired_dims = ("channel", "geo", "hierarchy")

handle(var, dims, desired_dims) == np.expand_dims(var.T, 2)

var = np.array([[1, 2, 3], [4, 5, 6]])
dims = ("hierarchy", "channel")
desired_dims = ("channel", "geo", "hierarchy")

handle(var, dims, desired_dims) == np.expand_dims(var.T, 1)

If this functionality doesn't already exist in numpy or pytensor, then functions like expand_dims, swapaxes, transpose might be helpful. Could exist in xarray itself since they autoalign data but not too sure.

This would be super helpful for #749 and generalizing the functionality of MMMs

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 18, 2024

This seems to work! Thanks @ricardoV94

import numpy as np
from pymc.distributions.shape_utils import Dims
import pytensor.tensor as pt


def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
    x = pt.as_tensor_variable(x)

    if np.ndim(x) == 0:
        return x

    dims = dims if isinstance(dims, tuple) else (dims,)
    desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)

    aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)

    missing_dims = aligned_dims.sum(axis=0) == 0
    new_idx = aligned_dims.argmax(axis=0)

    args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims)]
    return x.dimshuffle(*args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant