Skip to content

Commit

Permalink
Merge pull request #1 from shenoynikhil/checks
Browse files Browse the repository at this point in the history
More changes.
  • Loading branch information
shenoynikhil authored Jun 20, 2024
2 parents 3c6d68d + 485f89a commit 82a4515
Show file tree
Hide file tree
Showing 20 changed files with 625 additions and 880 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Preprocess the data by running the following command. Pass in the path to the da
python scripts/preprocess.py -p < path_to_saved_file > -d <folder_path_to_save_outputs>
```

## Training
## Training
We provide our configs for training on the GEOM-DRUGS and the GEOM-QM9 datasets. Run the following commands once datasets are preprocessed and the environment is set up:

```bash
Expand Down
105 changes: 105 additions & 0 deletions configs/drugs-base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# task name for logging
task_name: flow-drugs/base

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
dataset: EuclideanDataset
dataset_args:
dataset_name: geom
use_ogb_feat: true

train_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_train_0.10.npy
val_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.10.npy
test_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.1.npy

# dataloader args
dataloader_args:
batch_size: 48
num_workers: 4
pin_memory: false
persistent_workers: true

# model config
model: BaseFlow
model_args:
# network args
network_type: TorchMDDynamics
hidden_channels: 160
num_layers: 20
num_rbf: 64
rbf_type: expnorm
trainable_rbf: true
activation: silu
neighbor_embedding: true
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 10
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
distance_influence: both
reduce_op: sum
qk_norm: true
so3_equivariant: true
clip_during_norm: true

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
interpolation_type: linear

# optimizer args
optimizer_type: AdamW
lr: 8.e-4
weight_decay: 1.e-8

# lr scheduler args
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 500_000
cycle_mult: 1.0
max_lr: 8.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
last_epoch: -1
lr_scheduler_monitor: val/loss
lr_scheduler_interval: step
lr_scheduler_frequency: 1

# callbacks
callbacks:
- callback: ModelCheckpoint
callback_args:
dirpath: './checkpoint'
monitor: val/loss
mode: min
save_last: true
every_n_epochs: 1
save_top_k: 3

- callback: LearningRateMonitor
callback_args:
log_momentum: false
logging_interval: null


# logger
logger: WandbLogger
logger_args:
project: Energy-Aware-MCG
entity: doms-lab

# trainer
trainer: Trainer
trainer_args:
max_epochs: 100
devices: 8
limit_train_batches: 5000
strategy: ddp_find_unused_parameters_true
accelerator: auto
106 changes: 106 additions & 0 deletions configs/qm9-base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# task name for logging
task_name: flow-qm9/base

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
dataset: EuclideanDataset
dataset_args:
dataset_name: geom
use_ogb_feat: true

train_indices_path: /nfs/scratch/students/data/geom/preprocessed/qm9_train_0.9.npy
val_indices_path: /nfs/scratch/students/data/geom/preprocessed/qm9_val_0.9.npy
test_indices_path: /nfs/scratch/students/data/geom/preprocessed/drugs_val_0.1.npy

# dataloader args
dataloader_args:
batch_size: 128
num_workers: 4
pin_memory: false
persistent_workers: true

# model config
model: BaseFlow
model_args:
# network args
network_type: TorchMDDynamics
hidden_channels: 160
num_layers: 20
num_rbf: 64
rbf_type: expnorm
trainable_rbf: true
activation: silu
neighbor_embedding: true
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 8
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
distance_influence: both
reduce_op: sum
qk_norm: true
clip_during_norm: true
so3_equivariant: true
output_layer_norm: false

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
interpolation_type: linear

# optimizer args
optimizer_type: AdamW
lr: 7.e-4
weight_decay: 1.e-8

# lr scheduler args
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 375_000
cycle_mult: 1.0
max_lr: 7.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
last_epoch: -1
lr_scheduler_monitor: val/loss
lr_scheduler_interval: step
lr_scheduler_frequency: 1

# callbacks
callbacks:
- callback: ModelCheckpoint
callback_args:
dirpath: './checkpoint'
monitor: val/loss
mode: min
save_last: true
every_n_epochs: 1
save_top_k: 3

- callback: LearningRateMonitor
callback_args:
log_momentum: false
logging_interval: null


# logger
logger: WandbLogger
logger_args:
project: Energy-Aware-MCG
entity: doms-lab

# trainer
trainer: Trainer
trainer_args:
max_epochs: 250
devices: 4
limit_train_batches: 1500
strategy: ddp_find_unused_parameters_true
accelerator: auto
8 changes: 4 additions & 4 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ dependencies:

# Chem
- datamol
- rdkit #-pypi #==2022.9.3
- rdkit

# ML
- einops =0.6.0
- pytorch =2.0.1
- pytorch ==2.1.0
- pytorch-cuda =11.8
- lightning =2.0.1
- lightning ==2.1.0
- torchmetrics =1.2.0
- pyg == 2.4.0
- pytorch-sparse == 0.6.18
Expand All @@ -47,4 +47,4 @@ dependencies:
- black
- jupyterlab
- pre-commit
- ruff
- ruff
20 changes: 12 additions & 8 deletions etflow/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .chirality import get_chiral_tensors, signed_volume
from .edge import compute_edge_index, extend_graph_order_radius
from .covmat import build_conformer
from .featurization import (
MoleculeFeaturizer,
atom_to_feature_vector,
bond_to_feature_vector,
compute_edge_index,
extend_graph_order_radius,
get_atomic_number_and_charge,
get_chiral_tensors,
signed_volume,
)
from .io import (
get_local_cache,
load_hdf5,
Expand All @@ -10,18 +19,13 @@
save_memmap,
save_pkl,
)
from .covmat import build_conformer
from .featurization import (
atom_to_feature_vector,
get_atomic_number_and_charge,
bond_to_feature_vector
)
from .sample import batched_sampling
from .utils import Queue

__all__ = [
"atom_to_feature_vector",
"bond_to_feature_vector",
"MoleculeFeaturizer",
"Queue",
"load_json",
"load_pkl",
Expand Down
86 changes: 0 additions & 86 deletions etflow/commons/chirality.py

This file was deleted.

Loading

0 comments on commit 82a4515

Please sign in to comment.