Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior support for "consumed" dims #841

Open
williambdean opened this issue Jul 18, 2024 · 10 comments
Open

Prior support for "consumed" dims #841

williambdean opened this issue Jul 18, 2024 · 10 comments
Labels
enhancement New feature or request Prior class

Comments

@williambdean
Copy link
Contributor

williambdean commented Jul 18, 2024

Some distributions will likely not work because of the check for parent dims being a superset of the child. For instance,

from pymc_marketing.prior import Prior, UnsupportedShapeError

p = Prior("Dirichlet", a=[1, 1, 1], dims="probs")
try: 
    Prior("Categorical", p=p, dims="trial")
except UnsupportedShapeError as e: 
    print(e)

This could be relaxed based on some logic from the rv_op. i.e. pm.Categorical.rv_op has ndim_supp=0 and ndims_params=(1,) or by parsing the numpy-like signature

Previous (incorrect) examples

EDIT: The following are wrong because categorical has signature of (p)->()

from pymc_marketing.prior import Prior

# Does work
p = Prior("Dirichlet", a=[1, 2, 3], dims="prob")
y = Prior("Categorical", p=p, dims=("trial", "prob"))
coords = {
    "trial": [0, 1, 2, 3, 4], 
    "prob": ["A", "B", "C"]
}
samples = y.sample_prior(coords=coords)

The transpose doesn't work as well

# Doesn't works
z = Prior("Categorical", p=p, dims=("prob", "trial"))
try: 
    z.sample_prior(coords=coords)
except ValueError as e: 
    print(e)

# But would with other distributions
mu = Prior("Normal", dims="prob")
x = Prior("Normal", mu=mu, dims=("prob", "trial"))
samples = x.sample_prior(coords=coords)

Ref: pymc-devs/pymc#7416 (reply in thread)

@williambdean williambdean added the enhancement New feature or request label Jul 18, 2024
@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 18, 2024

y = Prior("Categorical", p=p, dims=("trial", "prob")) this is not really meaningful dim-wise. the prob dimension from the p gets consumed when it goes into a Categorical. The signature is (p)->(), and not something like (p)->(p) (preserved in the output). So it doesn't make sense to have a dim of probs in the output.

This should make some sense. In the core case (without batch dims), you may have a vector of 100 probs that are consumed to generate a single scalar (a number between 0-99). You could ask to draw a batch of length 100 (dim probs) of these numbers, but that in general would be strange.

@williambdean
Copy link
Contributor Author

williambdean commented Jul 18, 2024

Yes, totally. Makes sense. Let me change up the scenarios then in the initial description.
The "consumed" one fails because of internal checks of parent dim names. There'd have to be some logic change to support this based on that rv_op.signature

EDIT: The first message has been modified

@williambdean
Copy link
Contributor Author

Some scenarios that should work:

# Broadcasting support
p = Prior("Dirichlet", a=np.ones((2, 5)), dims=("geo", "prob"))
y = Prior("Categorical", p=p, dims=("geo", "trial"))

The information from rv_op could potentially warn if a dim name should be provided. For instance,

# Can provide a user warning
p = Prior("Dirichlet", a=[1, 1, 1])

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 23, 2024

Some scenarios that should work:

# Broadcasting support
p = Prior("Dirichlet", a=np.ones((2, 5)), dims=("geo", "prob"))
y = Prior("Categorical", p=p, dims=("geo", "trial"))

The information from rv_op could potentially warn if a dim name should be provided. For instance,

# Can provide a user warning
p = Prior("Dirichlet", a=[1, 1, 1])

Is that example with the dims "geo", "trial" in that order in purpose? You would need to transpose them for it to work since it must broadcast to the left.

To be clear when defining a Categorical distribution with p with dims (geo, probs), you could have dims (geo,), that is a single vector, or (trial, geo), which would be a matrix of observations, or anything with batch dims to the left of geo such as (year, .. , trial, geo)

@williambdean
Copy link
Contributor Author

Is that example with the dims "geo", "trial" in that order in purpose? You would need to transpose them for it to work since it must broadcast to the left.

Yup, the order is on purpose in order to support the broadcasting even under scenario of "consumed" dims.
The rightmost dim(s?) is always consumed, right? At least in ndim_supp=0 & ndims_params with at least one greater than zero.

Have code below to get distributions that might need some additional investigation

from pytensor.tensor.random.basic import RandomVariable

import pymc as pm
from pymc.distributions.distribution import DistributionMeta

lookup = {}
for name in dir(pm):
    obj = getattr(pm, name)
    if isinstance(obj, DistributionMeta):
        lookup[name] = obj

def needs_investigation(
    rv_op,
) -> bool:
    """Non scalar to scalar"""
    return any(ndims != 0 for ndims in rv_op.ndims_params) or rv_op.ndim_supp != 0

rv_op_lookup = {}
for name, value in lookup.items():
    rv_op = value.rv_op

    # Another case to investigate
    if not isinstance(rv_op, RandomVariable):
        continue

    if not needs_investigation(rv_op):
        continue

    rv_op_lookup[name] = rv_op

Results in:

{'CAR': CARRV(name=car,ndim_supp=1,ndims_params=(1, 2, 0, 0),dtype=floatX,inplace=False),
 'Categorical': CategoricalRV(name=categorical,ndim_supp=0,ndims_params=(1,),dtype=int64,inplace=False),
 'Dirichlet': DirichletRV(name=dirichlet,ndim_supp=1,ndims_params=(1,),dtype=floatX,inplace=False),
 'ICAR': ICARRV(name=icar,ndim_supp=1,ndims_params=(2, 1, 1, 0, 0, 0),dtype=floatX,inplace=False),
 'Interpolated': InterpolatedRV(name=interpolated,ndim_supp=0,ndims_params=(1, 1, 1),dtype=floatX,inplace=False),
 'KroneckerNormal': KroneckerNormalRV(name=kroneckernormal,ndim_supp=1,ndims_params=(1, 0, 2),dtype=floatX,inplace=False),
 'MatrixNormal': MatrixNormalRV(name=matrixnormal,ndim_supp=2,ndims_params=(2, 2, 2),dtype=floatX,inplace=False),
 'Multinomial': MultinomialRV(name=multinomial,ndim_supp=1,ndims_params=(0, 1),dtype=int64,inplace=False),
 'MvNormal': MvNormalRV(name=multivariate_normal,ndim_supp=1,ndims_params=(1, 2),dtype=floatX,inplace=False),
 'MvStudentT': MvStudentTRV(name=multivariate_studentt,ndim_supp=1,ndims_params=(0, 1, 2),dtype=floatX,inplace=False),
 'StickBreakingWeights': StickBreakingWeightsRV(name=stick_breaking_weights,ndim_supp=1,ndims_params=(0, 0),dtype=floatX,inplace=False),
 'Wishart': WishartRV(name=wishart,ndim_supp=2,ndims_params=(0, 2),dtype=floatX,inplace=False)}

Currently implementation does support many Distributions already (any supported for PyMC-Marketing previously).

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 23, 2024

Yup, sounds like you're on it. Yeah the rightmost dims are the "core" ones.

One day when we implement dims in PyTensor we'll probably need extra kwargs for the user to tell us what dims should be used to core case. Like a Categorical may need to ve written like pt.random.categorical(p=p, dims=("geo", "trial"), p_dims="probs")

Because with named dims order loses any meaning and so we can't rely on it to disambiguate what's core from what is batch.

In this case we could probably introspect the dims of p and find out which one is there that is not "geo" or "trial". But like we don't require size to always be provided we may not want to require "dims" to be provided in which case we infer it from the dims of the parameters. This is all about future PyTensor API, not about the work here in pymc-marketing. Just using this as an excuse to think about it.

I don't know if you want to go that route here of asking which ones are the core dims or rely on the missing dims in the output to find out which one is the "probs". I'm wondering about matrix core inputs, like cov. In some cases order may matter, so even if you can reason about which dims are the core you may still need to know which one goes first. cov is not a problem because it has to be symmetric but that need not always be the case with matrix inputs.

The other tricky thing are multiple parametrizations. Like user may define MvNormal with the cholesky which is not an (m, m) matrix. Under the hood PyMC will convert it to a covariance and the Op signature is always correct. But here you're acting before PyMC does that conversion when you're trying to Dimshuffle the dims for then user.

Those edge cases is when we start desiring PyTensor to natively handle dims

@williambdean
Copy link
Contributor Author

williambdean commented Jul 23, 2024

Some things to note:

  • Current implementation does support consumed dims to some degree they aren't specified
    p = Prior("Dirichlet", a=[1, 2, 3])
    data = Prior("Categorical", p=p, dims="trial")
    samples = data.sample_prior(coords={"trial": range(100)})
    # var_p has "var_p_dim_0" of values [0, 1, 2]
  • Prior class doesn't support non-unique dims at the moment. I.e. nothing square is works Prior("Normal", dims=("dim", "dim")) # ValueError. This was implemented for the auto-broadcasting as the use of dimshuffle doesn't allow.
  • Native support of dims in pytensor seems like the way to go (long term) and this API might deprecate or just wrap that (and PyMC implementations)

As for the problem of consumed dims under the Prior class API? Is there preference to p_dims as separate parameter or the dims being part of the already existing dims parameter. I see advantages of both ways

@ricardoV94
Copy link
Contributor

Duplicate dims are not allowed anywhere. Even if it's square they must have different names

@ricardoV94
Copy link
Contributor

p = Prior("Dirichlet", a=[1, 2, 3])
data = Prior("Categorical", p=p, dims="trial")

Right but that will fail when they are specified and/or are not aligned according to the non-dim semantics of PyMC, right?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 23, 2024

For your pragmatic question, I don't know. I suggest trying what feels better and see how it goes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Prior class
Projects
None yet
Development

No branches or pull requests

2 participants