From 8f1ebd58960b7eed708cf5eebcc88a7e6a3bf8a6 Mon Sep 17 00:00:00 2001 From: Ali Cowen-Rivers Date: Tue, 10 Jan 2023 15:08:44 -0800 Subject: [PATCH] Fix GPU relax for longer chains by pinning large memory ops to cpu. PiperOrigin-RevId: 501105389 Change-Id: I6c981d1d3231e008ebae192edb4586479eb5eb34 --- alphafold/relax/amber_minimize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/alphafold/relax/amber_minimize.py b/alphafold/relax/amber_minimize.py index 4694f4402..e21a0dc30 100644 --- a/alphafold/relax/amber_minimize.py +++ b/alphafold/relax/amber_minimize.py @@ -26,6 +26,7 @@ from alphafold.relax import utils import ml_collections import numpy as np +import jax from simtk import openmm from simtk import unit from simtk.openmm import app as openmm_app @@ -486,7 +487,9 @@ def run_pipeline( pdb_string = clean_protein(prot, checks=True) else: pdb_string = ret["min_pdb"] - ret.update(get_violation_metrics(prot)) + # Calculation of violations can cause CUDA errors for some JAX versions. + with jax.default_device(jax.devices("cpu")[0]): + ret.update(get_violation_metrics(prot)) ret.update({ "num_exclusions": len(exclude_residues), "iteration": iteration,