Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no stats saved in InferenceData when using pymc.sampling_jax.sample_blackjax_nuts #5718

Closed
fbarfi opened this issue Apr 17, 2022 · 5 comments
Labels

Comments

@fbarfi
Copy link

fbarfi commented Apr 17, 2022

Description of your problem

No stats saved in the InderenceData.

arviz.InferenceData
posterior
posterior_predictive
log_likelihood
prior
observed_data

Is this specific to blackjax?
Note that I am appending two groups to the InferenceData (prior and posterior_predictive)
It works fine with sample_numpyro_nuts using the same code exactly
Thank you!

Please provide a minimal, self-contained, and reproducible example.

pm.sampling_jax.sample_blackjax_nuts(

Please provide the full traceback.

Complete error traceback
Local current time at start : Sun Apr 17 17:57:29 2022
10000 1000 4
Compiling...
Compilation time =  0:00:00.849628
Sampling...
Sampling time =  0:00:04.327564
Transforming variables...
Transformation time =  0:00:57.827094
 100.00% [40000/40000 00:03<00:00]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [18], in <module>
     89 with ph_poisson:  
     90     az.to_netcdf(idata,file_name)   
---> 92 plots_func(idata)
     94 print(pm.str_for_model(ph_poisson,formatting='Latex'))
     96 t2 = time.perf_counter()

Input In [5], in plots_func(data)
     21 file_density  = os.path.join(folders[1],'density_plots.png')
     22 plt.savefig(file_density ,dpi=300)
---> 24 az.plot_energy(data)
     25 file_energy  = os.path.join(folders[1],'energy_plots.png')
     26 plt.savefig(file_energy ,dpi=300)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/arviz/plots/energyplot.py:100, in plot_energy(data, kind, bfmi, figsize, legend, fill_alpha, fill_color, bw, textsize, fill_kwargs, plot_kwargs, ax, backend, backend_kwargs, show)
      9 def plot_energy(
     10     data,
     11     kind=None,
   (...)
     24     show=None,
     25 ):
     26     """Plot energy transition distribution and marginal energy distribution in HMC algorithms.
     27 
     28     This may help to diagnose poor exploration by gradient-based algorithms like HMC or NUTS.
   (...)
     98 
     99     """
--> 100     energy = convert_to_dataset(data, group="sample_stats").energy.values
    102     if kind == "histogram":
    103         warnings.warn(
    104             "kind histogram will be deprecated in a future release. Use `hist` "
    105             "or set rcParam `plot.density_kind` to `hist`",
    106             FutureWarning,
    107         )

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/arviz/data/converters.py:182, in convert_to_dataset(obj, group, coords, dims)
    180 dataset = getattr(inference_data, group, None)
    181 if dataset is None:
--> 182     raise ValueError(
    183         "Can not extract {group} from {obj}! See {filename} for other "
    184         "conversion utilities.".format(group=group, obj=obj, filename=__file__)
    185     )
    186 return dataset

ValueError: Can not extract sample_stats from Inference data with groups:
	> posterior
	> posterior_predictive
	> log_likelihood
	> prior
	> observed_data! See /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/arviz/data/converters.py for other conversion utilities.

Please provide any additional information below.

Versions and main components

Python implementation: CPython
Python version : 3.9.10
IPython version : 8.0.1

Compiler : Clang 11.1.0
OS : Darwin
Release : 21.4.0
Machine : arm64
Processor : arm
CPU cores : 10
Architecture: 64bit

csv : 1.0
xarray : 2022.3.0
sys : 3.9.10 | packaged by conda-forge | (main, Feb 1 2022, 21:27:43)
[Clang 11.1.0 ]
sklearn : 0.0
jax : 0.3.4
statsmodels: 0.13.2
platform : 1.0.8
blackjax : 0.4.0
seaborn : 0.11.2
numpy : 1.21.5
arviz : 0.12.0
aesara : 2.5.1
pandas : 1.4.1
jaxlib : 0.3.0
scipy : 1.7.3
matplotlib : 3.5.1
pymc : 4.0.0b6

  • How did you install PyMC/PyMC3: git dev
@fbarfi
Copy link
Author

fbarfi commented Apr 17, 2022

same issue using either

import pymc.sampling_jax.sample_blackjax_nuts

or

import pymc.sampling_jax_w_blax.sample_blackjax_nuts

Thanks.

@ricardoV94
Copy link
Member

Do numpyro or blackjax provide energy stats?

@fbarfi
Copy link
Author

fbarfi commented Apr 18, 2022

numpyro does but blackjax does not.

@ricardoV94
Copy link
Member

We can add if they are returned. Want to open a PR?

@fbarfi
Copy link
Author

fbarfi commented Apr 18, 2022

apologies, not sure how to open a PR. Never done it before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants