Skip to content

Commit

Permalink
device and data type handling
Browse files Browse the repository at this point in the history
  • Loading branch information
RylieWeaver committed Dec 13, 2024
1 parent 112df12 commit ffa7f94
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
30 changes: 12 additions & 18 deletions examples/LennardJones/LJ_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ def transform_input_to_data_object_base(self, filepath):
.unsqueeze(0)
.to(torch.float32),
energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32),
pbc=[
True,
True,
True,
], # LJ example always has periodic boundary conditions
pbc=torch.tensor(
[
True,
True,
True,
],
dtype=torch.bool,
), # LJ example always has periodic boundary conditions
)

# Create pbc edges and lengths
Expand Down Expand Up @@ -337,30 +340,21 @@ def create_configuration(
data.cell = torch.diag(
torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z])
)
data.pbc = [True, True, True]
data.pbc = torch.tensor([True, True, True], dtype=torch.bool)
data.x = torch.cat([atom_types, positions], dim=1)

create_graph_connectivity_pbc = get_radius_graph_pbc(
radius_cutoff, max_num_neighbors
)
data = create_graph_connectivity_pbc(data)

atomic_descriptors = torch.cat(
(
atom_types,
positions,
),
1,
)

data.x = atomic_descriptors

data = atomic_structure_handler.compute(data)

total_energy = torch.sum(data.x[:, 4])
energy_per_atom = total_energy / number_nodes

total_energy_str = numpy.array2string(total_energy.detach().numpy())
energy_per_atom_str = numpy.array2string(energy_per_atom.detach().numpy())
total_energy_str = numpy.array2string(total_energy.detach().cpu().numpy())
energy_per_atom_str = numpy.array2string(energy_per_atom.detach().cpu().numpy())
filetxt = total_energy_str + "\n" + energy_per_atom_str

for index in range(0, 3):
Expand Down
21 changes: 17 additions & 4 deletions hydragnn/preprocess/graph_samples_checks_and_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ class RadiusGraphPBC(RadiusGraph):
"""

def __call__(self, data):
data.edge_attr = None
data.edge_shifts = None
assert (
"batch" not in data
), "Periodic boundary conditions not currently supported on batches."
Expand All @@ -151,6 +149,22 @@ def __call__(self, data):
assert hasattr(
data, "pbc"
), "The data must contain data.pbc as a bool (True) or list of bools for the dimensions ([True, False, True]) to apply periodic boundary conditions."
# Ensure data consistency
if not isinstance(data.pos, torch.Tensor):
data.pos = torch.tensor(data.pos, dtype=torch.float)
device = (
data.pos.device
) # Have canonical device obtained from `data.pos` in-line with PyG RadiusGraph
if not isinstance(data.cell, torch.Tensor):
data.cell = torch.tensor(data.cell, dtype=torch.float, device=device)
if not isinstance(data.pbc, torch.Tensor):
data.pbc = torch.tensor(data.pbc, dtype=torch.bool, device=device)
# Ensure device consistency
if data.cell.device != device:
data.cell = data.cell.to(device)
if data.pbc.device != device:
data.pbc = data.pbc.to(device)

ase_atom_object = ase.Atoms(
positions=data.pos,
cell=data.cell,
Expand Down Expand Up @@ -190,7 +204,6 @@ def __call__(self, data):
)

# Assign to data
device = get_device(data)
data.edge_index = torch.stack(
[
torch.tensor(edge_src, dtype=torch.long, device=device),
Expand All @@ -206,7 +219,7 @@ def __call__(self, data):
# ASE returns the integer number of cell shifts. Multiply by the cell size to get the shift vector.
data.edge_shifts = torch.matmul(
torch.tensor(edge_cell_shifts, dtype=torch.float, device=device),
data.cell.float(),
data.cell,
) # Shape: [n_edges, 3]

return data
Expand Down

0 comments on commit ffa7f94

Please sign in to comment.