Skip to content

Commit

Permalink
Merge pull request #176 from apax-hub/130-unit-conversion-is-broken
Browse files Browse the repository at this point in the history
130 unit conversion is broken
  • Loading branch information
Tetracarbonylnickel authored Oct 6, 2023
2 parents 079608f + ee0b36e commit 0ff6b56
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Unit Tests
run: |
poetry run coverage run -m pytest tests
poetry run coverage run -m pytest -k "not slow"
poetry run coverage report
- name: Coverage Report
Expand Down
34 changes: 17 additions & 17 deletions apax/data/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ class PerElementRegressionShift:
dtypes = [float]

@staticmethod
def compute(atoms_list, shift_options) -> np.ndarray:
def compute(inputs, labels, shift_options) -> np.ndarray:
log.info("Computing per element energy regression.")

lambd = shift_options["energy_regularisation"]
energies = [atoms.get_potential_energy() for atoms in atoms_list]
numbers = [atoms.numbers for atoms in atoms_list]
system_sizes = [num.shape[0] for num in numbers]
energies = labels["fixed"]["energy"]
numbers = inputs["ragged"]["numbers"]
system_sizes = inputs["fixed"]["n_atoms"]

energies = np.array(energies)
system_sizes = np.array(system_sizes)
Expand Down Expand Up @@ -64,7 +64,7 @@ class IsolatedAtomEnergyShift:
dtypes = [dict[int:float]]

@staticmethod
def compute(atoms_list, shift_options):
def compute(inputs, labels, shift_options):
n_species = 119
elemental_energies_shift = np.zeros(n_species)
for k, v in shift_options.items():
Expand All @@ -79,11 +79,11 @@ class MeanEnergyRMSScale:
dtypes = []

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
# log.info("Computing per element energy regression.")
energies = [atoms.get_potential_energy() for atoms in atoms_list]
numbers = [atoms.numbers for atoms in atoms_list]
system_sizes = [num.shape[0] for num in numbers]
energies = labels["fixed"]["energy"]
numbers = inputs["ragged"]["numbers"]
system_sizes = inputs["fixed"]["n_atoms"]

energies = np.array(energies)
system_sizes = np.array(system_sizes)
Expand All @@ -109,11 +109,11 @@ class PerElementForceRMSScale:
dtypes = []

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
n_species = 119

forces = np.concatenate([atoms.get_forces() for atoms in atoms_list], axis=0)
numbers = np.concatenate([atoms.numbers for atoms in atoms_list], axis=0)
forces = np.concatenate(labels["ragged"]["forces"], axis=0)
numbers = np.concatenate(inputs["ragged"]["numbers"], axis=0)

elements = np.unique(numbers)

Expand All @@ -133,7 +133,7 @@ class GlobalCustomScale:
dtypes = [float]

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
element_scale = scale_options["factor"]
return element_scale

Expand All @@ -144,7 +144,7 @@ class PerElementCustomScale:
dtypes = [dict[int, float]]

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
n_species = 119
element_scale = np.ones(n_species)
for k, v in scale_options["factors"].items():
Expand All @@ -163,7 +163,7 @@ def compute(atoms_list, scale_options):


def compute_scale_shift_parameters(
train_atoms_list, shift_method, scale_method, shift_options, scale_options
inputs, labels, shift_method, scale_method, shift_options, scale_options
):
shift_methods = {method.name: method for method in shift_method_list}
scale_methods = {method.name: method for method in scale_method_list}
Expand All @@ -182,8 +182,8 @@ def compute_scale_shift_parameters(
shift_method = shift_methods[shift_method]
scale_method = scale_methods[scale_method]

shift_parameters = shift_method.compute(train_atoms_list, shift_options)
scale_parameters = scale_method.compute(train_atoms_list, scale_options)
shift_parameters = shift_method.compute(inputs, labels, shift_options)
scale_parameters = scale_method.compute(inputs, labels, scale_options)

ds_stats = DatasetStats(shift_parameters, scale_parameters)
return ds_stats
12 changes: 2 additions & 10 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tqdm import trange

from apax.config import parse_config
from apax.data.statistics import compute_scale_shift_parameters
from apax.model import ModelBuilder
from apax.train.checkpoints import load_params
from apax.train.metrics import initialize_metrics
Expand Down Expand Up @@ -124,16 +123,9 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

test_raw_ds = load_test_data(config, model_version_path, eval_path, n_test)
raw_ds = load_test_data(config, model_version_path, eval_path, n_test)

test_ds = initialize_dataset(config, test_raw_ds)
ds_stats = compute_scale_shift_parameters(
test_raw_ds.atoms_list,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)
test_ds, ds_stats = initialize_dataset(config, raw_ds)

init_input = test_ds.init_input()
init_box = np.array(init_input["box"][0])
Expand Down
32 changes: 19 additions & 13 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def load_data_files(data_config, model_version_path):
return train_raw_ds, val_raw_ds


def initialize_dataset(config, raw_ds):
# Note(Moritz): external labels are actually not read in anywhere
def initialize_dataset(config, raw_ds, calc_stats: bool = True):
inputs, labels = create_dict_dataset(
raw_ds.atoms_list,
r_max=config.model.r_max,
Expand All @@ -82,14 +81,29 @@ def initialize_dataset(config, raw_ds):
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

if calc_stats:
ds_stats = compute_scale_shift_parameters(
inputs,
labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)

dataset = TFPipeline(
inputs,
labels,
config.n_epochs,
config.data.batch_size,
buffer_size=config.data.shuffle_buffer_size,
)
return dataset

if calc_stats:
return dataset, ds_stats
else:
return dataset


def maximize_l2_cache():
Expand Down Expand Up @@ -201,16 +215,8 @@ def run(user_config, log_file="train.log", log_level="error"):
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_ds = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds)

ds_stats = compute_scale_shift_parameters(
train_raw_ds.atoms_list,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)

log.info("Initializing Model")
init_input = train_ds.init_input()
Expand Down
1 change: 0 additions & 1 deletion apax/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def atoms_to_arrays(

inputs["ragged"]["numbers"].append(atoms.numbers)
inputs["fixed"]["n_atoms"].append(len(atoms))

for key, val in atoms.calc.results.items():
if key == "forces":
labels["ragged"][key].append(
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
slow: mark a test as slow and should only run explicitly
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import urllib
import zipfile
from typing import List

import numpy as np
Expand Down Expand Up @@ -74,3 +77,33 @@ def example_atoms(num_data: int, pbc: bool, calc_results: List[str]) -> Atoms:
def get_tmp_path(tmp_path_factory):
test_path = tmp_path_factory.mktemp("apax_tests")
return test_path


@pytest.fixture(scope="session")
def get_md22_stachyose(get_tmp_path):
url = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"
data_path = get_tmp_path / "data"
file_path = data_path / "md22_stachyose.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)

return file_path


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path
Empty file.
89 changes: 89 additions & 0 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
n_epochs: 200
seed: 0

data:
directory: models/
experiment: test

data_path: <PATH>

n_train: 1000
n_valid: 100

batch_size: 32
valid_batch_size: 100

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}
shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV

model:
n_basis: 7
n_radial: 5
nn: [512, 512]

r_max: 6.5
r_min: 0.5

calc_stress: false
use_zbl: true

b_init: normal
descriptor_dtype: fp32
readout_dtype: fp32
scale_shift_dtype: fp32

metrics:
- name: energy
reductions:
- mae
- name: forces
reductions:
- mae
- mse
# - name: stress
# reductions:
# - mae
# - mse

loss:
- loss_type: structures
name: energy
weight: 1.0
- loss_type: structures
name: forces
weight: 8.0
- loss_type: cosine_sim
name: forces
weight: 0.1
# - loss_type: structures
# name: stress
# weight: 1.0

optimizer:
opt_name: adam
opt_kwargs: {}
emb_lr: 0.02
nn_lr: 0.03
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0

callbacks:
- name: csv

checkpoints:
ckpt_interval: 1
# The options below are used for transfer learning
# base_model_checkpoint: null
# reset_layers: []

progress_bar:
disable_epoch_pbar: true
disable_nl_pbar: true

maximize_l2_cache: false
Loading

0 comments on commit 0ff6b56

Please sign in to comment.