Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner and faster active learning #59

Merged
merged 18 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
# Enerzyme

Towards next-generation machine learning force fields for enzymatic catalysis.

Currently supported model architectures:

| Model | Type | Energy and force prediction | Charge and dipole prediction | Fully modulized | Reference paper | Reference code |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| PhysNet | internal | ✅ | ✅ | ✅ | [J. Chem. Theory Comput. 2019, 15, 3678–3693](https://pubs.acs.org/doi/full/10.1021/acs.jctc.9b00181) | [Github](https://github.com/MMunibas/PhysNet) |
| SpookyNet | internal | ✅ | ✅ | ✅ | [Nat. Commun. 2021, 12(1), 7273](https://www.nature.com/articles/s41467-021-27504-0) | [Github](https://github.com/OUnke/SpookyNet) |
| LEFTNet | internal | ✅ | ✅ | ❌ | [NeurIPS 2023, arXiv:2304.04757](https://arxiv.org/abs/2304.04757) | [Github](https://github.com/yuanqidu/M2Hub) |
| MACE | external | ✅ | ❌ | ❌ | [NeurIPS 2022, arXiv:2206.07697](https://arxiv.org/abs/2206.07697) | [Github](https://github.com/ACEsuit/mace) |
| NequIP | external | ✅ | ❌ | ❌ | [Nat. Commun. 2022, 13(1), 2453](https://www.nature.com/articles/s41467-022-29939-5) | [Github](https://github.com/mir-group/nequip) |
| XPaiNN | external | ✅ | ❌ | ❌ | [J. Chem. Theory Comput. 2024, 20, 21, 9500–9511](https://pubs.acs.org/doi/10.1021/acs.jctc.4c01151) | [Github](https://github.com/X1X1010/XequiNet) |
| Model | Type | Energy and force prediction | Charge and dipole prediction | Fully modulized | Shallow ensemble | Reference paper | Reference code |
| :-------: | :------: | :-------------------------: | :--------------------------: | :-------------: | :--------------: | :-------------------------------------------------------------------------------------------------: | :----------------------------------------: |
| PhysNet | internal | ✅ | ✅ | ✅ | ✅ | [J. Chem. Theory Comput. 2019, 15, 3678–3693](https://pubs.acs.org/doi/full/10.1021/acs.jctc.9b00181) | [Github](https://github.com/MMunibas/PhysNet) |
| SpookyNet | internal | ✅ | ✅ | ✅ | ✅ | [Nat. Commun. 2021, 12(1), 7273](https://www.nature.com/articles/s41467-021-27504-0) | [Github](https://github.com/OUnke/SpookyNet) |
| LEFTNet | internal | ✅ | ✅ | ❌ | ❌ | [NeurIPS 2023, arXiv:2304.04757](https://arxiv.org/abs/2304.04757) | [Github](https://github.com/yuanqidu/M2Hub) |
| MACE | external | ✅ | ❌ | ❌ | ❌ | [NeurIPS 2022, arXiv:2206.07697](https://arxiv.org/abs/2206.07697) | [Github](https://github.com/ACEsuit/mace) |
| NequIP | external | ✅ | ❌ | ❌ | ❌ | [Nat. Commun. 2022, 13(1), 2453](https://www.nature.com/articles/s41467-022-29939-5) | [Github](https://github.com/mir-group/nequip) |
| XPaiNN | external | ✅ | ❌ | ❌ | ❌ | [J. Chem. Theory Comput. 2024, 20, 21, 9500–9511](https://pubs.acs.org/doi/10.1021/acs.jctc.4c01151) | [Github](https://github.com/X1X1010/XequiNet) |

# Usage

## Installation

Recommended environment for internal force fields

```
python==3.10.12
pip==23.2.1
Expand All @@ -37,21 +40,25 @@ e3nn==0.4.4
```

To test PhysNet, you also need

```
tensorflow==2.13.0
```

To invoke MACE, you need

```
mace-torch==0.3.6
```

To invoke NequIP, you need

```
nequip==0.6.1
```

To invoke XPaiNN, you need

```
XequiNet==0.3.6
scipy==1.11.2
Expand All @@ -62,6 +69,7 @@ pydantic==1.10.12
```

Then install the package

```bash
pip install -e .
```
Expand All @@ -73,6 +81,7 @@ Energy (force) / Atomic Charge / Dipole moment fitting.
```bash
enerzyme train -c <configuration yaml file> -o <output directory>
```

Please see `enerzyme/config/train.yaml` for details and recommended configurations.

Enerzyme saves the preprocessed dataset, split indices, final `<configuration yaml file>`, and the best/last model to the `<output directory>`.
Expand All @@ -96,11 +105,12 @@ Enerzyme reads the `<model directory>` for the model configuration, load the mod
## Simulation

Supported simulation types:

- Flexible scanning on the distance between two atoms.
- Constrained Langevin MD

```bash
enerzyme simulate -c <configuration yaml file> -o <output directory> -m <model directory>
```

Enerzyme reads the `<model directory>` for the model configuration, load the models, do simulation, and report the results in the `<output directory>`.
Enerzyme reads the `<model directory>` for the model configuration, load the models, do simulation, and report the results in the `<output directory>`.
14 changes: 11 additions & 3 deletions enerzyme/config/concurrent_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,23 @@ Trainer:
data_source: withheld # data source for active learning, withheld means the Splitter's withheld partition
picking_method: max_Fa_norm_std # picking method for active learning
# random means random picking
# mean_Fa_std means picking according to the mean of standard deviation of force errors
# See J. Chem. Inf. Model. 2024, 64, 16, 6377–6387 for more details
# max_Fa_norm_std means picking according to the maximum standard deviation of force errors
# See Comput. Phys. Commun., 253, 107206. for more details
picking_params:
relative_bound_on_validation: true # if the bounds are relative to the validation set (using relative_error_lower/upper_bound)
relative_error_lower_bound: 0.5
relative_error_upper_bound: 5
error_lower_bound: 0.01 # lower bound of the force error for active learning
error_upper_bound: 0.05 # upper bound of the force error for active learning
# error_lower_bound and error_upper_bound are used to filter out the samples
# whose force errors are larger than the upper bound or smaller than the lower bound
# whose predicted force errors are larger than the upper bound or smaller than the lower bound
# see Comput. Phys. Commun., 253, 107206. for more details
error_lower_bound: 0.01 # lower bound of the force error for active learning
error_upper_bound: 0.05 # upper bound of the force error for active learning
sample_size: 10000 # number of samples to be picked in each iteration
max_epoch_per_iter: 10000 # maximum epochs of training in each iteration
max_iter: 100 # maximum number of iterations
resume: true # if resume from the last active learning checkpoint
Splitter:
method: random # splitting method: random
parts: # partition names, at least training and validation are needed
Expand Down
8 changes: 6 additions & 2 deletions enerzyme/config/predict.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Datahub: # Datahub configurations in the training configuration file will be overriden
# expect the neighbor list and transforms to make sure consistency
data_path: "/my/dataset.pkl" # dataset containing the features
data_path: "/my/dataset.pkl" # external test set, if not provided, the dataset will be loaded from the training configuration file
data_format: pickle
features:
Ra: coord
Expand All @@ -19,4 +19,8 @@ Metric: # Metric used for this prediction
Qa:
rmse: 1
E:
rmse: 1
rmse: 1
Trainer:
non_target_features: # non-target features to be saved in the prediction artifact
- E_var
- Qa_var
1 change: 1 addition & 0 deletions enerzyme/config/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ Trainer:
cuda: true # if cuda is searched and used if possible
weight_decay: 0 # weight decay rate of the Adam optimizer, default 0
batch_size: 8 # batch size of the dataloader, default 8
inference_batch_size: 4 # batch size of the dataloader for inference, default the same as batch_size
max_epochs: 10000 # maximum epochs of training if earlystopping isn't triggered, default 1000
dtype: float32 # pytorch data type throughout the computation: float32 (single) / float64 (double), default float32
use_ema: true # if exponential moving average is used, default true
Expand Down
8 changes: 7 additions & 1 deletion enerzyme/data/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def is_idx(k):
def is_target(k):
return bool(DATA_TYPES.get(k, 0) & IS_TARGET)

def is_target_uq(k):
if k.endswith("_var") or k.endswith("_std"):
target = k[:-4]
return is_target(target)
return False

def get_tensor_rank(k):
return bool(DATA_TYPES.get(k, 0) >> TENSOR_RANK_BIT)

__all__ = ["is_int", "is_rounded", "is_atomic", "requires_grad", "is_idx", "get_tensor_rank", "is_target"]
__all__ = ["is_int", "is_rounded", "is_atomic", "requires_grad", "is_idx", "get_tensor_rank", "is_target", "is_target_uq"]
Loading