Skip to content

Commit

Permalink
Merge pull request #185 from apax-hub/dev
Browse files Browse the repository at this point in the history
Accumulated Changes since March 23
  • Loading branch information
M-R-Schaefer authored Oct 25, 2023
2 parents 16e2217 + c6a6fe3 commit f5c9cc3
Show file tree
Hide file tree
Showing 83 changed files with 5,136 additions and 3,305 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"

- name: Run Poetry Image
uses: abatilo/actions-poetry@v2.0.0
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/linting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"

- name: Install isort
run: |
Expand All @@ -37,7 +37,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"

- name: Install flake8
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"

- name: Run Poetry Image
uses: abatilo/actions-poetry@v2.0.0
Expand All @@ -28,7 +28,7 @@ jobs:
- name: Unit Tests
run: |
poetry run coverage run -m pytest tests
poetry run coverage run -m pytest -k "not slow"
poetry run coverage report
- name: Coverage Report
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ main.py
tmp/
.npz
.traj

events.out.*

# Translations
*.mo
Expand Down
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,16 @@ If you want to enable GPU support, please overwrite the jaxlib version:

```bash
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

CUDA 12 installation. Wheels only available on linux.
```bash
pip install --upgrade "jax[cuda12_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

CUDA 11 installation. Wheels only available on linux.
```bash
pip install --upgrade "jax[cuda11_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

See the [Jax installation instructions](https://github.com/google/jax#installation) for more details.
Expand All @@ -62,7 +69,7 @@ apax template train # use --full for a template with all input options
```

Please refer to the documentation LINK for a detailed explanation of all parameters.
The documentation can convenienty be accessed by runnning `apax docs`.
The documentation can convenienty be accessed by running `apax docs`.

## Molecular Dynamics

Expand Down
5 changes: 0 additions & 5 deletions apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import os
import warnings

import tensorflow as tf
from jax.config import config as jax_config

tf.config.set_visible_devices([], "GPU")

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
warnings.filterwarnings(action="ignore", category=FutureWarning, module=r"jax.*scatter")
jax_config.update("jax_enable_x64", True)
3 changes: 3 additions & 0 deletions apax/bal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from apax.bal.api import kernel_selection

__all__ = ["kernel_selection"]
99 changes: 99 additions & 0 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from functools import partial
from typing import List, Union

import jax
import numpy as np
from ase import Atoms
from click import Path
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import TFPipeline
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import restore_parameters
from apax.train.run import RawDataset, initialize_dataset


def create_feature_fn(
model: EnergyModel,
params,
base_feature_map,
feature_transforms=[],
is_ensemble: bool = False,
):
"""
Converts a model into a feature map and transforms it as needed and
sets it up for use in copmuting the features of a dataset.
All transformations are applied on the feature function, not on computed features.
Only the final function is jit compiled.
"""
feature_fn = base_feature_map.apply(model)

if is_ensemble:
feature_fn = transforms.ensemble_features(feature_fn)

for transform in feature_transforms:
feature_fn = transform.apply(feature_fn)

feature_fn = transforms.batch_features(feature_fn)
feature_fn = partial(feature_fn, params)
feature_fn = jax.jit(feature_fn)
return feature_fn


def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
ds = dataset.batch(processing_batch_size)

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
for i, (inputs, _) in enumerate(ds):
g = feature_fn(inputs)
features.append(np.asarray(g))
pbar.update(g.shape[0])
pbar.close()

features = np.concatenate(features, axis=0)
return features


def kernel_selection(
model_dir: Union[Path, List[Path]],
train_atoms: List[Atoms],
pool_atoms: List[Atoms],
base_fm_options: dict,
selection_method: str,
feature_transforms: list = [],
selection_batch_size: int = 10,
processing_batch_size: int = 64,
):
n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir)
is_ensemble = n_models > 1

selection_fn = {
"max_dist": selection.max_dist_selection,
}[selection_method]

base_feature_map = feature_maps.FeatureMapOptions(base_fm_options)

config, params = restore_parameters(model_dir)

n_train = len(train_atoms)
dataset = initialize_dataset(config, RawDataset(atoms_list=train_atoms + pool_atoms))

init_box = dataset.init_input()["box"][0]

builder = ModelBuilder(config.model.get_dict(), n_species=119)
model = builder.build_energy_model(apply_mask=True, init_box=init_box)

feature_fn = create_feature_fn(
model, params, base_feature_map, feature_transforms, is_ensemble
)
g = compute_features(feature_fn, dataset, processing_batch_size)
km = kernel.KernelMatrix(g, n_train)
new_indices = selection_fn(km, selection_batch_size)

return new_indices
81 changes: 81 additions & 0 deletions apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Literal, Tuple, Union

import jax
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from pydantic import BaseModel, TypeAdapter


def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]:
"""Separate params into those belonging to a selected layer
and the remaining ones.
"""
p_flat = flatten_dict(params)

feature_layer_params = {k: v for k, v in p_flat.items() if layer_name in k}
remaining_params = {k: v for k, v in p_flat.items() if layer_name not in k}

if len(feature_layer_params.keys()) > 2: # w and b
print(feature_layer_params.keys())
raise ValueError("Found more than one layer of the specified name")

return feature_layer_params, remaining_params


class LastLayerGradientFeatures(BaseModel, extra="forbid"):
"""
Model transfomration which computes the gradient of the output
wrt. the specified layer.
https://arxiv.org/pdf/2203.09410
"""

name: Literal["ll_grad"]
layer_name: str = "dense_2"

def apply(self, model):
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)

def inner(ll_params):
ll_params.update(remaining_params)
full_params = unflatten_dict(ll_params)

# TODO find better abstraction for inputs
R, Z, idx, box, offsets = (
inputs["positions"],
inputs["numbers"],
inputs["idx"],
inputs["box"],
inputs["offsets"],
)
return model.apply(full_params, R, Z, idx, box, offsets)

g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)

g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll)
(gw, gb), _ = jax.tree_util.tree_flatten(g_flat)

bias_factor = 0.1
weight_factor = jnp.sqrt(1 / gw.shape[-1])
g_scaled = [weight_factor * gw, bias_factor * gb]

g = jnp.concatenate(g_scaled)

return g

return ll_grad


class IdentityFeatures(BaseModel, extra="forbid"):
"""Identity feature map. For debugging purposes"""

name: Literal["identity"]

def apply(self, model):
return model.apply


FeatureMapOptions = TypeAdapter(
Union[LastLayerGradientFeatures, IdentityFeatures]
).validate_python
24 changes: 24 additions & 0 deletions apax/bal/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import einops
import numpy as np


class KernelMatrix:
"""
Matrix representation of a kernel defined by a feature map g
K_{ij} = \\sum_{k} g_{ik} g_{jk}
"""

def __init__(self, g: np.ndarray, n_train: int):
self.num_columns = g.shape[0]
self.g = g
self.diagonal = einops.einsum(g, g, "s feature, s feature -> s")
self.n_train = n_train

def compute_column(self, idx: int) -> np.ndarray:
return einops.einsum(self.g, self.g[idx, :], "s feature, feature -> s")

def score(self, idx: int) -> np.ndarray:
"""Computes the distance of sample i from all other samples j as
K_{ii} + K_{jj} - 2 K_{ij}
"""
return self.diagonal[idx] + self.diagonal - 2 * self.compute_column(idx)
35 changes: 35 additions & 0 deletions apax/bal/selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np

from apax.bal.kernel import KernelMatrix


def max_dist_selection(matrix: KernelMatrix, batch_size: int):
"""
Iteratively selects samples from the pool which are
most distant from all previously selected samples.
\\argmax_{S \\in \\mathbb{X}_{rem}} \\min_{S' \\in \\mathbb{X}_{sel} } d(S, S')
https://arxiv.org/pdf/2203.09410.pdf
https://doi.org/10.1039/D2DD00034B
"""
n_train = matrix.n_train

min_squared_distances = matrix.diagonal
min_squared_distances[:n_train] = -np.inf

# Use max norm for first point
new_idx = np.argmax(min_squared_distances)
selected_idxs = list(range(n_train)) + [new_idx]

for _ in range(1, batch_size):
squared_distances = matrix.score(new_idx)

squared_distances[selected_idxs] = -np.inf
min_squared_distances = np.minimum(min_squared_distances, squared_distances)

new_idx = np.argmax(min_squared_distances)
selected_idxs.append(new_idx)

return (
np.array(selected_idxs[n_train:]) - n_train
) # shift by number of train datapoints
29 changes: 29 additions & 0 deletions apax/bal/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax
import jax.numpy as jnp


def ensemble_features(feature_fn):
ensemble_feature_fn = jax.vmap(feature_fn, (0, None), 0)

def averaged_feature_fn(params, x):
g = ensemble_feature_fn(params, x)

if len(g.shape) != 2:
# models, features
raise ValueError(
"Dimension mismatch for input features. Expected shape (models,"
f" features), got {g.shape}"
)

n_models = g.shape[0]
# sqrt since the kernel is K = g^T g
feature_scale_factor = jnp.sqrt(1 / n_models)
g_ens = feature_scale_factor * jnp.sum(g, axis=0) # shape: n_features
return g_ens

return averaged_feature_fn


def batch_features(feature_fn):
batched_feature_fn = jax.vmap(feature_fn, (None, 0), 0)
return batched_feature_fn
Loading

0 comments on commit f5c9cc3

Please sign in to comment.