-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched NNs for TorchANI #13
Conversation
Benchmarks Molecule: 46 atoms (
|
Profiling of ANI-2x + #5 + BatchedNN (forward pass): |
|
Very nice! And it looks from the profile like there's still room for more improvement. |
I have managed to speed up more the backward pass:
|
For some inexplicable reason, the backward pass performance drops, when converting to TorchScript: nnp_ts = torch.jit.script(nnp)
|
With the tracing, it is slow too: nnp_tr = torch.jit.trace(nnp, ((species, positions),)) |
I have solved the performance problem regarding TorchScript.
|
Hi Raimondas, I found that GPU memory usage increase linearly as molecule size when initialize BatchedNN.
script to run it import os
import gc
import torch
import torchani
import pynvml
import numpy as np
from ase.io import read
from NNPOps.BatchedNN import TorchANIBatchedNN
def checkgpu(device=None):
i = device if device else torch.cuda.current_device()
real_i = int(os.environ['CUDA_VISIBLE_DEVICES'][0]) if 'CUDA_VISIBLE_DEVICES' in os.environ else i
pynvml.nvmlInit()
h = pynvml.nvmlDeviceGetHandleByIndex(real_i)
info = pynvml.nvmlDeviceGetMemoryInfo(h)
name = pynvml.nvmlDeviceGetName(h)
print(' GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))
file = '1hz5.pdb'
mol = read(file)
device = torch.device('cuda')
species_ = torch.tensor([mol.get_atomic_numbers()], device=device)
positions = torch.tensor([mol.get_positions()], dtype=torch.float32, requires_grad=False, device=device)
print(f'File: {file}, Molecule size: {species_.shape[-1]}\n')
for N in np.arange(100, 1000, 100):
torch.cuda.empty_cache()
gc.collect()
species = species_[:, :N]
print(f"species size: {species.shape[1]}")
nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, species).to(device)
checkgpu()
print('-' * 70 + '\n') pdb file at : https://raw.githubusercontent.com/yueyericardo/aev_benchmark/master/molecules/1hz5.pdb |
That shows a large decrease in memory between 700 and 800? This may not really be measuring what you think it is. Try modifying the script so |
I know the memory might not be accurate.
|
The NN parameters are replicated for each atom to optimize memory layout for the batched multiplication. So, the memory increases linearly with a system size. Effectively, this is trading memory for speed. For small molecules (~100 atoms), GPUs have more than enough of memory. For large molecules, we may need a different algorithm. |
It seems like there ought to be a way to avoid having to duplicate them. We just need to figure out how to get PyTorch to sum/broadcast correctly. Only having one copy of the parameters should help caching efficiency, which would improve performance. |
For small molecules, as far as I tested, it was the fastest algorithm, even though it wastes GPU memory and bandwidth, but it just one kernel launch. For larger molecules, I guess if atoms are sorted by elements, the multiplication can be carried out for each element separately. This won't replicate the parameters, but would require more kernel launches. |
@raimis : Can this be updated and merged? |
@peastman this is ready for a review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
A proof of concept to speed up the inference of ANI-like model by using batched matrix operations.
TorchANIBatchedNN
See #11 for details