diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index b4ced0dfaed0..df61fded92bf 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2558,7 +2558,7 @@ def tensordot(a, b, axes=2, precision=None): @_wraps(onp.einsum, lax_description=_PRECISION_DOC) def einsum(*operands, **kwargs): - optimize = kwargs.pop('optimize', 'auto') + optimize = kwargs.pop('optimize', True) optimize = 'greedy' if optimize is True else optimize precision = kwargs.pop('precision', None) if kwargs: