Skip to content

Commit

Permalink
Merge branch 'dev' into torch
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 8, 2024
2 parents 844a393 + 0b59ae2 commit 78e9ff8
Show file tree
Hide file tree
Showing 31 changed files with 3,984 additions and 749 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras
- name: Unit Tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/
Expand Down
2 changes: 2 additions & 0 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def kernel_selection(
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

_, init_box = dataset.init_input()
Expand Down
1 change: 0 additions & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,3 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false
15 changes: 15 additions & 0 deletions apax/config/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import MutableMapping
from typing import Union

import yaml
Expand Down Expand Up @@ -28,3 +29,17 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config = MDConfig.model_validate(config)

return config


def flatten(dictionary, parent_key="", separator="_"):
"""https://stackoverflow.com/questions/6027558/
flatten-nested-dictionaries-compressing-keys
"""
items = []
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten(value, new_key, separator=separator).items())
else:
items.append((new_key, value))
return dict(items)
52 changes: 44 additions & 8 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import logging
import os
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union

import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
create_model,
model_validator,
)
from typing_extensions import Annotated

from apax.data.statistics import scale_method_list, shift_method_list

Expand Down Expand Up @@ -235,16 +237,47 @@ class LossConfig(BaseModel, extra="forbid"):
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
class CSVCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the training callbacks.
Configuration of the CSVCallback.
Parameters
----------
name: Keyword of the callback used. Currently we implement "csv" and "tensorboard".
name: Keyword of the callback used..
"""

name: str
name: Literal["csv"]


class TBCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the TensorBoard callback.
Parameters
----------
name: Keyword of the callback used..
"""

name: Literal["tensorboard"]


class MLFlowCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the MLFlow callback.
Parameters
----------
name: Keyword of the callback used.
experiment: Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>
"""

name: Literal["mlflow"]
experiment: str


CallBack = Annotated[
Union[CSVCallback, TBCallback, MLFlowCallback], Field(discriminator="name")
]


class TrainProgressbarConfig(BaseModel, extra="forbid"):
Expand All @@ -254,11 +287,11 @@ class TrainProgressbarConfig(BaseModel, extra="forbid"):
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_nl_pbar: Set to True to disable the NL precomputation progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""

disable_epoch_pbar: bool = False
disable_nl_pbar: bool = False
disable_batch_pbar: bool = True


class CheckpointConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -298,20 +331,23 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
data_parallel: Automatically uses all available GPUs for data parallel training.
Set to false to force single device training.
"""

n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1
data_parallel: int = True

data: DataConfig
model: ModelConfig = ModelConfig()
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
callbacks: List[CallbackConfig] = [CallbackConfig(name="csv")]
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()

Expand Down
72 changes: 40 additions & 32 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow as tf

from apax.data.preprocessing import compute_nl, prefetch_to_single_device
from apax.utils.convert import atoms_to_inputs, atoms_to_labels
from apax.utils.convert import atoms_to_inputs, atoms_to_labels, unit_dict

log = logging.getLogger(__name__)

Expand All @@ -23,12 +23,13 @@ def pad_nl(idx, offsets, max_neighbors):
return idx, offsets


def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
def find_largest_system(inputs, r_max) -> tuple[int]:
positions, boxes = inputs["positions"], inputs["box"]
max_atoms = np.max(inputs["n_atoms"])

max_nbrs = 0
for position, box in zip(inputs["positions"], inputs["box"]):
neighbor_idxs, _ = compute_nl(position, box, r_max)
for pos, box in zip(positions, boxes):
neighbor_idxs, _ = compute_nl(pos, box, r_max)
n_neighbors = neighbor_idxs.shape[1]
max_nbrs = max(max_nbrs, n_neighbors)

Expand All @@ -38,39 +39,41 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
class InMemoryDataset:
def __init__(
self,
atoms,
atoms_list,
cutoff,
bs,
n_epochs,
buffer_size=1000,
n_jit_steps=1,
pos_unit: str = "Ang",
energy_unit: str = "eV",
pre_shuffle=False,
ignore_labels=False,
cache_path=".",
) -> None:
if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms)

self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff)
if pre_shuffle:
shuffle(atoms_list)
self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff)
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

if atoms[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms)
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
else:
self.labels = None

self.n_data = len(atoms)
self.count = 0
self.cutoff = cutoff
self.buffer = deque()
self.batch_size = self.validate_batch_size(bs)
self.n_jit_steps = n_jit_steps
self.file = Path(cache_path) / str(uuid.uuid4())

self.enqueue(min(self.buffer_size, self.n_data))
Expand Down Expand Up @@ -105,9 +108,6 @@ def prepare_data(self, i):
inputs["numbers"] = np.pad(
inputs["numbers"], (0, zeros_to_add), "constant"
).astype(np.int16)
inputs["n_atoms"] = np.pad(
inputs["n_atoms"], (0, zeros_to_add), "constant"
).astype(np.int16)

if not self.labels:
return inputs
Expand All @@ -117,7 +117,6 @@ def prepare_data(self, i):
labels["forces"] = np.pad(
labels["forces"], ((0, zeros_to_add), (0, 0)), "constant"
)

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}
return (inputs, labels)
Expand Down Expand Up @@ -164,8 +163,9 @@ def make_signature(self) -> tf.TensorSpec:

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
positions = self.sample_atoms.positions
box = self.sample_atoms.cell.array
positions = self.sample_atoms.positions * unit_dict[self.pos_unit]
box = self.sample_atoms.cell.array * unit_dict[self.pos_unit]
# For an input sample, it does not matter whether pos is fractional or cartesian
idx, offsets = compute_nl(positions, box, self.cutoff)
inputs = (
positions,
Expand Down Expand Up @@ -201,7 +201,7 @@ def __iter__(self):
space = self.n_data - self.count
self.enqueue(space)

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -223,10 +223,12 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
Expand All @@ -235,7 +237,9 @@ def batch(self) -> Iterator[jax.Array]:
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def cleanup(self):
Expand All @@ -261,7 +265,7 @@ def __iter__(self):
self.count = 0
self.enqueue(space)

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -279,15 +283,19 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds


Expand Down
Loading

0 comments on commit 78e9ff8

Please sign in to comment.