Skip to content

Commit

Permalink
Merge branch 'dtensor' of https://github.com/apax-hub/apax into dtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Nov 17, 2024
2 parents 073aba8 + 945c745 commit 9fdd618
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 68 deletions.
11 changes: 5 additions & 6 deletions apax/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,31 @@ class ExponentialRepulsion(Correction, extra="forbid"):
name: Literal["exponential"]
r_max: NonNegativeFloat = 2.0


class LatentEwald(Correction, extra="forbid"):
name: Literal["latent_ewald"]
kgrid: list
sigma: float=1.0
sigma: float = 1.0


EmpiricalCorrection = Union[ZBLRepulsion, ExponentialRepulsion, LatentEwald]



class PropertyHead(BaseModel, extra="forbid"):
"""
"""
""" """

name: str
aggregation: str = "none"
mode: str = "l0"

nn: List[PositiveInt] = [128, 128]
n_shallow_members : int = 0
n_shallow_members: int = 0
w_init: Literal["normal", "lecun"] = "lecun"
b_init: Literal["normal", "zeros"] = "zeros"
use_ntk: bool = False
dtype: Literal["fp32", "fp64"] = "fp32"



class BaseModelConfig(BaseModel, extra="forbid"):
"""
Configuration for the model.
Expand Down
31 changes: 15 additions & 16 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
n_jit_steps=1,
pos_unit: str = "Ang",
energy_unit: str = "eV",
additional_properties: list[tuple]= [],
additional_properties: list[tuple] = [],
pre_shuffle=False,
shuffle_buffer_size=1000,
ignore_labels=False,
Expand All @@ -111,15 +111,16 @@ def __init__(
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit, additional_properties)
self.labels = atoms_to_labels(
atoms_list, pos_unit, energy_unit, additional_properties
)
else:
self.labels = None

self.count = 0
self.buffer = deque()
self.file = Path(cache_path) / str(uuid.uuid4())


self.enqueue(min(self.buffer_size, self.n_data))

def steps_per_epoch(self) -> int:
Expand Down Expand Up @@ -165,10 +166,8 @@ def prepare_data(self, i):
for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
pad_shape = [(0, zeros_to_add)] + [(0,0)] * (len(shape)-1)
labels[name] = np.pad(
labels[name], pad_shape, "constant"
)
pad_shape = [(0, zeros_to_add)] + [(0, 0)] * (len(shape) - 1)
labels[name] = np.pad(labels[name], pad_shape, "constant")

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}
Expand Down Expand Up @@ -218,9 +217,7 @@ def make_signature(self) -> tf.TensorSpec:
if shape[0] == "natoms":
shape[0] = self.max_atoms

sig = tf.TensorSpec(
tuple(shape), dtype=tf.float64, name=name
)
sig = tf.TensorSpec(tuple(shape), dtype=tf.float64, name=name)
label_signature[name] = sig
signature = (input_signature, label_signature)
return signature
Expand Down Expand Up @@ -382,7 +379,9 @@ def next_power_of_two(x):


class BatchProcessor:
def __init__(self, cutoff, forces=True, stress=False, additional_properties = []) -> None:
def __init__(
self, cutoff, forces=True, stress=False, additional_properties=[]
) -> None:
self.cutoff = cutoff
self.forces = forces
self.stress = stress
Expand Down Expand Up @@ -451,10 +450,8 @@ def __call__(self, samples: list[dict]):
for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
pad_shape = [(0, zeros_to_add)] + [(0,0)] * (len(shape)-1)
labels[name] = np.pad(
labels[name], pad_shape, "constant"
)
pad_shape = [(0, zeros_to_add)] + [(0, 0)] * (len(shape) - 1)
labels[name] = np.pad(labels[name], pad_shape, "constant")

inputs = {k: np.array(v) for k, v in inputs.items()}
labels = {k: np.array(v) for k, v in labels.items()}
Expand Down Expand Up @@ -511,7 +508,9 @@ def __init__(
self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit, additional_properties)
self.labels = atoms_to_labels(
atoms_list, pos_unit, energy_unit, additional_properties
)
label_keys = self.labels.keys()

self.data = list(
Expand Down
34 changes: 15 additions & 19 deletions apax/layers/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ def __call__(self, R, dr_vec, Z, idx, box, properties):


class DirectCoulomb(EmpiricalEnergyTerm):

# apply_mask: bool = True
# TODO scale


def __call__(self, R, dr_vec, Z, idx, box, properties):
if "charge" not in properties:
raise KeyError("property 'charge' not found. Make sure to predict it in the model section")
raise KeyError(
"property 'charge' not found. Make sure to predict it in the model section"
)

q = properties["charge"]
idx_i, idx_j = idx[0], idx[1]
Expand All @@ -146,51 +146,47 @@ def __call__(self, R, dr_vec, Z, idx, box, properties):

# TODO mask, cutoff


return Ec



class LatentEwald(EmpiricalEnergyTerm):
"""
"""
kgrid: list[int] = field(default_factory=lambda: [2,2,2])
""" """

kgrid: list[int] = field(default_factory=lambda: [2, 2, 2])
sigma: float = 1.0
apply_mask: bool = True


def __call__(self, R, dr_vec, Z, idx, box, properties):
# Z shape n_atoms
if "charge" not in properties:
raise KeyError("property 'charge' not found. Make sure to predict it in the model section")
raise KeyError(
"property 'charge' not found. Make sure to predict it in the model section"
)

q = properties["charge"]

V = jnp.linalg.det(box)
Lx, Ly, Lz = jnp.linalg.norm(box, axis=1)

k_range_x = 2 * np.pi * jnp.arange(1,self.kgrid[0]) / Lx
k_range_y = 2 * np.pi * jnp.arange(1,self.kgrid[1]) / Ly
k_range_z = 2 * np.pi * jnp.arange(1,self.kgrid[2]) / Lz
k_range_x = 2 * np.pi * jnp.arange(1, self.kgrid[0]) / Lx
k_range_y = 2 * np.pi * jnp.arange(1, self.kgrid[1]) / Ly
k_range_z = 2 * np.pi * jnp.arange(1, self.kgrid[2]) / Lz

kx, ky, kz = jnp.meshgrid(k_range_x, k_range_y, k_range_z)
k = jnp.reshape(jnp.stack((kx, ky, kz), axis=-1), (-1, 3))


k2 = jnp.sum(k**2, axis=-1)

sf_k =q * jnp.exp(1j * jnp.einsum('id,jd->ij', R, k))
sf_k = q * jnp.exp(1j * jnp.einsum("id,jd->ij", R, k))
sf = jnp.sum(sf_k, axis=0)
S2 = jnp.abs(sf)**2

S2 = jnp.abs(sf) ** 2

# TODO mask by atom
E_lr = - jnp.sum(jnp.exp(-k2 * (self.sigma**2)/2) / k2 * S2) / V
E_lr = -jnp.sum(jnp.exp(-k2 * (self.sigma**2) / 2) / k2 * S2) / V

return E_lr



all_corrections = {
"zbl": ZBLRepulsion,
"exponential": ExponentialRepulsion,
Expand Down
8 changes: 2 additions & 6 deletions apax/layers/properties.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -54,9 +53,7 @@ class PropertyHead(nn.Module):
mode: str = "l0"
apply_mask: bool = True


def setup(self):

n_species = 119
scale_init = nn.initializers.constant(1.0)
self.scale = self.param(
Expand Down Expand Up @@ -87,15 +84,14 @@ def __call__(self, g, R, dr_vec, Z, idx, box):
p_i = p_i
elif self.mode == "l1":
Rc = R - jnp.mean(R, axis=0, keepdims=True)
r_hat = Rc / jnp.linalg.norm(Rc, axis=1)[:,None]
r_hat = Rc / jnp.linalg.norm(Rc, axis=1)[:, None]
p_i = p_i * R
elif self.mode == "symmetric_traceless_l2":
Rc = R - jnp.mean(R, axis=0, keepdims=True)
r_hat = Rc / jnp.linalg.norm(Rc, axis=1)[:,None]
r_hat = Rc / jnp.linalg.norm(Rc, axis=1)[:, None]
r_rt = jnp.einsum("ni, nj -> nij", r_hat, r_hat)
I = jnp.eye(3)
symmetrized = 3*r_rt - I
print(symmetrized.shape)
p_i = p_i[...,None] * symmetrized
else:
raise KeyError("unknown symmetry option")
Expand Down
6 changes: 2 additions & 4 deletions apax/nn/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def build_descriptor(
raise NotImplementedError("use a subclass to facilitate this")

def build_readout(self, head_config, is_feature_fn=False):

has_ensemble = "ensemble" in head_config.keys() and head_config["ensemble"]
if has_ensemble and head_config["ensemble"]["kind"] == "shallow":
n_shallow_ensemble = head_config["ensemble"]["n_members"]
Expand Down Expand Up @@ -119,7 +118,7 @@ def build_scale_shift(self, scale, shift):
)
return scale_shift

def build_property_heads(self, apply_mask: bool=True):
def build_property_heads(self, apply_mask: bool = True):
property_heads = []
for head in self.config["property_heads"]:
readout = self.build_readout(head)
Expand All @@ -133,7 +132,7 @@ def build_property_heads(self, apply_mask: bool=True):
property_heads.append(phead)
return property_heads

def build_corrections(self, apply_mask: bool=True):
def build_corrections(self, apply_mask: bool = True):
corrections = []
for correction in self.config["empirical_corrections"]:
correction = correction.copy()
Expand All @@ -147,7 +146,6 @@ def build_corrections(self, apply_mask: bool=True):

return corrections


def build_energy_model(
self,
scale=1.0,
Expand Down
11 changes: 8 additions & 3 deletions apax/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __call__(

# check for shallow ensemble
is_shallow_ensemble = E_i.shape[1] > 1
if is_shallow_ensemble: # is this necessary or is using sum with axis=0 enough?
if is_shallow_ensemble: # is this necessary or is using sum with axis=0 enough?
total_energies_ensemble = fp64_sum(E_i, axis=0)
# shape Nensemble
energy = total_energies_ensemble
Expand Down Expand Up @@ -159,7 +159,12 @@ def __call__(

if self.calc_stress:
stress = stress_times_vol(
make_energy_only_model(self.energy_model), R, box, Z=Z, neighbor=neighbor, offsets=offsets
make_energy_only_model(self.energy_model),
R,
box,
Z=Z,
neighbor=neighbor,
offsets=offsets,
)
prediction["stress"] = stress

Expand Down Expand Up @@ -187,7 +192,7 @@ def energy_chunk_fn(R, Z, neighbor, box, offsets):
Ei = energy_model(R, Z, neighbor, box, offsets)[start:end]
return Ei

grad_i_fn = jax.jacrev(energy_chunk_fn) # TODO
grad_i_fn = jax.jacrev(energy_chunk_fn) # TODO
return grad_i_fn


Expand Down
12 changes: 5 additions & 7 deletions apax/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def force_angle_loss(
"""
label, prediction = label[name], prediction[name]
dotp = normed_dotp(label, prediction)
return (1.0 - dotp)
return 1.0 - dotp


def force_angle_div_force_label(
Expand Down Expand Up @@ -184,14 +184,12 @@ def __post_init__(self):
def __call__(self, inputs: dict, prediction: dict, label: dict) -> float:
# TODO we may want to insert an additional `mask` argument for this method

divisor = inputs["n_atoms"]**self.atoms_exponent
batch_losses = self.loss_fn(
label, prediction, self.name, self.parameters
)
divisor = inputs["n_atoms"] ** self.atoms_exponent
batch_losses = self.loss_fn(label, prediction, self.name, self.parameters)

axes_to_add = len(batch_losses.shape) -1
axes_to_add = len(batch_losses.shape) - 1
for _ in range(axes_to_add):
divisor = divisor[...,None]
divisor = divisor[..., None]

arg = batch_losses / divisor
loss = self.weight * jnp.sum(jnp.mean(arg, axis=0))
Expand Down
5 changes: 1 addition & 4 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compute_property_shapes(config: Config):
if pconf["aggregation"] == "none":
shape.append("natoms")

feature_shapes = {"l0": [1], "l1": [3], "symmetric_traceless_l2": [3,3]}
feature_shapes = {"l0": [1], "l1": [3], "symmetric_traceless_l2": [3, 3]}

shape.extend(feature_shapes[pconf["mode"]])

Expand All @@ -103,9 +103,6 @@ def compute_property_shapes(config: Config):
return additional_properties





def initialize_datasets(config: Config):
"""
Initialize training and validation datasets based on the provided configuration.
Expand Down
4 changes: 3 additions & 1 deletion apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def fit(

state, start_epoch = load_state(state, latest_dir)
if start_epoch >= n_epochs:
print(f"Training has already completed ({start_epoch} >= {n_epochs}). Nothing to be done")
print(
f"Training has already completed ({start_epoch} >= {n_epochs}). Nothing to be done"
)
return

devices = len(jax.devices())
Expand Down
2 changes: 0 additions & 2 deletions apax/utils/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


def make_energy_only_model(energy_properties_model):
energy_model = lambda *args, **kwargs: energy_properties_model(*args, **kwargs)[0]
return energy_model

0 comments on commit 9fdd618

Please sign in to comment.