Skip to content

Commit

Permalink
Bump to Py3.9 (blackjax-devs#554)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Jun 21, 2023
1 parent 5abb248 commit cf94b27
Show file tree
Hide file tree
Showing 27 changed files with 73 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.9
- name: Give PyPI some time to update the index
run: sleep 240
- name: Attempt install from PyPI
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- style
strategy:
matrix:
python-version: [ '3.8', '3.10']
python-version: [ '3.9', '3.11']
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: 2

python:
version: "3.8"
version: "3.9"
install:
- method: pip
path: .
Expand Down
8 changes: 4 additions & 4 deletions blackjax/adaptation/mass_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parameters used in Hamiltonian Monte Carlo.
"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -68,7 +68,7 @@ class MassMatrixAdaptationState(NamedTuple):

def mass_matrix_adaptation(
is_diagonal_matrix: bool = True,
) -> Tuple[Callable, Callable, Callable]:
) -> tuple[Callable, Callable, Callable]:
"""Adapts the values in the mass matrix by computing the covariance
between parameters.
Expand Down Expand Up @@ -156,7 +156,7 @@ def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState:
return init, update, final


def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]:
def welford_algorithm(is_diagonal_matrix: bool) -> tuple[Callable, Callable, Callable]:
r"""Welford's online estimator of covariance.
It is possible to compute the variance of a population of values in an
Expand Down Expand Up @@ -231,7 +231,7 @@ def update(

def final(
wa_state: WelfordAlgorithmState,
) -> Tuple[Array, int, Array]:
) -> tuple[Array, int, Array]:
mean, m2, sample_size = wa_state
covariance = m2 / (sample_size - 1)
return covariance, sample_size, mean
Expand Down
4 changes: 2 additions & 2 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Tuple, Union
from typing import Callable, NamedTuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -128,7 +128,7 @@ def update(
new_ss_state, new_step_size, adaptation_state.inverse_mass_matrix
)

def final(warmup_state: PathfinderAdaptationState) -> Tuple[float, Array]:
def final(warmup_state: PathfinderAdaptationState) -> tuple[float, Array]:
"""Return the final values for the step size and inverse mass matrix."""
step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg)
inverse_mass_matrix = warmup_state.inverse_mass_matrix
Expand Down
4 changes: 2 additions & 2 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Step size adaptation"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -64,7 +64,7 @@ class DualAveragingAdaptationState(NamedTuple):

def dual_averaging_adaptation(
target: float, t0: int = 10, gamma: float = 0.05, kappa: float = 0.75
) -> Tuple[Callable, Callable, Callable]:
) -> tuple[Callable, Callable, Callable]:
"""Tune the step size in order to achieve a desired target acceptance rate.
Let us note :math:`\\epsilon` the current step size, :math:`\\alpha_t` the
Expand Down
10 changes: 5 additions & 5 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Stan warmup for the HMC family of sampling algorithms."""
from typing import Callable, List, NamedTuple, Tuple, Union
from typing import Callable, NamedTuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -45,7 +45,7 @@ class WindowAdaptationState(NamedTuple):
def base(
is_mass_matrix_diagonal: bool,
target_acceptance_rate: float = 0.80,
) -> Tuple[Callable, Callable, Callable]:
) -> tuple[Callable, Callable, Callable]:
"""Warmup scheme for sampling procedures based on euclidean manifold HMC.
The schedule and algorithms used match Stan's :cite:p:`stan_hmc_param` as closely as possible.
Expand Down Expand Up @@ -191,7 +191,7 @@ def slow_final(warmup_state: WindowAdaptationState) -> WindowAdaptationState:

def update(
adaptation_state: WindowAdaptationState,
adaptation_stage: Tuple,
adaptation_stage: tuple,
position: ArrayLikeTree,
acceptance_rate: float,
) -> WindowAdaptationState:
Expand Down Expand Up @@ -233,7 +233,7 @@ def update(

return warmup_state

def final(warmup_state: WindowAdaptationState) -> Tuple[float, Array]:
def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]:
"""Return the final values for the step size and mass matrix."""
step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg)
inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix
Expand Down Expand Up @@ -362,7 +362,7 @@ def build_schedule(
initial_buffer_size: int = 75,
final_buffer_size: int = 50,
first_window_size: int = 25,
) -> List[Tuple[int, bool]]:
) -> list[tuple[int, bool]]:
"""Return the schedule for Stan's warmup.
The schedule below is intended to be as close as possible to Stan's :cite:p:`stan_hmc_param`.
Expand Down
4 changes: 2 additions & 2 deletions blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

from typing_extensions import Protocol

Expand Down Expand Up @@ -64,7 +64,7 @@ class UpdateFn(Protocol):
"""

def __call__(self, rng_key: PRNGKey, state: State) -> Tuple[State, Info]:
def __call__(self, rng_key: PRNGKey, state: State) -> tuple[State, Info]:
"""Update the current state using the sampling algorithm.
Parameters
Expand Down
6 changes: 3 additions & 3 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Elliptical Slice sampling Kernel"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -110,7 +110,7 @@ def kernel(
rng_key: PRNGKey,
state: EllipSliceState,
logdensity_fn: Callable,
) -> Tuple[EllipSliceState, EllipSliceInfo]:
) -> tuple[EllipSliceState, EllipSliceInfo]:
proposal_generator = elliptical_proposal(
logdensity_fn, momentum_generator, mean
)
Expand Down Expand Up @@ -205,7 +205,7 @@ def elliptical_proposal(

def generate(
rng_key: PRNGKey, state: EllipSliceState
) -> Tuple[EllipSliceState, EllipSliceInfo]:
) -> tuple[EllipSliceState, EllipSliceInfo]:
position, logdensity = state
key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3)
# step 1: sample momentum
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Generalized (Non-reversible w/ persistent momentum) HMC Kernel"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -104,7 +104,7 @@ def kernel(
momentum_inverse_scale: ArrayLikeTree,
alpha: float,
delta: float,
) -> Tuple[GHMCState, hmc.HMCInfo]:
) -> tuple[GHMCState, hmc.HMCInfo]:
"""Generate new sample with the Generalized HMC kernel.
Parameters
Expand Down
6 changes: 3 additions & 3 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the HMC Kernel"""
from typing import Callable, NamedTuple, Tuple, Union
from typing import Callable, NamedTuple, Union

import jax

Expand Down Expand Up @@ -112,7 +112,7 @@ def kernel(
step_size: float,
inverse_mass_matrix: Array,
num_integration_steps: int,
) -> Tuple[HMCState, HMCInfo]:
) -> tuple[HMCState, HMCInfo]:
"""Generate a new sample with the HMC kernel."""

momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
Expand Down Expand Up @@ -281,7 +281,7 @@ def hmc_proposal(

def generate(
rng_key, state: integrators.IntegratorState
) -> Tuple[integrators.IntegratorState, HMCInfo]:
) -> tuple[integrators.IntegratorState, HMCInfo]:
"""Generate a new chain state."""
end_state = build_trajectory(state, step_size, num_integration_steps)
end_state = flip_momentum(end_state)
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Public API for Metropolis Adjusted Langevin kernels."""
import operator
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -96,7 +96,7 @@ def transition_energy(state, new_state, step_size):

def kernel(
rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float
) -> Tuple[MALAState, MALAInfo]:
) -> tuple[MALAState, MALAInfo]:
"""Generate a new sample with the MALA kernel."""
grad_fn = jax.value_and_grad(logdensity_fn)
integrator = diffusions.overdamped_langevin(grad_fn)
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
"""
from typing import Callable, Tuple
from typing import Callable

import jax.numpy as jnp
import jax.scipy as jscipy
Expand All @@ -43,7 +43,7 @@

def gaussian_euclidean(
inverse_mass_matrix: Array,
) -> Tuple[Callable, EuclideanKineticEnergy, Callable]:
) -> tuple[Callable, EuclideanKineticEnergy, Callable]:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum :cite:p:`betancourt2013general`.
The gaussian euclidean metric is a euclidean metric further characterized
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the NUTS Kernel"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -121,7 +121,7 @@ def kernel(
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
) -> Tuple[hmc.HMCState, NUTSInfo]:
) -> tuple[hmc.HMCState, NUTSInfo]:
"""Generate a new sample with the NUTS kernel."""

(
Expand Down
6 changes: 3 additions & 3 deletions blackjax/mcmc/periodic_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for Periodic Orbital Kernel"""
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -142,7 +142,7 @@ def kernel(
step_size: float,
inverse_mass_matrix: Array,
period: int,
) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]:
) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]:
"""Generate a new orbit with the Periodic Orbital kernel.
Choose a step from the orbit with probability proportional to its weights.
Expand Down Expand Up @@ -325,7 +325,7 @@ def periodic_orbital_proposal(

def generate(
direction: int, init_state: integrators.IntegratorState
) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]:
) -> tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]:
"""Generate orbit by applying bijection forwards and backwards on period.
As described in algorithm 2 of :cite:p:`neklyudov2022orbital`, each iteration of the periodic orbital
Expand Down
12 changes: 6 additions & 6 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -42,7 +42,7 @@ class Proposal(NamedTuple):

def proposal_generator(
energy: Callable, divergence_threshold: float
) -> Tuple[Callable, Callable]:
) -> tuple[Callable, Callable]:
"""
Parameters
Expand All @@ -61,7 +61,7 @@ def proposal_generator(
def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> Tuple[Proposal, bool]:
def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]:
"""Generate a new proposal from a trajectory state.
The trajectory state records information about the position in the state
Expand Down Expand Up @@ -95,7 +95,7 @@ def proposal_from_energy_diff(
new_energy: float,
divergence_threshold: float,
state: TrajectoryState,
) -> Tuple[Proposal, bool]:
) -> tuple[Proposal, bool]:
"""Computes a new proposal from the energy difference between two states.
It also verifies whether this difference is a divergence, if the
Expand Down Expand Up @@ -141,7 +141,7 @@ def asymmetric_proposal_generator(
transition_energy_fn: Callable,
divergence_threshold: float,
proposal_factory: Callable = proposal_from_energy_diff,
) -> Tuple[Callable, Callable]:
) -> tuple[Callable, Callable]:
"""A proposal generator that takes into account the transition between
two states to compute a new proposal.
Expand Down Expand Up @@ -171,7 +171,7 @@ def update(
initial_state: TrajectoryState,
state: TrajectoryState,
**energy_params,
) -> Tuple[Proposal, bool]:
) -> tuple[Proposal, bool]:
new_energy = transition_energy_fn(initial_state, state, **energy_params)
prev_energy = transition_energy_fn(state, initial_state, **energy_params)
return proposal_factory(prev_energy, new_energy, divergence_threshold, state)
Expand Down
Loading

0 comments on commit cf94b27

Please sign in to comment.