Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax and pyrosetta incompatability #60

Open
ronboger opened this issue Oct 17, 2024 · 3 comments
Open

jax and pyrosetta incompatability #60

ronboger opened this issue Oct 17, 2024 · 3 comments

Comments

@ronboger
Copy link

Hello - I first want to congratulate the authors on their great work here.

I initially had a working environment on my server with an A40 GPU when this was launched, but recently ran the install script by accident and wiped my previous environment.

The new install script with python 3.10 would hang on attempts to solve the environment, so I did the following within a slurm node with a BindCraft conda env running python 3.10:

conda install pip pandas matplotlib numpy"<2.0.0" biopython scipy pdbfixer seaborn tqdm jupyter ffmpeg fsspec py3dmol chex dm-haiku dm-tree joblib ml-collections immutabledict optax -c conda-forge -c anaconda -y

Install jax (the conda install in the script would hang on solving environment)
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Install pyrosetta from wheel (the conda install also flopped)
pip install pyrosetta-2024.24+release.ca096da-cp310-cp310-linux_x86_64.whl

and so on.
When I run bindcraft.py, I get a seg fault.

When I've tried debugging, I can important jax and pyrosetta separately, but

/global/scratch/users/ronb/BindCraft$ python
Python 3.10.15 | packaged by conda-forge | (main, Oct 16 2024, 01:24:24) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> import pyrosetta
=slt=bcl::util=> caught signal: 11 cleaning! Call Stack
/global/scratch/users/ronb/.conda/BindCraft/lib/python3.10/site-packages/pyrosetta/rosetta.so: bcl::util::CleanableInterface::Cleanables::SignalHandler(int)
/global/scratch/users/ronb/.conda/BindCraft/lib/python3.10/site-packages/pyrosetta/rosetta.so: bind_std_type_traits(std::function<pybind11::module_& (std::string const&)>&)
/global/scratch/users/ronb/.conda/BindCraft/lib/python3.10/site-packages/pyrosetta/rosetta.so: PyInit_rosetta
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyEval_EvalFrameDefault
python: _PyFunction_Vectorcall
python: _PyObject_CallMethodIdObjArgs
python: PyImport_ImportModuleLevelObject
python: _PyEval_EvalFrameDefault
python: PyEval_EvalCode
python: _PyEval_EvalFrameDefault

Segmentation fault (core dumped)

Do you have any input? This may be related to python 3.10 also as I didn't have this issue before, but I am unclear.

If there is a container I can run that would also be

Also adding the exported environment
bindcraft_gh_issues.txt
good!

@LasseMiddendorf
Copy link

I was running into the same issue with the updated installation script and installed the packages where conda failed to resolve the environment with pip instead. When I then try to run the script I get the error message below. I'm not sure if it is caused by the incompatibility of jax and pyrosetta mentioned here, because I can import them without getting the segmentation error. But I'm using the pyrosetta release 2024.39+release.59628fb instead of 2024.24+release.ca096da.

Starting trajectory: test
Stage 1: Test Logits
Traceback (most recent call last):
  File "/home/hpc/b114cb/b114cb21/tools/BindCraft/bindcraft.py", line 109, in <module>
    trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"],
  File "/home/hpc/b114cb/b114cb21/tools/BindCraft/functions/colabdesign_utils.py", line 97, in binder_hallucination
    af_model.design_logits(iters=50, e_soft=0.9, models=design_models, num_models=1, sample_models=advanced_settings["sample_models"], save_best=True)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 354, in design_logits
    self.design(iters, **kwargs)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 348, in design
    self.step(lr_scale=lr_scale, num_recycles=num_recycles,
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 214, in step
    self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models,
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 96, in run
    auxs.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop))
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 195, in _recycle
    aux = self._single(model_params, backprop=False)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 142, in _single
    loss, aux = self._model["fn"](*flags)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/model.py", line 216, in _model
    self._get_loss(inputs=inputs, outputs=outputs, aux=aux)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/loss.py", line 80, in _loss_binder
    align_fn = get_rmsd_loss(inputs, outputs, L=tL)["align"]
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/loss.py", line 443, in get_rmsd_loss
    return _get_rmsd_loss(true, pred, weights=weights, L=L, include_L=include_L, copies=copies)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/loss.py", line 476, in _get_rmsd_loss
    aln = _np_kabsch((P-P_mu)*W, T-T_mu)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/shared/protein.py", line 132, in _np_kabsch
    u, s, vh = _np.linalg.svd(ab, full_matrices=False)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/jax/_src/numpy/linalg.py", line 298, in svd
    u, s, vh = lax_linalg.svd(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/hpc/b114cb/b114cb21/tools/BindCraft/bindcraft.py", line 109, in <module>
    trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"],
  File "/home/hpc/b114cb/b114cb21/tools/BindCraft/functions/colabdesign_utils.py", line 97, in binder_hallucination
    af_model.design_logits(iters=50, e_soft=0.9, models=design_models, num_models=1, sample_models=advanced_settings["sample_models"], save_best=True)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 354, in design_logits
    self.design(iters, **kwargs)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 348, in design
    self.step(lr_scale=lr_scale, num_recycles=num_recycles,
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 214, in step
    self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models,
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 96, in run
    auxs.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop))
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 195, in _recycle
    aux = self._single(model_params, backprop=False)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/colabdesign/af/design.py", line 142, in _single
    loss, aux = self._model["fn"](*flags)
  File "/home/atuin/b114cb/b114cb21/software/private/conda/envs/BindCraft/lib/python3.10/site-packages/jaxlib/gpu_solver.py", line 353, in _gesvd_hlo
    lwork, opaque = gpu_solver.build_gesvdj_descriptor(
RuntimeError: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

@ronboger
Copy link
Author

ronboger commented Oct 17, 2024 via email

@martinpacesa
Copy link
Owner

Yeah there is a compatibility problem with jax and pyrosetta if you install them separately, that's why they need to be installed in one conda command to find the appropriate version. Did you try with the latest install and forcing the CUDA version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants