diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 5858c34aa..9a60b6617 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -10,6 +10,7 @@ from .adaptation.window_adaptation import window_adaptation from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess +from .diagnostics import nested_rhat as nested_rhat from .diagnostics import potential_scale_reduction as rhat from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc @@ -161,5 +162,6 @@ def generate_top_level_api_from(module): "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation "ess", # diagnostics + "nested_rhat", "rhat", ] diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 93480302e..b95fb9016 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -19,7 +19,7 @@ from blackjax.types import Array, ArrayLike -__all__ = ["potential_scale_reduction", "effective_sample_size"] +__all__ = ["potential_scale_reduction", "nested_rhat", "effective_sample_size"] def potential_scale_reduction( @@ -75,6 +75,66 @@ def potential_scale_reduction( return rhat_value.squeeze() +def nested_rhat( + input_array: ArrayLike, + superchain_axis: int = 0, + chain_axis: int = 1, + sample_axis: int = 2, +) -> Array: + """Margossian et al. (2024)'s nested R-hat for computing multiple MCMC superchain convergence. + + Parameters + ---------- + input_array + An array representing multiple superchains of MCMC smaples. The array must + contain a superchain dimension, chain dimension, and sample dimension. + superchain_axis + The axis indicating the multiple superchains. Default to 0. + chain_axis + The axis indicating the multiple chains. Default to 1. + sample_axis + The axis indicating a single chain of MCMC samples. Default to 2. + + Returns + ------- + NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed. + + """ + assert input_array.ndim == 4, "The input array must have 4 dimensions." + num_chains = input_array.shape[chain_axis] + num_samples = input_array.shape[sample_axis] + param_axis = 3 - (chain_axis + sample_axis + superchain_axis) + num_params = input_array.shape[param_axis] + assert ( + num_chains > 1 or num_samples > 1 + ), "num_chains or num_samples must be greater than 1 for valid nested R-hat." + + chain_means = jnp.mean(input_array, axis=sample_axis) + super_means = jnp.mean(chain_means, axis=chain_axis) + total_mean = jnp.mean(super_means, axis=superchain_axis) + + between_var = jnp.mean(jnp.square(super_means - total_mean), axis=superchain_axis) + + if num_chains > 1: + within_chain_var = jnp.mean( + jnp.square(chain_means - super_means), axis=chain_axis + ) + else: + within_chain_var = jnp.zeros(num_params) + + if num_samples > 1: + within_super_var = jnp.mean( + jnp.square(input_array - chain_means), axis=(chain_axis, sample_axis) + ) + else: + within_super_var = jnp.zeros(num_params) + + within_var = jnp.mean(within_chain_var + within_super_var, axis=superchain_axis) + + nested_rhat_value = jnp.sqrt(1 + between_var / within_var) + return nested_rhat_value.squeeze() + + def effective_sample_size( input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1 ) -> Array: