diff --git a/alphafold/relax/amber_minimize.py b/alphafold/relax/amber_minimize.py index 4694f4402..e21a0dc30 100644 --- a/alphafold/relax/amber_minimize.py +++ b/alphafold/relax/amber_minimize.py @@ -26,6 +26,7 @@ from alphafold.relax import utils import ml_collections import numpy as np +import jax from simtk import openmm from simtk import unit from simtk.openmm import app as openmm_app @@ -486,7 +487,9 @@ def run_pipeline( pdb_string = clean_protein(prot, checks=True) else: pdb_string = ret["min_pdb"] - ret.update(get_violation_metrics(prot)) + # Calculation of violations can cause CUDA errors for some JAX versions. + with jax.default_device(jax.devices("cpu")[0]): + ret.update(get_violation_metrics(prot)) ret.update({ "num_exclusions": len(exclude_residues), "iteration": iteration,