Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched NNs for TorchANI #13

Merged
merged 24 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enable_testing()

add_library(${LIBRARY} SHARED src/ani/CpuANISymmetryFunctions.cpp
src/ani/CudaANISymmetryFunctions.cu
src/pytorch/BatchedNN.cpp
src/pytorch/SymmetryFunctions.cpp
src/schnet/CpuCFConv.cpp
src/schnet/CudaCFConv.cu)
Expand All @@ -29,8 +30,10 @@ foreach(TEST_PATH ${TEST_PATHS})
endforeach()

add_test(TestSymmetryFunctions pytest ${CMAKE_SOURCE_DIR}/src/pytorch/TestSymmetryFunctions.py)
add_test(TestBatchedNN pytest ${CMAKE_SOURCE_DIR}/src/pytorch/TestBatchedNN.py)

install(TARGETS ${LIBRARY} DESTINATION ${Python_SITEARCH}/${NAME})
install(FILES src/pytorch/__init__.py
src/pytorch/BatchedNN.py
src/pytorch/SymmetryFunctions.py
DESTINATION ${Python_SITEARCH}/${NAME})
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,40 @@ $ make install
- Run the tests
```bash
$ ctest
```

## Usage

Accelerated [*TorchANI*](https://aiqm.github.io/torchani/) operations:
- [`torchani.AEVComputer`](https://aiqm.github.io/torchani/api.html?highlight=speciesaev#torchani.AEVComputer)
- [`torchani.neurochem.NeuralNetwork`](https://aiqm.github.io/torchani/api.html#module-torchani.neurochem)

### Example

```python
import mdtraj
import torch
import torchani

from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions
from NNPOps.BatchedNN import TorchANIBatchedNN

device = torch.device('cuda')

# Load a molecule
molecule = mdtraj.load('molecule.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in molecule.top.atoms]], device=device)
positions = torch.tensor(molecule.xyz * 10, dtype=torch.float32, requires_grad=True, device=device)

# Construct ANI-2x and replace its operations with the optimized ones
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)

# Compute energy and forces
energy = nnp((species, positions)).energies
energy.backward()
forces = -positions.grad.clone()

print(energy, forces)
```
50 changes: 50 additions & 0 deletions src/pytorch/BatchedNN.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* Copyright (c) 2020 Acellera
* Authors: Raimondas Galvelis
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

#include <torch/script.h>

using Context = torch::autograd::AutogradContext;
using Tensor = torch::Tensor;
using tensor_list = torch::autograd::tensor_list;

class BatchedLinearFunction : public torch::autograd::Function<BatchedLinearFunction> {
public:
static Tensor forward(Context* ctx, const Tensor& vectors, const Tensor& weights, const Tensor& biases) {
ctx->save_for_backward({weights});
return torch::matmul(weights, vectors) + biases;
};
static tensor_list backward(Context *ctx, const tensor_list& grads) {
const Tensor grad_in = grads[0].squeeze(-1).unsqueeze(-2);
const Tensor weights = ctx->get_saved_variables()[0];
const Tensor grad_out = torch::matmul(grad_in, weights).squeeze(-2).unsqueeze(-1);
return {grad_out, torch::Tensor(), torch::Tensor()};
};
};

static Tensor BatchedLinear(const Tensor& vector, const Tensor& weights, const Tensor& biases) {
return BatchedLinearFunction::apply(vector, weights, biases);
}

TORCH_LIBRARY(NNPOpsBatchedNN, m) {
m.def("BatchedLinear", BatchedLinear);
}
104 changes: 104 additions & 0 deletions src/pytorch/BatchedNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#
# Copyright (c) 2020 Acellera
# Authors: Raimondas Galvelis
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import os
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
import torchani
from torchani.nn import ANIModel, Ensemble, SpeciesConverter, SpeciesEnergies
from typing import List, Optional, Tuple, Union

torch.ops.load_library(os.path.join(os.path.dirname(__file__), 'libNNPOpsPyTorch.so'))
batchedLinear = torch.ops.NNPOpsBatchedNN.BatchedLinear


class TorchANIBatchedNN(torch.nn.Module):

def __init__(self, converter: SpeciesConverter, ensemble: Union[ANIModel, Ensemble], atomicNumbers: Tensor):

super().__init__()

# Convert atomic numbers to a list of species
species_list = converter((atomicNumbers, torch.empty(0))).species[0].tolist()

# Handle the case when the ensemble is just one model
ensemble = [ensemble] if type(ensemble) == ANIModel else ensemble

# Convert models to the list of linear layers
models = [list(model.values()) for model in ensemble]

# Extract the weihts and biases of the linear layers
for ilayer in [0, 2, 4, 6]:
layers = [[model[species][ilayer] for species in species_list] for model in models]
weights, biases = self.batchLinearLayers(layers)
self.register_parameter(f'layer{ilayer}_weights', weights)
self.register_parameter(f'layer{ilayer}_biases', biases)

# Disable autograd for the parameters
for parameter in self.parameters():
parameter.requires_grad = False

@staticmethod
def batchLinearLayers(layers: List[List[nn.Linear]]) -> Tuple[nn.Parameter, nn.Parameter]:

num_models = len(layers)
num_atoms = len(layers[0])

# Note: different elements have different size linear layers, so we just find maximum sizes
# and pad with zeros.
max_out = max(layer.out_features for layer in sum(layers, []))
max_in = max(layer.in_features for layer in sum(layers, []))

# Copy weights and biases
weights = torch.zeros((1, num_atoms, num_models, max_out, max_in), dtype=torch.float32)
biases = torch.zeros((1, num_atoms, num_models, max_out, 1), dtype=torch.float32)
for imodel, sublayers in enumerate(layers):
for iatom, layer in enumerate(sublayers):
num_out, num_in = layer.weight.shape
weights[0, iatom, imodel, :num_out, :num_in] = layer.weight
biases [0, iatom, imodel, :num_out, 0] = layer.bias

return nn.Parameter(weights), nn.Parameter(biases)

def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:

species, aev = species_aev

# Reshape: [num_mols, num_atoms, num_features] --> [num_mols, num_atoms, 1, num_features, 1]
vectors = aev.unsqueeze(-2).unsqueeze(-1)

vectors = batchedLinear(vectors, self.layer0_weights, self.layer0_biases) # Linear 0
vectors = F.celu(vectors, alpha=0.1) # CELU 1
vectors = batchedLinear(vectors, self.layer2_weights, self.layer2_biases) # Linear 2
vectors = F.celu(vectors, alpha=0.1) # CELU 3
vectors = batchedLinear(vectors, self.layer4_weights, self.layer4_biases) # Linear 4
vectors = F.celu(vectors, alpha=0.1) # CELU 5
vectors = batchedLinear(vectors, self.layer6_weights, self.layer6_biases) # Linear 6

# Sum: [num_mols, num_atoms, num_models, 1, 1] --> [num_mols, num_models]
# Mean: [num_mols, num_models] --> [num_mols]
energies = torch.mean(torch.sum(vectors, (1, 3, 4)), 1)

return SpeciesEnergies(species, energies)
99 changes: 99 additions & 0 deletions src/pytorch/BenchmarkBatchedNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#
# Copyright (c) 2020 Acellera
# Authors: Raimondas Galvelis
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

import mdtraj
import time
import torch
import torchani

# from NNPOps.SymmetryFunctions import TorchANISymmetryFunctions
from NNPOps.BatchedNN import TorchANIBatchedNN

device = torch.device('cuda')

mol = mdtraj.load('molecules/2iuz_ligand.mol2')
species = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
positions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
print(nnp)

energy_ref = nnp((species, positions)).energies
energy_ref.backward()
grad_ref = positions.grad.clone()

N = 3000
start = time.time()
for _ in range(N):
energy_ref = nnp((species, positions)).energies
delta = time.time() - start
print(f'ANI-2x (forward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

N = 1000
start = time.time()
for _ in range(N):
energy_ref = nnp((species, positions)).energies
positions.grad.zero_()
energy_ref.backward()
delta = time.time() - start
print(f'ANI-2x (forward & backward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

# nnp.aev_computer = TorchANISymmetryFunctions(nnp.aev_computer).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)
print(nnp)

# nnp = torch.jit.script(nnp)
# nnp.save('nnp.pt')
# npp = torch.jit.load('nnp.pt').to(device)

energy = nnp((species, positions)).energies
positions.grad.zero_()
energy.backward()
grad = positions.grad.clone()

N = 15000
start = time.time()
for _ in range(N):
energy = nnp((species, positions)).energies
delta = time.time() - start
print(f'ANI-2x with BatchedNN (forward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

N = 7500
start = time.time()
for _ in range(N):
energy = nnp((species, positions)).energies
positions.grad.zero_()
energy.backward()
delta = time.time() - start
print(f'ANI-2x with BatchedNN (forward & backward pass)')
print(f' Duration: {delta} s')
print(f' Speed: {delta/N*1000} ms/it')

# print(float(energy_ref), float(energy), float(energy_ref - energy))
# print(float(torch.max(torch.abs((grad - grad_ref)/grad_ref))))
35 changes: 0 additions & 35 deletions src/pytorch/README.md

This file was deleted.

Loading