Skip to content

Commit

Permalink
s/diffusion/diffusions
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 22, 2022
1 parent 164e4a3 commit f52a608
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
File renamed without changes.
4 changes: 2 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp

from blackjax.mcmc.diffusion import overdamped_langevin
import blackjax.mcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree

__all__ = ["MALAState", "MALAInfo", "init", "kernel"]
Expand Down Expand Up @@ -83,7 +83,7 @@ def one_step(
"""
grad_fn = jax.value_and_grad(logprob_fn)
integrator = overdamped_langevin(grad_fn)
integrator = diffusions.overdamped_langevin(grad_fn)

key_integrator, key_rmh = jax.random.split(rng_key)

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax

from blackjax.sgmcmc.diffusion import sghmc
import blackjax.sgmcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

Expand All @@ -13,7 +13,7 @@
def kernel(alpha: float = 0.01, beta: float = 0) -> Callable:
"""Stochastic gradient Hamiltonian Monte Carlo (SgHMC) algorithm."""

integrator = sghmc(alpha, beta)
integrator = diffusions.sghmc(alpha, beta)

def one_step(
rng_key: PRNGKey,
Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Public API for the Stochastic gradient Langevin Dynamics kernel."""
from typing import Callable

from blackjax.sgmcmc.diffusion import overdamped_langevin
import blackjax.sgmcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree

__all__ = ["kernel"]
Expand All @@ -10,7 +10,7 @@
def kernel() -> Callable:
"""Stochastic gradient Langevin Dynamics (SgLD) algorithm."""

integrator = overdamped_langevin()
integrator = diffusions.overdamped_langevin()

def one_step(
rng_key: PRNGKey,
Expand Down

0 comments on commit f52a608

Please sign in to comment.