Skip to content

Commit

Permalink
cpu gpu device error
Browse files Browse the repository at this point in the history
  • Loading branch information
yueyericardo authored Jan 27, 2022
1 parent 4f05466 commit 1d1d3f2
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions openmmml/models/anipotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,20 @@ def __init__(self, model, species, atoms, periodic):
self.indices = None
else:
self.indices = torch.tensor(sorted(atoms), dtype=torch.int64)
if periodic:
self.pbc = torch.tensor([True, True, True], dtype=torch.bool)
else:
self.pbc = None
self.pbc = torch.tensor([True, True, True], dtype=torch.bool)
# comment the following lines if need to use CPU
if topology.getPeriodicBoxVectors() is None:
self.model.aev_computer.use_cuda_extension = True

def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):
positions = positions.to(torch.float32)
self.species = self.species.to(positions.device)
if self.indices is not None:
positions = positions[self.indices]
if boxvectors is None:
_, energy = self.model((self.species, 10.0*positions.unsqueeze(0)))
else:
self.pbc = self.pbc.to(positions.device)
boxvectors = boxvectors.to(torch.float32)
_, energy = self.model((self.species, 10.0*positions.unsqueeze(0)), cell=10.0*boxvectors, pbc=self.pbc)
return self.energyScale*energy
Expand Down

0 comments on commit 1d1d3f2

Please sign in to comment.