Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removal of Algorithm classes. #657

Merged
merged 13 commits into from
Apr 22, 2024
181 changes: 133 additions & 48 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,156 @@
import dataclasses
from typing import Callable

from blackjax._version import __version__

from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
from .mcmc.dynamic_hmc import dynamic_hmc
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .mcmc.rmhmc import rmhmc
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice, ghmc
from .mcmc import hmc as _hmc
from .mcmc import mala, marginal_latent_gaussian, mclmc
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk, rmhmc
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
from .mcmc.random_walk import (
irmh_as_top_level_api,
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.schrodinger_follmer import schrodinger_follmer
from .vi.svgd import svgd
from .sgmcmc import csgld, sghmc, sgld, sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
from .vi.pathfinder import PathFinderAlgorithm

"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
level to be mostly functional programming in nature and reducing boilerplate code.
"""


@dataclasses.dataclass
class GenerateSamplingAPI:
differentiable: Callable
init: Callable
build_kernel: Callable

def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
return self.differentiable(*args, **kwargs)

def register_factory(self, name, callable):
setattr(self, name, callable)


@dataclasses.dataclass
class GenerateVariationalAPI:
differentiable: Callable
init: Callable
step: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)


@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
return self.differentiable(*args, **kwargs)


def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
)


# MCMC
hmc = generate_top_level_api_from(_hmc)
nuts = generate_top_level_api_from(_nuts)
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
rmhmc = generate_top_level_api_from(rmhmc)
ciguaran marked this conversation as resolved.
Show resolved Hide resolved
mala = generate_top_level_api_from(mala)
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
orbital_hmc = generate_top_level_api_from(periodic_orbital)

additive_step_random_walk = GenerateSamplingAPI(
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
)

additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(mclmc)
elliptical_slice = generate_top_level_api_from(elliptical_slice)
ghmc = generate_top_level_api_from(ghmc)
barker_proposal = generate_top_level_api_from(barker)

hmc_family = [hmc, nuts]

# SMC
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)

smc_family = [tempered_smc, adaptive_tempered_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
sgld = generate_top_level_api_from(sgld)
sghmc = generate_top_level_api_from(sghmc)
sgnht = generate_top_level_api_from(sgnht)
csgld = generate_top_level_api_from(csgld)
svgd = generate_top_level_api_from(_svgd)

# variational inference
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)

pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)


__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"dynamic_hmc",
"rmhmc",
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"irmh",
"mclmc",
"elliptical_slice",
"ghmc",
"barker_proposal",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
"svgd",
"ess", # diagnostics
"rhat",
]
5 changes: 2 additions & 3 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Union
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
import blackjax.vi as vi
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.step_size import (
Expand Down Expand Up @@ -138,7 +137,7 @@ def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]:


def pathfinder_adaptation(
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
algorithm,
logdensity_fn: Callable,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
Expand Down
7 changes: 3 additions & 4 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Union
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
Expand Down Expand Up @@ -243,7 +242,7 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:


def window_adaptation(
algorithm: Union[mcmc.hmc.hmc, mcmc.nuts.nuts],
algorithm,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this type not SamplingAlgorithm? I guess it is somewhat more specific than that, but it seems like it needs to at least implement the SamplingAlgorithm protocol...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but adding that type annotation would tell the user that any subtype of SamplingAlgorithm can be used as parameter to that function, which would not be true.

logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
Expand All @@ -252,7 +251,7 @@ def window_adaptation(
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC fmaily.
algorithms in the HMC family. See Blackjax.hmc_family

Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
Expand Down
29 changes: 12 additions & 17 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"]
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]


class BarkerState(NamedTuple):
Expand Down Expand Up @@ -128,7 +128,10 @@ def kernel(
return kernel


class barker_proposal:
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.

Expand Down Expand Up @@ -179,24 +182,16 @@ class barker_proposal:

"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)
kernel = build_kernel()

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
step_size: float,
) -> SamplingAlgorithm:
kernel = cls.build_kernel()
def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position, logdensity_fn)

def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return cls.init(position, logdensity_fn)
def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)

return SamplingAlgorithm(init_fn, step_fn)
return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
Expand Down
63 changes: 28 additions & 35 deletions blackjax/mcmc/dynamic_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"DynamicHMCState",
"init",
"build_kernel",
"dynamic_hmc",
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
"halton_sequence",
]

Expand Down Expand Up @@ -115,7 +114,16 @@ def kernel(
return kernel


class dynamic_hmc:
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
*,
divergence_threshold: int = 1000,
integrator: Callable = integrators.velocity_verlet,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the dynamic HMC kernel.

Parameters
Expand Down Expand Up @@ -144,41 +152,26 @@ class dynamic_hmc:
-------
A ``SamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
*,
divergence_threshold: int = 1000,
integrator: Callable = integrators.velocity_verlet,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
) -> SamplingAlgorithm:
kernel = cls.build_kernel(
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
kernel = build_kernel(
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
)

def init_fn(position: ArrayLikeTree, rng_key: Array):
# Note that rng_key here is not necessarily a PRNGKey, could be a Array that
# for generates a sequence of pseudo or quasi-random numbers (previously
# named as `random_generator_arg`)
return init(position, logdensity_fn, rng_key)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
step_size,
inverse_mass_matrix,
)

def init_fn(position: ArrayLikeTree, rng_key: Array):
# Note that rng_key here is not necessarily a PRNGKey, could be a Array that
# for generates a sequence of pseudo or quasi-random numbers (previously
# named as `random_generator_arg`)
return cls.init(position, logdensity_fn, rng_key)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
step_size,
inverse_mass_matrix,
)

return SamplingAlgorithm(init_fn, step_fn)
return SamplingAlgorithm(init_fn, step_fn)


def halton_sequence(i: Array, max_bits: int = 10) -> float:
Expand Down
Loading
Loading