-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #185 from apax-hub/dev
Accumulated Changes since March 23
- Loading branch information
Showing
83 changed files
with
5,136 additions
and
3,305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,7 +57,7 @@ main.py | |
tmp/ | ||
.npz | ||
.traj | ||
|
||
events.out.* | ||
|
||
# Translations | ||
*.mo | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,6 @@ | ||
import os | ||
import warnings | ||
|
||
import tensorflow as tf | ||
from jax.config import config as jax_config | ||
|
||
tf.config.set_visible_devices([], "GPU") | ||
|
||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" | ||
warnings.filterwarnings(action="ignore", category=FutureWarning, module=r"jax.*scatter") | ||
jax_config.update("jax_enable_x64", True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from apax.bal.api import kernel_selection | ||
|
||
__all__ = ["kernel_selection"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from functools import partial | ||
from typing import List, Union | ||
|
||
import jax | ||
import numpy as np | ||
from ase import Atoms | ||
from click import Path | ||
from tqdm import trange | ||
|
||
from apax.bal import feature_maps, kernel, selection, transforms | ||
from apax.data.input_pipeline import TFPipeline | ||
from apax.model.builder import ModelBuilder | ||
from apax.model.gmnn import EnergyModel | ||
from apax.train.checkpoints import restore_parameters | ||
from apax.train.run import RawDataset, initialize_dataset | ||
|
||
|
||
def create_feature_fn( | ||
model: EnergyModel, | ||
params, | ||
base_feature_map, | ||
feature_transforms=[], | ||
is_ensemble: bool = False, | ||
): | ||
""" | ||
Converts a model into a feature map and transforms it as needed and | ||
sets it up for use in copmuting the features of a dataset. | ||
All transformations are applied on the feature function, not on computed features. | ||
Only the final function is jit compiled. | ||
""" | ||
feature_fn = base_feature_map.apply(model) | ||
|
||
if is_ensemble: | ||
feature_fn = transforms.ensemble_features(feature_fn) | ||
|
||
for transform in feature_transforms: | ||
feature_fn = transform.apply(feature_fn) | ||
|
||
feature_fn = transforms.batch_features(feature_fn) | ||
feature_fn = partial(feature_fn, params) | ||
feature_fn = jax.jit(feature_fn) | ||
return feature_fn | ||
|
||
|
||
def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int): | ||
"""Compute the features of a dataset.""" | ||
features = [] | ||
n_data = dataset.n_data | ||
ds = dataset.batch(processing_batch_size) | ||
|
||
pbar = trange(n_data, desc="Computing features", ncols=100, leave=True) | ||
for i, (inputs, _) in enumerate(ds): | ||
g = feature_fn(inputs) | ||
features.append(np.asarray(g)) | ||
pbar.update(g.shape[0]) | ||
pbar.close() | ||
|
||
features = np.concatenate(features, axis=0) | ||
return features | ||
|
||
|
||
def kernel_selection( | ||
model_dir: Union[Path, List[Path]], | ||
train_atoms: List[Atoms], | ||
pool_atoms: List[Atoms], | ||
base_fm_options: dict, | ||
selection_method: str, | ||
feature_transforms: list = [], | ||
selection_batch_size: int = 10, | ||
processing_batch_size: int = 64, | ||
): | ||
n_models = 1 if isinstance(model_dir, (Path, str)) else len(model_dir) | ||
is_ensemble = n_models > 1 | ||
|
||
selection_fn = { | ||
"max_dist": selection.max_dist_selection, | ||
}[selection_method] | ||
|
||
base_feature_map = feature_maps.FeatureMapOptions(base_fm_options) | ||
|
||
config, params = restore_parameters(model_dir) | ||
|
||
n_train = len(train_atoms) | ||
dataset = initialize_dataset(config, RawDataset(atoms_list=train_atoms + pool_atoms)) | ||
|
||
init_box = dataset.init_input()["box"][0] | ||
|
||
builder = ModelBuilder(config.model.get_dict(), n_species=119) | ||
model = builder.build_energy_model(apply_mask=True, init_box=init_box) | ||
|
||
feature_fn = create_feature_fn( | ||
model, params, base_feature_map, feature_transforms, is_ensemble | ||
) | ||
g = compute_features(feature_fn, dataset, processing_batch_size) | ||
km = kernel.KernelMatrix(g, n_train) | ||
new_indices = selection_fn(km, selection_batch_size) | ||
|
||
return new_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from typing import Literal, Tuple, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from flax.traverse_util import flatten_dict, unflatten_dict | ||
from pydantic import BaseModel, TypeAdapter | ||
|
||
|
||
def extract_feature_params(params: dict, layer_name: str) -> Tuple[dict, dict]: | ||
"""Separate params into those belonging to a selected layer | ||
and the remaining ones. | ||
""" | ||
p_flat = flatten_dict(params) | ||
|
||
feature_layer_params = {k: v for k, v in p_flat.items() if layer_name in k} | ||
remaining_params = {k: v for k, v in p_flat.items() if layer_name not in k} | ||
|
||
if len(feature_layer_params.keys()) > 2: # w and b | ||
print(feature_layer_params.keys()) | ||
raise ValueError("Found more than one layer of the specified name") | ||
|
||
return feature_layer_params, remaining_params | ||
|
||
|
||
class LastLayerGradientFeatures(BaseModel, extra="forbid"): | ||
""" | ||
Model transfomration which computes the gradient of the output | ||
wrt. the specified layer. | ||
https://arxiv.org/pdf/2203.09410 | ||
""" | ||
|
||
name: Literal["ll_grad"] | ||
layer_name: str = "dense_2" | ||
|
||
def apply(self, model): | ||
def ll_grad(params, inputs): | ||
ll_params, remaining_params = extract_feature_params(params, self.layer_name) | ||
|
||
def inner(ll_params): | ||
ll_params.update(remaining_params) | ||
full_params = unflatten_dict(ll_params) | ||
|
||
# TODO find better abstraction for inputs | ||
R, Z, idx, box, offsets = ( | ||
inputs["positions"], | ||
inputs["numbers"], | ||
inputs["idx"], | ||
inputs["box"], | ||
inputs["offsets"], | ||
) | ||
return model.apply(full_params, R, Z, idx, box, offsets) | ||
|
||
g_ll = jax.grad(inner)(ll_params) | ||
g_ll = unflatten_dict(g_ll) | ||
|
||
g_flat = jax.tree_map(lambda arr: jnp.reshape(arr, (-1,)), g_ll) | ||
(gw, gb), _ = jax.tree_util.tree_flatten(g_flat) | ||
|
||
bias_factor = 0.1 | ||
weight_factor = jnp.sqrt(1 / gw.shape[-1]) | ||
g_scaled = [weight_factor * gw, bias_factor * gb] | ||
|
||
g = jnp.concatenate(g_scaled) | ||
|
||
return g | ||
|
||
return ll_grad | ||
|
||
|
||
class IdentityFeatures(BaseModel, extra="forbid"): | ||
"""Identity feature map. For debugging purposes""" | ||
|
||
name: Literal["identity"] | ||
|
||
def apply(self, model): | ||
return model.apply | ||
|
||
|
||
FeatureMapOptions = TypeAdapter( | ||
Union[LastLayerGradientFeatures, IdentityFeatures] | ||
).validate_python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import einops | ||
import numpy as np | ||
|
||
|
||
class KernelMatrix: | ||
""" | ||
Matrix representation of a kernel defined by a feature map g | ||
K_{ij} = \\sum_{k} g_{ik} g_{jk} | ||
""" | ||
|
||
def __init__(self, g: np.ndarray, n_train: int): | ||
self.num_columns = g.shape[0] | ||
self.g = g | ||
self.diagonal = einops.einsum(g, g, "s feature, s feature -> s") | ||
self.n_train = n_train | ||
|
||
def compute_column(self, idx: int) -> np.ndarray: | ||
return einops.einsum(self.g, self.g[idx, :], "s feature, feature -> s") | ||
|
||
def score(self, idx: int) -> np.ndarray: | ||
"""Computes the distance of sample i from all other samples j as | ||
K_{ii} + K_{jj} - 2 K_{ij} | ||
""" | ||
return self.diagonal[idx] + self.diagonal - 2 * self.compute_column(idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
|
||
from apax.bal.kernel import KernelMatrix | ||
|
||
|
||
def max_dist_selection(matrix: KernelMatrix, batch_size: int): | ||
""" | ||
Iteratively selects samples from the pool which are | ||
most distant from all previously selected samples. | ||
\\argmax_{S \\in \\mathbb{X}_{rem}} \\min_{S' \\in \\mathbb{X}_{sel} } d(S, S') | ||
https://arxiv.org/pdf/2203.09410.pdf | ||
https://doi.org/10.1039/D2DD00034B | ||
""" | ||
n_train = matrix.n_train | ||
|
||
min_squared_distances = matrix.diagonal | ||
min_squared_distances[:n_train] = -np.inf | ||
|
||
# Use max norm for first point | ||
new_idx = np.argmax(min_squared_distances) | ||
selected_idxs = list(range(n_train)) + [new_idx] | ||
|
||
for _ in range(1, batch_size): | ||
squared_distances = matrix.score(new_idx) | ||
|
||
squared_distances[selected_idxs] = -np.inf | ||
min_squared_distances = np.minimum(min_squared_distances, squared_distances) | ||
|
||
new_idx = np.argmax(min_squared_distances) | ||
selected_idxs.append(new_idx) | ||
|
||
return ( | ||
np.array(selected_idxs[n_train:]) - n_train | ||
) # shift by number of train datapoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def ensemble_features(feature_fn): | ||
ensemble_feature_fn = jax.vmap(feature_fn, (0, None), 0) | ||
|
||
def averaged_feature_fn(params, x): | ||
g = ensemble_feature_fn(params, x) | ||
|
||
if len(g.shape) != 2: | ||
# models, features | ||
raise ValueError( | ||
"Dimension mismatch for input features. Expected shape (models," | ||
f" features), got {g.shape}" | ||
) | ||
|
||
n_models = g.shape[0] | ||
# sqrt since the kernel is K = g^T g | ||
feature_scale_factor = jnp.sqrt(1 / n_models) | ||
g_ens = feature_scale_factor * jnp.sum(g, axis=0) # shape: n_features | ||
return g_ens | ||
|
||
return averaged_feature_fn | ||
|
||
|
||
def batch_features(feature_fn): | ||
batched_feature_fn = jax.vmap(feature_fn, (None, 0), 0) | ||
return batched_feature_fn |
Oops, something went wrong.