Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not Merge] Different SO3 Variant #2

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ wandb/
logs/
output/
results/
notebooks/

# cache
.pytest_cache/
Expand Down
6 changes: 3 additions & 3 deletions configs/drugs-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ model_args:
lr_scheduler_type: CosineAnnealingWarmupRestarts
first_cycle_steps: 500_000
cycle_mult: 1.0
max_lr: 8.e-4
max_lr: 5.e-4
min_lr: 1.e-8
warmup_steps: 0
gamma: 0.05
Expand Down Expand Up @@ -98,8 +98,8 @@ logger_args:
# trainer
trainer: Trainer
trainer_args:
max_epochs: 100
max_epochs: 200
devices: 8
limit_train_batches: 5000
strategy: ddp_find_unused_parameters_true
strategy: ddp
accelerator: auto
10 changes: 5 additions & 5 deletions configs/qm9-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model_args:
cutoff_lower: 0.0
cutoff_upper: 10.0
max_z: 100
node_attr_dim: 8
node_attr_dim: 10
edge_attr_dim: 1
attn_activation: silu
num_heads: 8
Expand All @@ -47,12 +47,12 @@ model_args:
qk_norm: true
clip_during_norm: true
so3_equivariant: true
output_layer_norm: false
output_layer_norm: true

# flow matching specific
normalize_node_invariants: false
sigma: 0.1
prior_type: harmonic
prior_type: gaussian
interpolation_type: linear

# optimizer args
Expand Down Expand Up @@ -99,8 +99,8 @@ logger_args:
# trainer
trainer: Trainer
trainer_args:
max_epochs: 250
max_epochs: 500
devices: 4
limit_train_batches: 1500
strategy: ddp_find_unused_parameters_true
strategy: ddp
accelerator: auto
22 changes: 7 additions & 15 deletions etflow/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from .covmat import build_conformer
from .featurization import (
MoleculeFeaturizer,
atom_to_feature_vector,
bond_to_feature_vector,
compute_edge_index,
extend_graph_order_radius,
get_atomic_number_and_charge,
get_chiral_tensors,
signed_volume,
)
from .featurization import MoleculeFeaturizer
from .io import (
get_local_cache,
load_hdf5,
Expand All @@ -20,11 +11,14 @@
save_pkl,
)
from .sample import batched_sampling
from .utils import Queue
from .utils import (
Queue,
extend_graph_order_radius,
get_atomic_number_and_charge,
signed_volume,
)

__all__ = [
"atom_to_feature_vector",
"bond_to_feature_vector",
"MoleculeFeaturizer",
"Queue",
"load_json",
Expand All @@ -34,10 +28,8 @@
"load_memmap",
"load_hdf5",
"save_memmap",
"get_chiral_tensors",
"get_local_cache",
"get_atomic_number_and_charge",
"compute_edge_index",
"build_conformer",
"extend_graph_order_radius",
"batched_sampling",
Expand Down
Loading