Skip to content

Commit

Permalink
ds stats now calculated with the converted inputs and labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Oct 2, 2023
1 parent f9cee0a commit 5431694
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 41 deletions.
34 changes: 17 additions & 17 deletions apax/data/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ class PerElementRegressionShift:
dtypes = [float]

@staticmethod
def compute(atoms_list, shift_options) -> np.ndarray:
def compute(inputs, labels, shift_options) -> np.ndarray:
log.info("Computing per element energy regression.")

lambd = shift_options["energy_regularisation"]
energies = [atoms.get_potential_energy() for atoms in atoms_list]
numbers = [atoms.numbers for atoms in atoms_list]
system_sizes = [num.shape[0] for num in numbers]
energies = labels["ragged"]["energy"]
numbers = inputs["ragged"]["numbers"]
system_sizes = inputs["fixed"]["n_atoms"]

energies = np.array(energies)
system_sizes = np.array(system_sizes)
Expand Down Expand Up @@ -64,7 +64,7 @@ class IsolatedAtomEnergyShift:
dtypes = [dict[int:float]]

@staticmethod
def compute(atoms_list, shift_options):
def compute(inputs, labels, shift_options):
n_species = 119
elemental_energies_shift = np.zeros(n_species)
for k, v in shift_options.items():
Expand All @@ -79,11 +79,11 @@ class MeanEnergyRMSScale:
dtypes = []

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
# log.info("Computing per element energy regression.")
energies = [atoms.get_potential_energy() for atoms in atoms_list]
numbers = [atoms.numbers for atoms in atoms_list]
system_sizes = [num.shape[0] for num in numbers]
energies = labels["ragged"]["energy"]
numbers = inputs["ragged"]["numbers"]
system_sizes = inputs["fixed"]["n_atoms"]

energies = np.array(energies)
system_sizes = np.array(system_sizes)
Expand All @@ -109,11 +109,11 @@ class PerElementForceRMSScale:
dtypes = []

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
n_species = 119

forces = np.concatenate([atoms.get_forces() for atoms in atoms_list], axis=0)
numbers = np.concatenate([atoms.numbers for atoms in atoms_list], axis=0)
forces = np.concatenate(labels["ragged"]["forces"], axis=0)
numbers = np.concatenate(inputs["ragged"]["numbers"], axis=0)

elements = np.unique(numbers)

Expand All @@ -133,7 +133,7 @@ class GlobalCustomScale:
dtypes = [float]

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
element_scale = scale_options["factor"]
return element_scale

Expand All @@ -144,7 +144,7 @@ class PerElementCustomScale:
dtypes = [dict[int, float]]

@staticmethod
def compute(atoms_list, scale_options):
def compute(inputs, labels, scale_options):
n_species = 119
element_scale = np.ones(n_species)
for k, v in scale_options["factors"].items():
Expand All @@ -163,7 +163,7 @@ def compute(atoms_list, scale_options):


def compute_scale_shift_parameters(
train_atoms_list, shift_method, scale_method, shift_options, scale_options
inputs, labels, shift_method, scale_method, shift_options, scale_options
):
shift_methods = {method.name: method for method in shift_method_list}
scale_methods = {method.name: method for method in scale_method_list}
Expand All @@ -182,8 +182,8 @@ def compute_scale_shift_parameters(
shift_method = shift_methods[shift_method]
scale_method = scale_methods[scale_method]

shift_parameters = shift_method.compute(train_atoms_list, shift_options)
scale_parameters = scale_method.compute(train_atoms_list, scale_options)
shift_parameters = shift_method.compute(inputs, labels, shift_options)
scale_parameters = scale_method.compute(inputs, labels, scale_options)

ds_stats = DatasetStats(shift_parameters, scale_parameters)
return ds_stats
12 changes: 2 additions & 10 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tqdm import trange

from apax.config import parse_config
from apax.data.statistics import compute_scale_shift_parameters
from apax.model import ModelBuilder
from apax.train.checkpoints import load_params
from apax.train.metrics import initialize_metrics
Expand Down Expand Up @@ -124,16 +123,9 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

test_raw_ds = load_test_data(config, model_version_path, eval_path, n_test)
raw_ds = load_test_data(config, model_version_path, eval_path, n_test)

test_ds = initialize_dataset(config, test_raw_ds)
ds_stats = compute_scale_shift_parameters(
test_raw_ds.atoms_list,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)
test_ds, ds_stats = initialize_dataset(config, raw_ds)

init_input = test_ds.init_input()
init_box = np.array(init_input["box"][0])
Expand Down
36 changes: 23 additions & 13 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ def load_data_files(data_config, model_version_path):
return train_raw_ds, val_raw_ds


def initialize_dataset(config, raw_ds):
def initialize_dataset(
config, raw_ds, calc_stats: bool = True
): # have to be compatible with eval and only on ds
# Note(Moritz): external labels are actually not read in anywhere
# Answer(Nico): external labels are read in the filde utils/data.py in
# the load_data() function.
inputs, labels = create_dict_dataset(
raw_ds.atoms_list,
r_max=config.model.r_max,
Expand All @@ -82,14 +86,28 @@ def initialize_dataset(config, raw_ds):
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

dataset = TFPipeline(
inputs,
labels,
config.n_epochs,
config.data.batch_size,
buffer_size=config.data.shuffle_buffer_size,
)
return dataset

if calc_stats:
ds_stats = compute_scale_shift_parameters(
inputs,
labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)

return dataset, ds_stats
else:
return dataset


def maximize_l2_cache():
Expand Down Expand Up @@ -200,17 +218,9 @@ def run(user_config, log_file="train.log", log_level="error"):
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_ds = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds)

ds_stats = compute_scale_shift_parameters(
train_raw_ds.atoms_list,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)
raw_ds = load_data_files(config.data, model_version_path)
train_ds, ds_stats = initialize_dataset(config, raw_ds)
val_ds = initialize_dataset(config, raw_ds, calc_stats=False)

log.info("Initializing Model")
init_input = train_ds.init_input()
Expand Down
19 changes: 18 additions & 1 deletion tests/unit_tests/data/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@ def test_energy_per_element():

atoms_list = [atoms1, atoms2, atoms3]
energies = []
n_atoms = []
for atoms in atoms_list:
n_atoms.append(len(atoms))
print(len(atoms))
energy = np.sum(dummy_energies[atoms.numbers])
energies.append(energy)
atoms.calc = SinglePointCalculator(atoms, energy=energy)

labels = {
"ragged": {
"energy": [atoms.get_potential_energy() for atoms in atoms_list],
}
}
inputs = {
"ragged": {
"numbers": [atoms.numbers for atoms in atoms_list],
},
"fixed": {
"n_atoms": n_atoms,
},
}

elemental_shift = PerElementRegressionShift.compute(
atoms_list, {"energy_regularisation": 0.0}
inputs, labels, {"energy_regularisation": 0.0}
)
regression_energies = []
for atoms in atoms_list:
Expand Down

0 comments on commit 5431694

Please sign in to comment.