Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 5, 2024
1 parent 2c08110 commit 4aa8eae
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
4 changes: 1 addition & 3 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from collections import deque
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator, List
from typing import Dict, Iterator

from ase import Atoms
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -52,7 +51,6 @@ def __init__(
ignore_labels=False,
cache_path=".",
) -> None:

self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
Expand Down
6 changes: 1 addition & 5 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def compute_nl(positions, box, r_max):
else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS",
positions=positions,
cutoff=r_max,
cell=box,
pbc=[True, True, True]
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
offsets = np.matmul(offsets, box)
Expand Down
10 changes: 6 additions & 4 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down
5 changes: 4 additions & 1 deletion apax/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ def prune_dict(data_dict):
pruned = {key: val for key, val in data_dict.items() if len(val) != 0}
return pruned


def is_periodic(box):
pbc_dims = np.any(np.abs(box) > 1e-6)
if np.all(pbc_dims == True) or np.all(pbc_dims == False):
return pbc_dims
else:
msg = f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
msg = (
f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
)
raise ValueError(msg)


Expand Down

0 comments on commit 4aa8eae

Please sign in to comment.