Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symbolic_regression.ipynb examples fails with KeyError #553

Closed
sebastianprobst opened this issue Jan 1, 2025 · 2 comments
Closed

symbolic_regression.ipynb examples fails with KeyError #553

sebastianprobst opened this issue Jan 1, 2025 · 2 comments
Labels
documentation Improvements or additions to documentation

Comments

@sebastianprobst
Copy link

sebastianprobst commented Jan 1, 2025

diffrax 0.6.2
jax-0.4.38
sympy2jax 0.0,5
The example "symbolic_regression.ipynb" fails with the following errors

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:244, in _Func.__call__(self, memodict)
    243 try:
--> 244     arg_call = memodict[arg]
    245 except KeyError:

KeyError: _Symbol(_name=str64[])

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ...\Python312\Lib\site-packages\diffrax\_integrate.py:168, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    167 try:
--> 168     vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
    169 except Exception as e:

File ...\Python312\Lib\site-packages\equinox\_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs)
     37 dynamic, static = partition((fun, args, kwargs), _filter)
---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
     39 return combine(dynamic_out, static_out.value)

    [... skipping hidden 1 frame]

File ...\Python312\Lib\site-packages\jax\_src\api.py:2637, in eval_shape(fun, *args, **kwargs)
   2636 except TypeError: fun = partial(fun)
-> 2637 return jit(fun).eval_shape(*args, **kwargs)

    [... skipping hidden 1 frame]

File ...\Python312\Lib\site-packages\jax\_src\pjit.py:487, in _make_jit_wrapper.<locals>.eval_shape(*args, **kwargs)
    485 @api_boundary
    486 def eval_shape(*args, **kwargs):
--> 487   p, _ = _infer_params(fun, jit_info, args, kwargs)
    488   out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]

File ...\Python312\Lib\site-packages\jax\_src\pjit.py:769, in _infer_params(fun, ji, args, kwargs)
    768 if entry.pjit_params is None:
--> 769   p, args_flat = _infer_params_impl(
    770       fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
    771   if p.attrs_tracked:
    772     # If there are attrs_tracked, don't use the cache.

File ...\Python312\Lib\site-packages\jax\_src\pjit.py:651, in _infer_params_impl(***failed resolving arguments***)
    650 with mesh_lib.set_abstract_mesh(abstract_mesh):
--> 651   jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
    652       flat_fun, in_type, attr_token, dbg,
    653       HashableFunction(res_paths, closure=()),
    654       IgnoreKey(ji.inline))
    655 _attr_update(flat_fun, in_type, attr_token, attrs_tracked)

File ...\Python312\Lib\site-packages\jax\_src\linear_util.py:335, in cache.<locals>.memoized_fun(fun, *args)
    334 else:
--> 335   ans = call(fun, *args)
    336   if explain and config.explain_cache_misses.value:

File ...\Python312\Lib\site-packages\jax\_src\pjit.py:1315, in _create_pjit_jaxpr(***failed resolving arguments***)
   1314   else:
-> 1315     jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
   1316         fun, in_type, debug_info=pe_debug)
   1317     # assert attr_data is sentinel or attr_data matches attrs_tracked
   1318 
   1319 # TODO(dougalm,mattjj): enable debug info with attrs_tracked

File ...\Python312\Lib\site-packages\jax\_src\profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    332 with TraceAnnotation(name, **decorator_kwargs):
--> 333   return func(*args, **kwargs)
    334 return wrapper

File ...\Python312\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2189, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs)
   2188 with core.set_current_trace(trace):
-> 2189   ans = fun.call_wrapped(*in_tracers)
   2191 out_tracers = map(trace.to_jaxpr_tracer, ans)

File ...\Python312\Lib\site-packages\jax\_src\linear_util.py:187, in WrappedFun.call_wrapped(self, *args, **kwargs)
    186 """Calls the transformed function"""
--> 187 return self.f_transformed(*args, **kwargs)

File ...\Python312\Lib\site-packages\jax\_src\api_util.py:294, in _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs)
    293 assert next(fixed_args_, sentinel) is sentinel
--> 294 return _fun(*args, **kwargs)

File ...\Python312\Lib\site-packages\jax\_src\api_util.py:74, in flatten_fun(f, store, in_tree, *args_flat)
     73 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 74 ans = f(*py_args, **py_kwargs)
     75 ans, out_tree = tree_flatten(ans)

File ...\Python312\Lib\site-packages\jax\_src\api_util.py:691, in result_paths(_fun, _store, *args, **kwargs)
    690 "linear_util transform to get output pytree paths of pre-flattened function."
--> 691 ans = _fun(*args, **kwargs)
    692 _store.store([keystr(path) for path, _ in generate_key_paths(ans)])

File ...\Python312\Lib\site-packages\equinox\_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic)
     32 _fun, _args, _kwargs = combine(_static, _dynamic)
---> 33 _out = _fun(*_args, **_kwargs)
     34 _dynamic_out, _static_out = partition(_out, _filter)

    [... skipping hidden 1 frame]

File ...\Python312\Lib\site-packages\diffrax\_term.py:194, in ODETerm.vf(self, t, y, args)
    193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
--> 194     out = self.vector_field(t, y, args)
    195     if jtu.tree_structure(out) != jtu.tree_structure(y):

Cell In[5], line 16, in Func.__call__(self, t, y, args)
     15 def __call__(self, t, y, args):
---> 16     return self.mlp(y)

Cell In[2], line 8, in Stack.__call__(self, x)
      7 x1 = x[..., 1]
----> 8 return jnp.stack([module(x0=x0, x1=x1) for module in self.modules], axis=-1)

File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:325, in SymbolicModule.__call__(self, **symbols)
    324 memodict = symbols
--> 325 return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in tree_map(f, tree, is_leaf, *rest)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in <genexpr>(.0)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:325, in SymbolicModule.__call__.<locals>.<lambda>(n)
    324 memodict = symbols
--> 325 return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:246, in _Func.__call__(self, memodict)
    245 except KeyError:
--> 246     arg_call = arg(memodict)
    247     memodict[arg] = arg_call

File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:136, in _Symbol.__call__(self, memodict)
    135 try:
--> 136     return memodict[self._name]
    137 except KeyError as e:

TypeError: unhashable type: 'DynamicJaxprTracer'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ...\Python312\Lib\site-packages\diffrax\_integrate.py:194, in _assert_term_compatible(y, args, terms, term_structure, contr_kwargs)
    193     with jax.numpy_dtype_promotion("standard"):
--> 194         jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    195 except Exception as e:
    196     # ValueError may also arise from mismatched tree structures

File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in tree_map(f, tree, is_leaf, *rest)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in <genexpr>(.0)
    358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ...\Python312\Lib\site-packages\diffrax\_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    169 except Exception as e:
--> 170     raise ValueError(f"Error while tracing {term}.vf: " + str(e))
    171 vf_type_compatible = eqx.filter_eval_shape(
    172     better_isinstance, vf_type, vf_type_expected
    173 )

ValueError: Error while tracing ODETerm(
  vector_field=Func(
    mlp=Stack(
      modules=[
        SymbolicModule(
          nodes=_Func(
            _func=<function fn_>,
            _args=[
              _Symbol(_name=str64[]),
              _Func(
                _func=<function power>,
                _args=[
                  _Func(
                    _func=<function fn_>,
                    _args=[_Symbol(_name=str64[]), _Float(_value=weak_f32[])]
                  ),
                  _Integer(_value=weak_i32[])
                ]
              )
            ]
          ),
          has_extra_funcs=False
        ),
        SymbolicModule(
          nodes=_Func(
            _func=<function fn_>,
            _args=[
              _Symbol(_name=str64[]),
              _Func(
                _func=<function power>,
                _args=[
                  _Func(
                    _func=<function fn_>,
                    _args=[
                      _Func(
                        _func=<function fn_>,
                        _args=[
                          _Float(_value=weak_f32[]),
                          _Func(
                            _func=<function fn_>,
                            _args=[
                              _Integer(_value=weak_i32[]),
                              _Symbol(_name=str64[])
                            ]
                          )
                        ]
                      ),
                      _Float(_value=weak_f32[])
                    ]
                  ),
                  _Integer(_value=weak_i32[])
                ]
              )
            ]
          ),
          has_extra_funcs=False
        )
      ]
    )
  )
).vf: unhashable type: 'DynamicJaxprTracer'

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[53], line 1
----> 1 main2()

Cell In[52], line 85, in main2(symbolic_dataset_size, symbolic_num_populations, symbolic_population_size, symbolic_migration_steps, symbolic_mutation_steps, symbolic_descent_steps, pareto_coefficient, fine_tuning_steps, fine_tuning_lr, quantise_to)
     82     return symbolic_model, opt_state
     84 for _ in range(fine_tuning_steps):
---> 85     symbolic_model, opt_state = make_step(symbolic_model, opt_state)
     87 #
     88 # Finally we round each constant to the nearest multiple of `quantise_to`.
     89 #
     91 trained_expressions = []

    [... skipping hidden 19 frame]

Cell In[52], line 79, in main2.<locals>.make_step(symbolic_model, opt_state)
     77 @eqx.filter_jit
     78 def make_step(symbolic_model, opt_state):
---> 79     grads = grad_loss(symbolic_model)
     80     updates, opt_state = optim.update(grads, opt_state)
     81     symbolic_model = eqx.apply_updates(symbolic_model, updates)

    [... skipping hidden 18 frame]

Cell In[52], line 71, in main2.<locals>.grad_loss(symbolic_model)
     68 @eqx.filter_grad
     69 def grad_loss(symbolic_model):
     70     vmap_model = jax.vmap(symbolic_model, in_axes=(None, 0))
---> 71     pred_ys = vmap_model(ts, ys[:, 0])  # noqa: F821
     72     return jnp.mean((ys - pred_ys) ** 2)

    [... skipping hidden 6 frame]

Cell In[6], line 9, in NeuralODE.__call__(self, ts, y0)
      8 def __call__(self, ts, y0):
----> 9     solution = diffrax.diffeqsolve(
     10         diffrax.ODETerm(self.func),
     11         diffrax.Tsit5(),
     12         t0=ts[0],
     13         t1=ts[-1],
     14         dt0=ts[1] - ts[0],
     15         y0=y0,
     16         stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
     17         saveat=diffrax.SaveAt(ts=ts),
     18     )
     19     return solution.ys

    [... skipping hidden 19 frame]

File ...\Python312\Lib\site-packages\diffrax\_integrate.py:1089, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1086         terms = MultiTerm(*terms)
   1088 # Error checking for term compatibility
-> 1089 _assert_term_compatible(
   1090     y0,
   1091     args,
   1092     terms,
   1093     solver.term_structure,
   1094     solver.term_compatible_contr_kwargs,
   1095 )
   1097 if is_sde(terms):
   1098     if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):

File ...\Python312\Lib\site-packages\diffrax\_integrate.py:197, in _assert_term_compatible(y, args, terms, term_structure, contr_kwargs)
    194         jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    195 except Exception as e:
    196     # ValueError may also arise from mismatched tree structures
--> 197     raise ValueError("Terms are not compatible with solver!") from e

ValueError: Terms are not compatible with solver!

(excellent library btw.!)

@patrick-kidger
Copy link
Owner

Thanks for the report! Looks like there's a few things going on here:

  • numpy 2 now supports its own string-dtyped arrays.
  • sympy now uses these numpy-strings (rather than regular strings) to represent symbolic names.
  • JAX erroneously (?) is willing to create stringly-dtyped tracers under jax.eval_shape (JAX can create str-dtyped tracers under eval_shape with numpy 2 jax-ml/jax#25707) when given a numpy-string, because it thinks it is an array.
  • and so what was previously a regular str is now a tracer and so is not hashable, and this causes an explosion.

The fix is to release a new version of sympy2jax, which converts numpy-strings back to regular strings: patrick-kidger/sympy2jax#16

Whilst I'm here I've also just updated the symbolic regression example to handle a few changes in both sympy and PySR.

@patrick-kidger patrick-kidger added the documentation Improvements or additions to documentation label Jan 2, 2025
@sebastianprobst
Copy link
Author

Thank you for the quick solution. It is working now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants