We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
diffrax 0.6.2 jax-0.4.38 sympy2jax 0.0,5 The example "symbolic_regression.ipynb" fails with the following errors
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:244, in _Func.__call__(self, memodict) 243 try: --> 244 arg_call = memodict[arg] 245 except KeyError: KeyError: _Symbol(_name=str64[]) During handling of the above exception, another exception occurred: TypeError Traceback (most recent call last) File ...\Python312\Lib\site-packages\diffrax\_integrate.py:168, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi) 167 try: --> 168 vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) 169 except Exception as e: File ...\Python312\Lib\site-packages\equinox\_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs) 37 dynamic, static = partition((fun, args, kwargs), _filter) ---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic) 39 return combine(dynamic_out, static_out.value) [... skipping hidden 1 frame] File ...\Python312\Lib\site-packages\jax\_src\api.py:2637, in eval_shape(fun, *args, **kwargs) 2636 except TypeError: fun = partial(fun) -> 2637 return jit(fun).eval_shape(*args, **kwargs) [... skipping hidden 1 frame] File ...\Python312\Lib\site-packages\jax\_src\pjit.py:487, in _make_jit_wrapper.<locals>.eval_shape(*args, **kwargs) 485 @api_boundary 486 def eval_shape(*args, **kwargs): --> 487 p, _ = _infer_params(fun, jit_info, args, kwargs) 488 out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] File ...\Python312\Lib\site-packages\jax\_src\pjit.py:769, in _infer_params(fun, ji, args, kwargs) 768 if entry.pjit_params is None: --> 769 p, args_flat = _infer_params_impl( 770 fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals) 771 if p.attrs_tracked: 772 # If there are attrs_tracked, don't use the cache. File ...\Python312\Lib\site-packages\jax\_src\pjit.py:651, in _infer_params_impl(***failed resolving arguments***) 650 with mesh_lib.set_abstract_mesh(abstract_mesh): --> 651 jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( 652 flat_fun, in_type, attr_token, dbg, 653 HashableFunction(res_paths, closure=()), 654 IgnoreKey(ji.inline)) 655 _attr_update(flat_fun, in_type, attr_token, attrs_tracked) File ...\Python312\Lib\site-packages\jax\_src\linear_util.py:335, in cache.<locals>.memoized_fun(fun, *args) 334 else: --> 335 ans = call(fun, *args) 336 if explain and config.explain_cache_misses.value: File ...\Python312\Lib\site-packages\jax\_src\pjit.py:1315, in _create_pjit_jaxpr(***failed resolving arguments***) 1314 else: -> 1315 jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( 1316 fun, in_type, debug_info=pe_debug) 1317 # assert attr_data is sentinel or attr_data matches attrs_tracked 1318 1319 # TODO(dougalm,mattjj): enable debug info with attrs_tracked File ...\Python312\Lib\site-packages\jax\_src\profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs) 332 with TraceAnnotation(name, **decorator_kwargs): --> 333 return func(*args, **kwargs) 334 return wrapper File ...\Python312\Lib\site-packages\jax\_src\interpreters\partial_eval.py:2189, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs) 2188 with core.set_current_trace(trace): -> 2189 ans = fun.call_wrapped(*in_tracers) 2191 out_tracers = map(trace.to_jaxpr_tracer, ans) File ...\Python312\Lib\site-packages\jax\_src\linear_util.py:187, in WrappedFun.call_wrapped(self, *args, **kwargs) 186 """Calls the transformed function""" --> 187 return self.f_transformed(*args, **kwargs) File ...\Python312\Lib\site-packages\jax\_src\api_util.py:294, in _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs) 293 assert next(fixed_args_, sentinel) is sentinel --> 294 return _fun(*args, **kwargs) File ...\Python312\Lib\site-packages\jax\_src\api_util.py:74, in flatten_fun(f, store, in_tree, *args_flat) 73 py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ---> 74 ans = f(*py_args, **py_kwargs) 75 ans, out_tree = tree_flatten(ans) File ...\Python312\Lib\site-packages\jax\_src\api_util.py:691, in result_paths(_fun, _store, *args, **kwargs) 690 "linear_util transform to get output pytree paths of pre-flattened function." --> 691 ans = _fun(*args, **kwargs) 692 _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) File ...\Python312\Lib\site-packages\equinox\_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic) 32 _fun, _args, _kwargs = combine(_static, _dynamic) ---> 33 _out = _fun(*_args, **_kwargs) 34 _dynamic_out, _static_out = partition(_out, _filter) [... skipping hidden 1 frame] File ...\Python312\Lib\site-packages\diffrax\_term.py:194, in ODETerm.vf(self, t, y, args) 193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: --> 194 out = self.vector_field(t, y, args) 195 if jtu.tree_structure(out) != jtu.tree_structure(y): Cell In[5], line 16, in Func.__call__(self, t, y, args) 15 def __call__(self, t, y, args): ---> 16 return self.mlp(y) Cell In[2], line 8, in Stack.__call__(self, x) 7 x1 = x[..., 1] ----> 8 return jnp.stack([module(x0=x0, x1=x1) for module in self.modules], axis=-1) File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:325, in SymbolicModule.__call__(self, **symbols) 324 memodict = symbols --> 325 return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node) File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in tree_map(f, tree, is_leaf, *rest) 358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in <genexpr>(.0) 358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:325, in SymbolicModule.__call__.<locals>.<lambda>(n) 324 memodict = symbols --> 325 return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node) File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:246, in _Func.__call__(self, memodict) 245 except KeyError: --> 246 arg_call = arg(memodict) 247 memodict[arg] = arg_call File ...\Python312\Lib\site-packages\sympy2jax\sympy_module.py:136, in _Symbol.__call__(self, memodict) 135 try: --> 136 return memodict[self._name] 137 except KeyError as e: TypeError: unhashable type: 'DynamicJaxprTracer' During handling of the above exception, another exception occurred: ValueError Traceback (most recent call last) File ...\Python312\Lib\site-packages\diffrax\_integrate.py:194, in _assert_term_compatible(y, args, terms, term_structure, contr_kwargs) 193 with jax.numpy_dtype_promotion("standard"): --> 194 jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) 195 except Exception as e: 196 # ValueError may also arise from mismatched tree structures File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in tree_map(f, tree, is_leaf, *rest) 358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File ...\Python312\Lib\site-packages\jax\_src\tree_util.py:359, in <genexpr>(.0) 358 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 359 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File ...\Python312\Lib\site-packages\diffrax\_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi) 169 except Exception as e: --> 170 raise ValueError(f"Error while tracing {term}.vf: " + str(e)) 171 vf_type_compatible = eqx.filter_eval_shape( 172 better_isinstance, vf_type, vf_type_expected 173 ) ValueError: Error while tracing ODETerm( vector_field=Func( mlp=Stack( modules=[ SymbolicModule( nodes=_Func( _func=<function fn_>, _args=[ _Symbol(_name=str64[]), _Func( _func=<function power>, _args=[ _Func( _func=<function fn_>, _args=[_Symbol(_name=str64[]), _Float(_value=weak_f32[])] ), _Integer(_value=weak_i32[]) ] ) ] ), has_extra_funcs=False ), SymbolicModule( nodes=_Func( _func=<function fn_>, _args=[ _Symbol(_name=str64[]), _Func( _func=<function power>, _args=[ _Func( _func=<function fn_>, _args=[ _Func( _func=<function fn_>, _args=[ _Float(_value=weak_f32[]), _Func( _func=<function fn_>, _args=[ _Integer(_value=weak_i32[]), _Symbol(_name=str64[]) ] ) ] ), _Float(_value=weak_f32[]) ] ), _Integer(_value=weak_i32[]) ] ) ] ), has_extra_funcs=False ) ] ) ) ).vf: unhashable type: 'DynamicJaxprTracer' The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) Cell In[53], line 1 ----> 1 main2() Cell In[52], line 85, in main2(symbolic_dataset_size, symbolic_num_populations, symbolic_population_size, symbolic_migration_steps, symbolic_mutation_steps, symbolic_descent_steps, pareto_coefficient, fine_tuning_steps, fine_tuning_lr, quantise_to) 82 return symbolic_model, opt_state 84 for _ in range(fine_tuning_steps): ---> 85 symbolic_model, opt_state = make_step(symbolic_model, opt_state) 87 # 88 # Finally we round each constant to the nearest multiple of `quantise_to`. 89 # 91 trained_expressions = [] [... skipping hidden 19 frame] Cell In[52], line 79, in main2.<locals>.make_step(symbolic_model, opt_state) 77 @eqx.filter_jit 78 def make_step(symbolic_model, opt_state): ---> 79 grads = grad_loss(symbolic_model) 80 updates, opt_state = optim.update(grads, opt_state) 81 symbolic_model = eqx.apply_updates(symbolic_model, updates) [... skipping hidden 18 frame] Cell In[52], line 71, in main2.<locals>.grad_loss(symbolic_model) 68 @eqx.filter_grad 69 def grad_loss(symbolic_model): 70 vmap_model = jax.vmap(symbolic_model, in_axes=(None, 0)) ---> 71 pred_ys = vmap_model(ts, ys[:, 0]) # noqa: F821 72 return jnp.mean((ys - pred_ys) ** 2) [... skipping hidden 6 frame] Cell In[6], line 9, in NeuralODE.__call__(self, ts, y0) 8 def __call__(self, ts, y0): ----> 9 solution = diffrax.diffeqsolve( 10 diffrax.ODETerm(self.func), 11 diffrax.Tsit5(), 12 t0=ts[0], 13 t1=ts[-1], 14 dt0=ts[1] - ts[0], 15 y0=y0, 16 stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), 17 saveat=diffrax.SaveAt(ts=ts), 18 ) 19 return solution.ys [... skipping hidden 19 frame] File ...\Python312\Lib\site-packages\diffrax\_integrate.py:1089, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event) 1086 terms = MultiTerm(*terms) 1088 # Error checking for term compatibility -> 1089 _assert_term_compatible( 1090 y0, 1091 args, 1092 terms, 1093 solver.term_structure, 1094 solver.term_compatible_contr_kwargs, 1095 ) 1097 if is_sde(terms): 1098 if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): File ...\Python312\Lib\site-packages\diffrax\_integrate.py:197, in _assert_term_compatible(y, args, terms, term_structure, contr_kwargs) 194 jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) 195 except Exception as e: 196 # ValueError may also arise from mismatched tree structures --> 197 raise ValueError("Terms are not compatible with solver!") from e ValueError: Terms are not compatible with solver!
(excellent library btw.!)
The text was updated successfully, but these errors were encountered:
Thanks for the report! Looks like there's a few things going on here:
jax.eval_shape
eval_shape
str
The fix is to release a new version of sympy2jax, which converts numpy-strings back to regular strings: patrick-kidger/sympy2jax#16
Whilst I'm here I've also just updated the symbolic regression example to handle a few changes in both sympy and PySR.
Sorry, something went wrong.
Thank you for the quick solution. It is working now.
No branches or pull requests
diffrax 0.6.2
jax-0.4.38
sympy2jax 0.0,5
The example "symbolic_regression.ipynb" fails with the following errors
(excellent library btw.!)
The text was updated successfully, but these errors were encountered: