Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Oct 20, 2023
1 parent 49f064a commit 4a34fd7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,9 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic")) or msg.get(
"_control_flow_done", False
):
if (
msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic")
) or msg.get("_control_flow_done", False):
if msg["type"] == "control_flow":
if self.data is not None:
msg["kwargs"]["substitute_stack"].append(("substitute", self.data))
Expand Down
3 changes: 2 additions & 1 deletion numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import jax
from jax.api_util import flatten_fun, shaped_abstractify
import jax.core as core
import jax.util as util
from jax.experimental.pjit import pjit_p
import jax.util as util

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp
Expand Down

0 comments on commit 4a34fd7

Please sign in to comment.