Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Regression in JAX model ops #6993

Closed
velochy opened this issue Nov 7, 2023 · 4 comments · Fixed by #6995
Closed

BUG: Regression in JAX model ops #6993

velochy opened this issue Nov 7, 2023 · 4 comments · Fixed by #6995

Comments

@velochy
Copy link
Contributor

velochy commented Nov 7, 2023

Describe the issue:

A model that sampled fine in 5.8.0 no longer works in 5.9.1 and throws a NotImplementedError (below)

The code might look a bit convoluted, but what it does is essentially a Gaussian Process over a time dimension that has very sparse values, making it a pretty useful construct in my models.

Reproduceable code example:

import pandas as pd, numpy as np
import pymc as pm
print(pm.__version__)
from pymc.sampling import jax as pm_jax


df = pd.DataFrame([
       [-1.8, 'A'],
       [-1.8, 'B'],
       [-0.9, 'A'],
       [-1.8, 'B'],
       [-1.8, 'B'],
       [-0.9, 'B'],
       [-0.9, 'A']], columns=['t','response'])


times_idx, times = df["t"].factorize(sort=True)
resp_idx, responses = df['response'].factorize(sort=True)

COORDS = { 
    'time': times,
    'response': responses,
    'obs_idx': np.array(df.index)
}

with pm.Model(coords=COORDS) as h_multinomial_model:
    
    obs = pm.MutableData( "obs", resp_idx, dims=("obs_idx") )

    times_id = pm.MutableData("time_id", times_idx, dims="obs_idx")
    gp_inp = pm.MutableData('time_vals',np.array(times),dims="time")[:,None]

    ls = pm.Gamma(name='ls', alpha=5.0, beta=2.0)
    c1 = pm.gp.cov.Matern52(ls=ls,input_dim=1)

    gp_sds = pm.HalfNormal(f"σ_gp", 0.2, dims=('response',) )   
    α_time_offset = pm.MvNormal(f'α_time_offset', mu=0, cov=c1.full(gp_inp),dims=('response',"time"))
    α_time = pm.Deterministic(f'α_time', (gp_sds[:,None]*α_time_offset).transpose(), dims=("time",'response') )

        # likelihood
    _ = pm.Categorical(
        "y",
        p=pm.math.softmax(α_time[times_id], axis=-1),
        observed=obs,
        dims=("obs_idx"),
    )
    
    idata = pm_jax.sample_numpyro_nuts()

Error message:

`NotImplementedError: No JAX conversion for the given `Op`: Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)}`

PyMC version information:

PyMC 5.9.1

Context for the issue:

This is an issue with one of the main building blocks in the models I am working with, and while I can turn it off for the time being, it does add a lot of power to the models as it allows us to better model data that was gathered at different points in time.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 7, 2023

Hmm, the Blockwise is now introduced due to #6897.
We haven't yet implemented it for JAX, but in the meantime I wrote an optimization that removes it in pymc-devs/pytensor#482

You can copy the code from https://discourse.pymc.io/t/version-dependant-slowing-down-of-gaussian-mixture-sampling-in-ubuntu-20-04/13219/39?u=ricardov94 to fix your problem locally.

@velochy
Copy link
Contributor Author

velochy commented Nov 7, 2023

Thank you for the quick response yet again. I can confirm that adding that piece of code at the beginning of file nicely removes the issue :)

@velochy
Copy link
Contributor Author

velochy commented Dec 3, 2023

This looks to be solved. Should it be closed?

@ricardoV94
Copy link
Member

I want to add a test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants