diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c55a742ee..a03dc8423 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -51,7 +51,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Give PyPI some time to update the index run: sleep 240 - name: Attempt install from PyPI diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be5624169..d56c8ade1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,7 @@ jobs: - style strategy: matrix: - python-version: [ '3.8', '3.10'] + python-version: [ '3.9', '3.11'] steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} diff --git a/.readthedocs.yaml b/.readthedocs.yaml index cb5f162c6..611a9bebb 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,7 +1,7 @@ version: 2 python: - version: "3.8" + version: "3.9" install: - method: pip path: . diff --git a/blackjax/adaptation/mass_matrix.py b/blackjax/adaptation/mass_matrix.py index 4cd84492e..dc0730161 100644 --- a/blackjax/adaptation/mass_matrix.py +++ b/blackjax/adaptation/mass_matrix.py @@ -18,7 +18,7 @@ parameters used in Hamiltonian Monte Carlo. """ -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -68,7 +68,7 @@ class MassMatrixAdaptationState(NamedTuple): def mass_matrix_adaptation( is_diagonal_matrix: bool = True, -) -> Tuple[Callable, Callable, Callable]: +) -> tuple[Callable, Callable, Callable]: """Adapts the values in the mass matrix by computing the covariance between parameters. @@ -156,7 +156,7 @@ def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState: return init, update, final -def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]: +def welford_algorithm(is_diagonal_matrix: bool) -> tuple[Callable, Callable, Callable]: r"""Welford's online estimator of covariance. It is possible to compute the variance of a population of values in an @@ -231,7 +231,7 @@ def update( def final( wa_state: WelfordAlgorithmState, - ) -> Tuple[Array, int, Array]: + ) -> tuple[Array, int, Array]: mean, m2, sample_size = wa_state covariance = m2 / (sample_size - 1) return covariance, sample_size, mean diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index 3d05bc1d1..c70ed3f99 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Pathinder warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Tuple, Union +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp @@ -128,7 +128,7 @@ def update( new_ss_state, new_step_size, adaptation_state.inverse_mass_matrix ) - def final(warmup_state: PathfinderAdaptationState) -> Tuple[float, Array]: + def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]: """Return the final values for the step size and inverse mass matrix.""" step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg) inverse_mass_matrix = warmup_state.inverse_mass_matrix diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 298f702f6..2d6b0182f 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Step size adaptation""" -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -64,7 +64,7 @@ class DualAveragingAdaptationState(NamedTuple): def dual_averaging_adaptation( target: float, t0: int = 10, gamma: float = 0.05, kappa: float = 0.75 -) -> Tuple[Callable, Callable, Callable]: +) -> tuple[Callable, Callable, Callable]: """Tune the step size in order to achieve a desired target acceptance rate. Let us note :math:`\\epsilon` the current step size, :math:`\\alpha_t` the diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index b99d787f1..cc871b4b6 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, List, NamedTuple, Tuple, Union +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp @@ -45,7 +45,7 @@ class WindowAdaptationState(NamedTuple): def base( is_mass_matrix_diagonal: bool, target_acceptance_rate: float = 0.80, -) -> Tuple[Callable, Callable, Callable]: +) -> tuple[Callable, Callable, Callable]: """Warmup scheme for sampling procedures based on euclidean manifold HMC. The schedule and algorithms used match Stan's :cite:p:`stan_hmc_param` as closely as possible. @@ -191,7 +191,7 @@ def slow_final(warmup_state: WindowAdaptationState) -> WindowAdaptationState: def update( adaptation_state: WindowAdaptationState, - adaptation_stage: Tuple, + adaptation_stage: tuple, position: ArrayLikeTree, acceptance_rate: float, ) -> WindowAdaptationState: @@ -233,7 +233,7 @@ def update( return warmup_state - def final(warmup_state: WindowAdaptationState) -> Tuple[float, Array]: + def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: """Return the final values for the step size and mass matrix.""" step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg) inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix @@ -362,7 +362,7 @@ def build_schedule( initial_buffer_size: int = 75, final_buffer_size: int = 50, first_window_size: int = 25, -) -> List[Tuple[int, bool]]: +) -> list[tuple[int, bool]]: """Return the schedule for Stan's warmup. The schedule below is intended to be as close as possible to Stan's :cite:p:`stan_hmc_param`. diff --git a/blackjax/base.py b/blackjax/base.py index 0ad6a1628..7f709b895 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple from typing_extensions import Protocol @@ -64,7 +64,7 @@ class UpdateFn(Protocol): """ - def __call__(self, rng_key: PRNGKey, state: State) -> Tuple[State, Info]: + def __call__(self, rng_key: PRNGKey, state: State) -> tuple[State, Info]: """Update the current state using the sampling algorithm. Parameters diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index e8010ffb5..4ff310445 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the Elliptical Slice sampling Kernel""" -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -110,7 +110,7 @@ def kernel( rng_key: PRNGKey, state: EllipSliceState, logdensity_fn: Callable, - ) -> Tuple[EllipSliceState, EllipSliceInfo]: + ) -> tuple[EllipSliceState, EllipSliceInfo]: proposal_generator = elliptical_proposal( logdensity_fn, momentum_generator, mean ) @@ -205,7 +205,7 @@ def elliptical_proposal( def generate( rng_key: PRNGKey, state: EllipSliceState - ) -> Tuple[EllipSliceState, EllipSliceInfo]: + ) -> tuple[EllipSliceState, EllipSliceInfo]: position, logdensity = state key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3) # step 1: sample momentum diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 62462ae68..a068acee7 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the Generalized (Non-reversible w/ persistent momentum) HMC Kernel""" -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -104,7 +104,7 @@ def kernel( momentum_inverse_scale: ArrayLikeTree, alpha: float, delta: float, - ) -> Tuple[GHMCState, hmc.HMCInfo]: + ) -> tuple[GHMCState, hmc.HMCInfo]: """Generate new sample with the Generalized HMC kernel. Parameters diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index ecdf2394e..228fd0b51 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the HMC Kernel""" -from typing import Callable, NamedTuple, Tuple, Union +from typing import Callable, NamedTuple, Union import jax @@ -112,7 +112,7 @@ def kernel( step_size: float, inverse_mass_matrix: Array, num_integration_steps: int, - ) -> Tuple[HMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( @@ -281,7 +281,7 @@ def hmc_proposal( def generate( rng_key, state: integrators.IntegratorState - ) -> Tuple[integrators.IntegratorState, HMCInfo]: + ) -> tuple[integrators.IntegratorState, HMCInfo]: """Generate a new chain state.""" end_state = build_trajectory(state, step_size, num_integration_steps) end_state = flip_momentum(end_state) diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 7e76202a0..0f4295a0e 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -13,7 +13,7 @@ # limitations under the License. """Public API for Metropolis Adjusted Langevin kernels.""" import operator -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -96,7 +96,7 @@ def transition_energy(state, new_state, step_size): def kernel( rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float - ) -> Tuple[MALAState, MALAInfo]: + ) -> tuple[MALAState, MALAInfo]: """Generate a new sample with the MALA kernel.""" grad_fn = jax.value_and_grad(logdensity_fn) integrator = diffusions.overdamped_langevin(grad_fn) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 262345abd..a24bc00b4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -27,7 +27,7 @@ We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`. """ -from typing import Callable, Tuple +from typing import Callable import jax.numpy as jnp import jax.scipy as jscipy @@ -43,7 +43,7 @@ def gaussian_euclidean( inverse_mass_matrix: Array, -) -> Tuple[Callable, EuclideanKineticEnergy, Callable]: +) -> tuple[Callable, EuclideanKineticEnergy, Callable]: r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum :cite:p:`betancourt2013general`. The gaussian euclidean metric is a euclidean metric further characterized diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 6f3b4fc4b..e09841ccf 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the NUTS Kernel""" -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -121,7 +121,7 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, - ) -> Tuple[hmc.HMCState, NUTSInfo]: + ) -> tuple[hmc.HMCState, NUTSInfo]: """Generate a new sample with the NUTS kernel.""" ( diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index ae0d9fede..a8bb54787 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for Periodic Orbital Kernel""" -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -142,7 +142,7 @@ def kernel( step_size: float, inverse_mass_matrix: Array, period: int, - ) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: + ) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: """Generate a new orbit with the Periodic Orbital kernel. Choose a step from the orbit with probability proportional to its weights. @@ -325,7 +325,7 @@ def periodic_orbital_proposal( def generate( direction: int, init_state: integrators.IntegratorState - ) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: + ) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: """Generate orbit by applying bijection forwards and backwards on period. As described in algorithm 2 of :cite:p:`neklyudov2022orbital`, each iteration of the periodic orbital diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 19983a450..642c76623 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -42,7 +42,7 @@ class Proposal(NamedTuple): def proposal_generator( energy: Callable, divergence_threshold: float -) -> Tuple[Callable, Callable]: +) -> tuple[Callable, Callable]: """ Parameters @@ -61,7 +61,7 @@ def proposal_generator( def new(state: TrajectoryState) -> Proposal: return Proposal(state, energy(state), 0.0, -jnp.inf) - def update(initial_energy: float, state: TrajectoryState) -> Tuple[Proposal, bool]: + def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]: """Generate a new proposal from a trajectory state. The trajectory state records information about the position in the state @@ -95,7 +95,7 @@ def proposal_from_energy_diff( new_energy: float, divergence_threshold: float, state: TrajectoryState, -) -> Tuple[Proposal, bool]: +) -> tuple[Proposal, bool]: """Computes a new proposal from the energy difference between two states. It also verifies whether this difference is a divergence, if the @@ -141,7 +141,7 @@ def asymmetric_proposal_generator( transition_energy_fn: Callable, divergence_threshold: float, proposal_factory: Callable = proposal_from_energy_diff, -) -> Tuple[Callable, Callable]: +) -> tuple[Callable, Callable]: """A proposal generator that takes into account the transition between two states to compute a new proposal. @@ -171,7 +171,7 @@ def update( initial_state: TrajectoryState, state: TrajectoryState, **energy_params, - ) -> Tuple[Proposal, bool]: + ) -> tuple[Proposal, bool]: new_energy = transition_energy_fn(initial_state, state, **energy_params) prev_energy = transition_energy_fn(state, initial_state, **energy_params) return proposal_factory(prev_energy, new_energy, divergence_threshold, state) diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 3f260ffa4..6d97c7c08 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -60,7 +60,7 @@ new_state, info = step(rng_key, state) """ -from typing import Callable, NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Optional import jax import numpy as np @@ -171,7 +171,7 @@ def build_additive_step(): def kernel( rng_key: PRNGKey, state: RWState, logdensity_fn: Callable, random_step: Callable - ) -> Tuple[RWState, RWInfo]: + ) -> tuple[RWState, RWInfo]: def proposal_generator(key_proposal, position): move_proposal = random_step(key_proposal, position) new_position = jax.tree_util.tree_map(jnp.add, position, move_proposal) @@ -271,7 +271,7 @@ def kernel( state: RWState, logdensity_fn: Callable, proposal_distribution: Callable, - ) -> Tuple[RWState, RWInfo]: + ) -> tuple[RWState, RWInfo]: """ Parameters @@ -362,7 +362,7 @@ def kernel( logdensity_fn: Callable, transition_generator: Callable, proposal_logdensity_fn: Optional[Callable] = None, - ) -> Tuple[RWState, RWInfo]: + ) -> tuple[RWState, RWInfo]: """Move the chain by one step using the Rosenbluth Metropolis Hastings algorithm. @@ -493,7 +493,7 @@ def build_trajectory(rng_key, initial_state: RWState) -> RWState: new_position = transition_distribution(rng_key, position) return RWState(new_position, logdensity_fn(new_position)) - def generate(rng_key, state: RWState) -> Tuple[RWState, bool, float]: + def generate(rng_key, state: RWState) -> tuple[RWState, bool, float]: key_proposal, key_accept = jax.random.split(rng_key, 2) end_state = build_trajectory(key_proposal, state) new_proposal, _ = generate_proposal(state, end_state) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 265afc351..81d369c0b 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -36,7 +36,7 @@ memory by keeping states that will subsequently be discarded. """ -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -70,7 +70,7 @@ def append_to_trajectory(trajectory: Trajectory, state: IntegratorState) -> Traj def reorder_trajectories( direction: int, trajectory: Trajectory, new_trajectory: Trajectory -) -> Tuple[Trajectory, Trajectory]: +) -> tuple[Trajectory, Trajectory]: """Order the two trajectories depending on the direction.""" return jax.lax.cond( direction > 0, diff --git a/blackjax/optimizers/dual_averaging.py b/blackjax/optimizers/dual_averaging.py index 1e3eca782..94b7aaa34 100644 --- a/blackjax/optimizers/dual_averaging.py +++ b/blackjax/optimizers/dual_averaging.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax.numpy as jnp @@ -52,7 +52,7 @@ class DualAveragingState(NamedTuple): def dual_averaging( t0: int = 10, gamma: float = 0.05, kappa: float = 0.75 -) -> Tuple[Callable, Callable, Callable]: +) -> tuple[Callable, Callable, Callable]: """Find the state that minimizes an objective function using a primal-dual subgradient method. diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index ba688ff8e..de653108c 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -70,7 +70,7 @@ def minimize_lbfgs( gtol: float = 1e-08, ftol: float = 1e-05, maxls: int = 1000, -) -> Tuple[OptStep, LBFGSHistory]: +) -> tuple[OptStep, LBFGSHistory]: """ Minimize a function using L-BFGS @@ -152,7 +152,7 @@ def _minimize_lbfgs( gtol: float, ftol: float, maxls: int, -) -> Tuple[OptStep, LBFGSHistory]: +) -> tuple[OptStep, LBFGSHistory]: def lbfgs_one_step(carry, i): (params, state), previous_history = carry diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index d2b24a9f7..bbb71761a 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable import jax import jax.numpy as jnp @@ -95,7 +95,7 @@ def kernel( state: tempered.TemperedSMCState, num_mcmc_steps: int, mcmc_parameters: dict, - ) -> Tuple[tempered.TemperedSMCState, base.SMCInfo]: + ) -> tuple[tempered.TemperedSMCState, base.SMCInfo]: delta = compute_delta(state) lmbda = delta + state.lmbda return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters) @@ -143,7 +143,7 @@ def __new__( # type: ignore[misc] loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, + mcmc_parameters: dict, resampling_fn: Callable, target_ess: float, root_solver: Callable = solver.dichotomy, diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 0732e0404..52d1338ef 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Optional import jax import jax.numpy as jnp @@ -59,7 +59,7 @@ def step( weigh_fn: Callable, resample_fn: Callable, num_resampled: Optional[int] = None, -) -> Tuple[SMCState, SMCInfo]: +) -> tuple[SMCState, SMCInfo]: """General SMC sampling step. `update_fn` here corresponds to the Markov kernel $M_{t+1}$, and `weigh_fn` diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index f7de5768f..40e95b665 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -97,7 +97,7 @@ def kernel( num_mcmc_steps: int, lmbda: float, mcmc_parameters: dict, - ) -> Tuple[TemperedSMCState, smc.base.SMCInfo]: + ) -> tuple[TemperedSMCState, smc.base.SMCInfo]: """Move the particles one step using the Tempered SMC algorithm. Parameters @@ -191,7 +191,7 @@ def __new__( # type: ignore[misc] loglikelihood_fn: Callable, mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, + mcmc_parameters: dict, resampling_fn: Callable, num_mcmc_steps: int = 10, ) -> SamplingAlgorithm: diff --git a/blackjax/vi/meanfield_vi.py b/blackjax/vi/meanfield_vi.py index e7f5c409d..8d5defa15 100644 --- a/blackjax/vi/meanfield_vi.py +++ b/blackjax/vi/meanfield_vi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -61,7 +61,7 @@ def step( optimizer: GradientTransformation, num_samples: int = 5, stl_estimator: bool = True, -) -> Tuple[MFVIState, MFVIInfo]: +) -> tuple[MFVIState, MFVIInfo]: """Approximate the target density using the mean-field approximation. Parameters @@ -141,7 +141,7 @@ def __new__( def init_fn(position: ArrayLikeTree): return cls.init(position, optimizer) - def step_fn(rng_key: PRNGKey, state: MFVIState) -> Tuple[MFVIState, MFVIInfo]: + def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]: return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples) def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int): diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index 7cb40e437..e494b026c 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Tuple, Union +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp @@ -79,7 +79,7 @@ def approximate( maxls=1000, gtol=1e-08, ftol=1e-05, -) -> Tuple[PathfinderState, PathfinderInfo]: +) -> tuple[PathfinderState, PathfinderInfo]: """Pathfinder variational inference algorithm. Pathfinder locates normal approximations to the target density along a @@ -200,7 +200,7 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): def sample( rng_key: PRNGKey, state: PathfinderState, - num_samples: Union[int, Tuple[()], Tuple[int]] = (), + num_samples: Union[int, tuple[()], tuple[int]] = (), ) -> ArrayTree: """Draw from the Pathfinder approximation of the target distribution. diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index 9ec7f28ff..f93941aee 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Callable, Dict, NamedTuple +from typing import Any, Callable, NamedTuple import jax import jax.numpy as jnp @@ -14,13 +14,13 @@ class SVGDState(NamedTuple): particles: ArrayTree - kernel_parameters: Dict[str, ArrayTree] + kernel_parameters: dict[str, ArrayTree] opt_state: Any def init( initial_particles: ArrayLikeTree, - kernel_parameters: Dict[str, Any], + kernel_parameters: dict[str, Any], optimizer: optax.GradientTransformation, ) -> SVGDState: """ @@ -156,7 +156,7 @@ def __new__( def init_fn( initial_position: ArrayLikeTree, - kernel_parameters: Dict[str, Any] = {"length_scale": 1.0}, + kernel_parameters: dict[str, Any] = {"length_scale": 1.0}, ): return cls.init(initial_position, kernel_parameters, optimizer) diff --git a/pyproject.toml b/pyproject.toml index 51ac3b636..ae630a9ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "blackjax" authors= [{name = "The Blackjax team", email = "remi@thetypicalset.com"}] description = "Flexible and fast sampling in Python" -requires-python = ">=3.8" +requires-python = ">=3.9" keywords=[ "probability", "machine learning", @@ -22,9 +22,9 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: MacOS", "Operating System :: POSIX", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Education", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence",