Skip to content

Commit

Permalink
Merge pull request #2 from shenoynikhil/so3-variants
Browse files Browse the repository at this point in the history
[Do not Merge] Different SO3 Variant
  • Loading branch information
shenoynikhil authored Jun 24, 2024
2 parents 82a4515 + a641dc6 commit 4c9b46b
Show file tree
Hide file tree
Showing 10 changed files with 510 additions and 517 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ wandb/
logs/
output/
results/
notebooks/

# cache
.pytest_cache/
Expand Down
6 changes: 3 additions & 3 deletions configs/drugs-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ model_args:
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 500_000
cycle_mult: 1.0
max_lr: 8.e-4
max_lr: 5.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
Expand Down Expand Up @@ -98,8 +98,8 @@ logger_args:
# trainer
trainer: Trainer
trainer_args:
max_epochs: 100
max_epochs: 200
devices: 8
limit_train_batches: 5000
strategy: ddp_find_unused_parameters_true
strategy: ddp
accelerator: auto
10 changes: 5 additions & 5 deletions configs/qm9-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model_args:
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 8
node_attr_dim: 10
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
Expand All @@ -47,12 +47,12 @@ model_args:
qk_norm: true
clip_during_norm: true
so3_equivariant: true
output_layer_norm: false
output_layer_norm: true

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
prior_type: gaussian
interpolation_type: linear

# optimizer args
Expand Down Expand Up @@ -99,8 +99,8 @@ logger_args:
# trainer
trainer: Trainer
trainer_args:
max_epochs: 250
max_epochs: 500
devices: 4
limit_train_batches: 1500
strategy: ddp_find_unused_parameters_true
strategy: ddp
accelerator: auto
22 changes: 7 additions & 15 deletions etflow/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from .covmat import build_conformer
from .featurization import (
MoleculeFeaturizer,
atom_to_feature_vector,
bond_to_feature_vector,
compute_edge_index,
extend_graph_order_radius,
get_atomic_number_and_charge,
get_chiral_tensors,
signed_volume,
)
from .featurization import MoleculeFeaturizer
from .io import (
get_local_cache,
load_hdf5,
Expand All @@ -20,11 +11,14 @@
save_pkl,
)
from .sample import batched_sampling
from .utils import Queue
from .utils import (
Queue,
extend_graph_order_radius,
get_atomic_number_and_charge,
signed_volume,
)

__all__ = [
"atom_to_feature_vector",
"bond_to_feature_vector",
"MoleculeFeaturizer",
"Queue",
"load_json",
Expand All @@ -34,10 +28,8 @@
"load_memmap",
"load_hdf5",
"save_memmap",
"get_chiral_tensors",
"get_local_cache",
"get_atomic_number_and_charge",
"compute_edge_index",
"build_conformer",
"extend_graph_order_radius",
"batched_sampling",
Expand Down
Loading

0 comments on commit 4c9b46b

Please sign in to comment.