Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More changes. #1

Merged
merged 8 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Preprocess the data by running the following command. Pass in the path to the da
python scripts/preprocess.py -p < path_to_saved_file > -d <folder_path_to_save_outputs>
```

## Training
## Training
We provide our configs for training on the GEOM-DRUGS and the GEOM-QM9 datasets. Run the following commands once datasets are preprocessed and the environment is set up:

```bash
Expand Down
105 changes: 105 additions & 0 deletions configs/drugs-base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# task name for logging
task_name: flow-drugs/base

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
dataset: EuclideanDataset
dataset_args:
dataset_name: geom
use_ogb_feat: true

train_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_train_0.10.npy
val_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.10.npy
test_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.1.npy

# dataloader args
dataloader_args:
batch_size: 48
num_workers: 4
pin_memory: false
persistent_workers: true

# model config
model: BaseFlow
model_args:
# network args
network_type: TorchMDDynamics
hidden_channels: 160
num_layers: 20
num_rbf: 64
rbf_type: expnorm
trainable_rbf: true
activation: silu
neighbor_embedding: true
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 10
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
distance_influence: both
reduce_op: sum
qk_norm: true
so3_equivariant: true
clip_during_norm: true

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
interpolation_type: linear

# optimizer args
optimizer_type: AdamW
lr: 8.e-4
weight_decay: 1.e-8

# lr scheduler args
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 500_000
cycle_mult: 1.0
max_lr: 8.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
last_epoch: -1
lr_scheduler_monitor: val/loss
lr_scheduler_interval: step
lr_scheduler_frequency: 1

# callbacks
callbacks:
- callback: ModelCheckpoint
callback_args:
dirpath: './checkpoint'
monitor: val/loss
mode: min
save_last: true
every_n_epochs: 1
save_top_k: 3

- callback: LearningRateMonitor
callback_args:
log_momentum: false
logging_interval: null


# logger
logger: WandbLogger
logger_args:
project: Energy-Aware-MCG
entity: doms-lab

# trainer
trainer: Trainer
trainer_args:
max_epochs: 100
devices: 8
limit_train_batches: 5000
strategy: ddp_find_unused_parameters_true
accelerator: auto
106 changes: 106 additions & 0 deletions configs/qm9-base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# task name for logging
task_name: flow-qm9/base

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
dataset: EuclideanDataset
dataset_args:
dataset_name: geom
use_ogb_feat: true

train_indices_path: /nfs/scratch/students/data/geom/preprocessed/qm9_train_0.9.npy
val_indices_path: /nfs/scratch/students/data/geom/preprocessed/qm9_val_0.9.npy
test_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.1.npy

# dataloader args
dataloader_args:
batch_size: 128
num_workers: 4
pin_memory: false
persistent_workers: true

# model config
model: BaseFlow
model_args:
# network args
network_type: TorchMDDynamics
hidden_channels: 160
num_layers: 20
num_rbf: 64
rbf_type: expnorm
trainable_rbf: true
activation: silu
neighbor_embedding: true
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 8
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
distance_influence: both
reduce_op: sum
qk_norm: true
clip_during_norm: true
so3_equivariant: true
output_layer_norm: false

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
interpolation_type: linear

# optimizer args
optimizer_type: AdamW
lr: 7.e-4
weight_decay: 1.e-8

# lr scheduler args
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 375_000
cycle_mult: 1.0
max_lr: 7.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
last_epoch: -1
lr_scheduler_monitor: val/loss
lr_scheduler_interval: step
lr_scheduler_frequency: 1

# callbacks
callbacks:
- callback: ModelCheckpoint
callback_args:
dirpath: './checkpoint'
monitor: val/loss
mode: min
save_last: true
every_n_epochs: 1
save_top_k: 3

- callback: LearningRateMonitor
callback_args:
log_momentum: false
logging_interval: null


# logger
logger: WandbLogger
logger_args:
project: Energy-Aware-MCG
entity: doms-lab

# trainer
trainer: Trainer
trainer_args:
max_epochs: 250
devices: 4
limit_train_batches: 1500
strategy: ddp_find_unused_parameters_true
accelerator: auto
8 changes: 4 additions & 4 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ dependencies:

# Chem
- datamol
- rdkit #-pypi #==2022.9.3
- rdkit

# ML
- einops =0.6.0
- pytorch =2.0.1
- pytorch ==2.1.0
- pytorch-cuda =11.8
- lightning =2.0.1
- lightning ==2.1.0
- torchmetrics =1.2.0
- pyg == 2.4.0
- pytorch-sparse == 0.6.18
Expand All @@ -47,4 +47,4 @@ dependencies:
- black
- jupyterlab
- pre-commit
- ruff
- ruff
20 changes: 12 additions & 8 deletions etflow/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .chirality import get_chiral_tensors, signed_volume
from .edge import compute_edge_index, extend_graph_order_radius
from .covmat import build_conformer
from .featurization import (
MoleculeFeaturizer,
atom_to_feature_vector,
bond_to_feature_vector,
compute_edge_index,
extend_graph_order_radius,
get_atomic_number_and_charge,
get_chiral_tensors,
signed_volume,
)
from .io import (
get_local_cache,
load_hdf5,
Expand All @@ -10,18 +19,13 @@
save_memmap,
save_pkl,
)
from .covmat import build_conformer
from .featurization import (
atom_to_feature_vector,
get_atomic_number_and_charge,
bond_to_feature_vector
)
from .sample import batched_sampling
from .utils import Queue

__all__ = [
"atom_to_feature_vector",
"bond_to_feature_vector",
"MoleculeFeaturizer",
"Queue",
"load_json",
"load_pkl",
Expand Down
86 changes: 0 additions & 86 deletions etflow/commons/chirality.py

This file was deleted.

Loading