diff --git a/jax/core.py b/jax/core.py index d7f91477c8af..94de3b4e949a 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1129,10 +1129,10 @@ def check_jaxpr(jaxpr: Jaxpr): exception_type = type(e) msg_context = f"while checking jaxpr:\n\n{jaxpr}\n" if len(e.args) == 0: - exception_args = (msg_context,) + exception_args = [msg_context] else: msg = f"{e.args[0]}\n\n" + msg_context - exception_args = (msg, *e.args[1:]) + exception_args = [msg, *e.args[1:]] raise exception_type(*exception_args) from e def _check_jaxpr(jaxpr: Jaxpr):