diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 8a606810581b..93f83bbb4491 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import partial -import itertools as it from typing import Dict, Any, Callable import jax @@ -21,8 +20,7 @@ from jax import linear_util as lu from . import ad from . import partial_eval as pe -from .partial_eval import (PartialVal, partial_eval_jaxpr, - JaxprTracer, ConstVar, convert_constvars_jaxpr, +from .partial_eval import (PartialVal, partial_eval_jaxpr, JaxprTracer, new_eqn_recipe, _partition_knowns) from ..core import raise_to_shaped, get_aval, Literal, Jaxpr from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs