diff --git a/qcengine/programs/torchani.py b/qcengine/programs/torchani.py index bcd1b9bd1..2e81a5d9e 100644 --- a/qcengine/programs/torchani.py +++ b/qcengine/programs/torchani.py @@ -100,7 +100,7 @@ def compute(self, input_data: "AtomicInput", config: "TaskConfig") -> "AtomicRes import torch import torchani - device = torch.device("cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Failure flag ret_data = {"success": False} @@ -126,6 +126,7 @@ def compute(self, input_data: "AtomicInput", config: "TaskConfig") -> "AtomicRes # Build coord array geom_array = input_data.molecule.geometry.reshape(1, -1, 3) * ureg.conversion_factor("bohr", "angstrom") coordinates = torch.tensor(geom_array.tolist(), requires_grad=True, device=device) + model.to(device) _, energy_array = model((species, coordinates)) energy = energy_array.mean() @@ -139,11 +140,11 @@ def compute(self, input_data: "AtomicInput", config: "TaskConfig") -> "AtomicRes elif input_data.driver == "gradient": derivative = torch.autograd.grad(energy.sum(), coordinates)[0].squeeze() ret_data["return_result"] = ( - np.asarray(derivative * ureg.conversion_factor("angstrom", "bohr")).ravel().tolist() + np.asarray(derivative.cpu() * ureg.conversion_factor("angstrom", "bohr")).ravel().tolist() ) elif input_data.driver == "hessian": hessian = torchani.utils.hessian(coordinates, energies=energy) - ret_data["return_result"] = np.asarray(hessian) + ret_data["return_result"] = np.asarray(hessian.cpu()) else: raise InputError( f"TorchANI can only compute energy, gradient, and hessian driver methods. Found {input_data.driver}." @@ -172,7 +173,7 @@ def compute(self, input_data: "AtomicInput", config: "TaskConfig") -> "AtomicRes ret_data["extras"] = input_data.extras.copy() ret_data["extras"].update( { - "ensemble_energies": energy_array.detach().numpy(), + "ensemble_energies": energy_array.cpu().detach().numpy(), "ensemble_energy_avg": energy.item(), "ensemble_energy_std": ensemble_std.item(), "ensemble_per_root_atom_disagreement": ensemble_scaled_std.item(),