diff --git a/.gitignore b/.gitignore index ed9f4aa..64aa1d4 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ wandb/ logs/ output/ results/ +notebooks/ # cache .pytest_cache/ diff --git a/configs/drugs-base.yaml b/configs/drugs-base.yaml index 49c0b33..6194714 100644 --- a/configs/drugs-base.yaml +++ b/configs/drugs-base.yaml @@ -63,7 +63,7 @@ model_args: lr_scheduler_type: CosineAnnealingWarmupRestarts first_cycle_steps: 500_000 cycle_mult: 1.0 - max_lr: 8.e-4 + max_lr: 5.e-4 min_lr: 1.e-8 warmup_steps: 0 gamma: 0.05 @@ -98,8 +98,8 @@ logger_args: # trainer trainer: Trainer trainer_args: - max_epochs: 100 + max_epochs: 200 devices: 8 limit_train_batches: 5000 - strategy: ddp_find_unused_parameters_true + strategy: ddp accelerator: auto diff --git a/configs/qm9-base.yaml b/configs/qm9-base.yaml index 94d8b25..bbc4ee6 100644 --- a/configs/qm9-base.yaml +++ b/configs/qm9-base.yaml @@ -38,7 +38,7 @@ model_args: cutoff_lower: 0.0 cutoff_upper: 10.0 max_z: 100 - node_attr_dim: 8 + node_attr_dim: 10 edge_attr_dim: 1 attn_activation: silu num_heads: 8 @@ -47,12 +47,12 @@ model_args: qk_norm: true clip_during_norm: true so3_equivariant: true - output_layer_norm: false + output_layer_norm: true # flow matching specific normalize_node_invariants: false sigma: 0.1 - prior_type: harmonic + prior_type: gaussian interpolation_type: linear # optimizer args @@ -99,8 +99,8 @@ logger_args: # trainer trainer: Trainer trainer_args: - max_epochs: 250 + max_epochs: 500 devices: 4 limit_train_batches: 1500 - strategy: ddp_find_unused_parameters_true + strategy: ddp accelerator: auto diff --git a/etflow/commons/__init__.py b/etflow/commons/__init__.py index 3e17543..1f7bd6a 100644 --- a/etflow/commons/__init__.py +++ b/etflow/commons/__init__.py @@ -1,14 +1,5 @@ from .covmat import build_conformer -from .featurization import ( - MoleculeFeaturizer, - atom_to_feature_vector, - bond_to_feature_vector, - compute_edge_index, - extend_graph_order_radius, - get_atomic_number_and_charge, - get_chiral_tensors, - signed_volume, -) +from .featurization import MoleculeFeaturizer from .io import ( get_local_cache, load_hdf5, @@ -20,11 +11,14 @@ save_pkl, ) from .sample import batched_sampling -from .utils import Queue +from .utils import ( + Queue, + extend_graph_order_radius, + get_atomic_number_and_charge, + signed_volume, +) __all__ = [ - "atom_to_feature_vector", - "bond_to_feature_vector", "MoleculeFeaturizer", "Queue", "load_json", @@ -34,10 +28,8 @@ "load_memmap", "load_hdf5", "save_memmap", - "get_chiral_tensors", "get_local_cache", "get_atomic_number_and_charge", - "compute_edge_index", "build_conformer", "extend_graph_order_radius", "batched_sampling", diff --git a/etflow/commons/featurization.py b/etflow/commons/featurization.py index 7acd971..a1f79d0 100644 --- a/etflow/commons/featurization.py +++ b/etflow/commons/featurization.py @@ -1,122 +1,85 @@ # allowable multiple choice node and edge features -from copy import deepcopy -from typing import Tuple +from collections import defaultdict +from typing import Callable, Tuple import datamol as dm -import numpy as np import torch from datamol.types import Mol -from rdkit import Chem -from rdkit.Chem.rdchem import BondType as BT -from rdkit.Chem.rdchem import ChiralType -from torch_cluster import radius_graph -from torch_geometric.utils import dense_to_sparse, to_dense_adj -from torch_sparse import coalesce from .covmat import build_conformer +from .utils import atom_to_feature_vector, compute_edge_index, get_chiral_tensors -# similar to GeoMol -BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} -chirality = { - ChiralType.CHI_TETRAHEDRAL_CW: -1.0, - ChiralType.CHI_TETRAHEDRAL_CCW: 1.0, - ChiralType.CHI_UNSPECIFIED: 0, - ChiralType.CHI_OTHER: 0, -} -allowable_features = { - "possible_atomic_num_list": list(range(1, 119)) + ["misc"], - "possible_chirality_list": [ - "CHI_UNSPECIFIED", - "CHI_TETRAHEDRAL_CW", - "CHI_TETRAHEDRAL_CCW", - "CHI_OTHER", - "misc", - ], - "possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"], - "possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"], - "possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"], - "possible_implicit_valence": [0, 1, 2, 3, 4, 5, 6, "misc"], - "possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"], - "possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"], - "possible_is_aromatic_list": [False, True], - "possible_is_in_ring_list": [False, True], - "possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"], - "possible_bond_stereo_list": [ - "STEREONONE", - "STEREOZ", - "STEREOE", - "STEREOCIS", - "STEREOTRANS", - "STEREOANY", - ], - "possible_is_conjugated_list": [False, True], -} +def cache_decorator(func: Callable): + """Decorator to handle caching logic.""" + + def wrapper(self, smiles: str, *args, **kwargs): + cache_key = func.__name__ + if smiles in self.cache and cache_key in self.cache[smiles]: + return self.cache[smiles][cache_key] + result = func(self, smiles, *args, **kwargs) + self.cache[smiles][cache_key] = result + return result + + return wrapper class MoleculeFeaturizer: """A Featurizer Class for Molecules. - Give smiles, get mol objects, atom features, bond features, etc. - - Caching to avoid recomputation. - - Parameters - ---------- - use_ogb_features: bool, default=True - If True, 10-dimensional atom features based on OGB are computed, - Otherwise, atomic charges are used. - use_edge_feat: bool, default=False - If True, edge features are computed. + - Smiles-based Caching to avoid recomputation. """ - def __init__(self, use_ogb_feat: bool = True, use_edge_feat: bool = False): + def __init__(self): # smiles based cache - self.cache = {} - self.use_ogb_feat = use_ogb_feat - self.use_edge_feat = use_edge_feat + self.cache = defaultdict(dict) def get_mol(self, smiles: str) -> Mol: return dm.to_mol(smiles, remove_hs=False, ordered=True) - def get_atom_features(self, smiles: str) -> torch.Tensor: - # check if cached - if smiles in self.cache and "atom_features" in self.cache[smiles]: - return self.cache[smiles]["atom_features"] - + @cache_decorator + def get_atom_features(self, smiles: str, use_ogb_feat: bool = True) -> torch.Tensor: # compute atom features mol = self.get_mol(smiles) + atom_features = self.get_atom_features_from_mol(mol, use_ogb_feat=use_ogb_feat) + return atom_features - if self.use_ogb_feat: + @cache_decorator + def get_atomic_numbers(self, smiles: str) -> torch.Tensor: + # compute atomic numbers + mol = self.get_mol(smiles) + atomic_numbers = self.get_atomic_numbers_from_mol(mol) + return atomic_numbers + + def get_atomic_numbers_from_mol(self, mol: Mol) -> torch.Tensor: + atomic_numbers = torch.tensor( + [atom.GetAtomicNum() for atom in mol.GetAtoms()], + dtype=torch.int32, + ) + return atomic_numbers + + def get_atom_features_from_mol( + self, mol: Mol, use_ogb_feat: bool = True + ) -> torch.Tensor: + if use_ogb_feat: atom_features = torch.tensor( [atom_to_feature_vector(atom) for atom in mol.GetAtoms()], dtype=torch.float32, - ) # (n_atoms, 10) + ) else: atom_features = torch.tensor( [atom.GetFormalCharge() for atom in mol.GetAtoms()], dtype=torch.float32, - ).view( - -1, 1 - ) # (n_atoms, 1) - - # add smiles to cache - if smiles not in self.cache: - self.cache[smiles] = {} - - self.cache[smiles]["atom_features"] = atom_features + ).view(-1, 1) return atom_features + @cache_decorator def get_chiral_centers(self, smiles: str) -> torch.Tensor: - # check if cached - if smiles in self.cache and "chiral_centers" in self.cache[smiles]: - return self.cache[smiles]["chiral_centers"] - # compute chiral centers mol = self.get_mol(smiles) - chiral_index, chiral_nbr_index, chiral_tag = get_chiral_tensors(mol) - - # add smiles to cache - if smiles not in self.cache: - self.cache[smiles] = {} + chiral_index, chiral_nbr_index, chiral_tag = self.get_chiral_centers_from_mol( + mol + ) self.cache[smiles]["chiral_centers"] = ( chiral_index, @@ -125,341 +88,34 @@ def get_chiral_centers(self, smiles: str) -> torch.Tensor: ) return chiral_index, chiral_nbr_index, chiral_tag + def get_chiral_centers_from_mol(self, mol: Mol) -> torch.Tensor: + chiral_index, chiral_nbr_index, chiral_tag = get_chiral_tensors(mol) + return chiral_index, chiral_nbr_index, chiral_tag + + @cache_decorator def get_mol_with_conformer(self, smiles: str, positions: torch.Tensor) -> Mol: mol = self.get_mol(smiles) mol.AddConformer(build_conformer(positions)) return mol - def get_edge_index(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: + @cache_decorator + def get_edge_index( + self, smiles: str, use_edge_feat: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: """Returns edge index and edge attributes for a given smiles.""" - # check if cached - if smiles in self.cache and "edge_index" in self.cache[smiles]: - return self.cache[smiles]["edge_index"], self.cache[smiles]["edge_attr"] - # compute edge index mol = self.get_mol(smiles) - edge_index, edge_attr = compute_edge_index( - mol, with_edge_attr=self.use_edge_feat + edge_index, edge_attr = self.get_edge_index_from_mol( + mol, use_edge_feat=use_edge_feat ) - # add smiles to cache - if smiles not in self.cache: - self.cache[smiles] = {} - self.cache[smiles]["edge_index"] = edge_index self.cache[smiles]["edge_attr"] = edge_attr return edge_index, edge_attr - -def get_atomic_number_and_charge(mol: Chem.Mol): - """Returns atoms number and charge for rdkit molecule""" - return np.array( - [[atom.GetAtomicNum(), atom.GetFormalCharge()] for atom in mol.GetAtoms()] - ) - - -def GetNumRings(atom): - return sum([atom.IsInRingSize(i) for i in range(3, 7)]) - - -def atom_to_feature_vector(atom): - """Node Invariant Features for an Atom.""" - atom_feature = [ - safe_index( - allowable_features["possible_chirality_list"], - chirality[atom.GetChiralTag()], - ), - safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()), - safe_index( - allowable_features["possible_formal_charge_list"], atom.GetFormalCharge() - ), - safe_index( - allowable_features["possible_implicit_valence"], atom.GetImplicitValence() - ), - safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()), - safe_index( - allowable_features["possible_hybridization_list"], - str(atom.GetHybridization()), - ), - safe_index( - allowable_features["possible_number_radical_e_list"], - atom.GetNumRadicalElectrons(), - ), - allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()), - allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()), - GetNumRings(atom), - ] - return atom_feature - - -def safe_index(l, e): - """ - Return index of element e in list l. If e is not present, return the last index - """ - try: - return l.index(e) - except Exception as e: - return len(l) - 1 - - -def bond_to_feature_vector(bond): - """ - Converts rdkit bond object to feature list of indices - :param mol: rdkit bond object - :return: list - """ - bond_feature = [ - safe_index( - allowable_features["possible_bond_type_list"], str(bond.GetBondType()) - ), - # allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), - # allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), - ] - return bond_feature - - -def compute_edge_index( - mol, no_reverse: bool = False, with_edge_attr=False -) -> torch.Tensor: - """Computes edge index from mol object""" - edge_list = [] - bond_types = [] - for bond in mol.GetBonds(): - i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - edge_list.append((i, j)) - bond_types.append(bond_to_feature_vector(bond)) - - if not no_reverse: - edge_list.append((j, i)) - bond_types.append(bond_to_feature_vector(bond)) - - if len(edge_list) == 0: - return torch.empty((2, 0)).long() - - edge_index = torch.from_numpy(np.array(edge_list).T).long() - - if with_edge_attr: - edge_attr = torch.tensor(bond_types, dtype=torch.float32) # (num_edges, 1) + def get_edge_index_from_mol( + self, mol: Mol, use_edge_feat: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns edge index and edge attributes for a given mol object.""" + edge_index, edge_attr = compute_edge_index(mol, with_edge_attr=use_edge_feat) return edge_index, edge_attr - - return edge_index, None - - -def _extend_graph_order( - num_nodes: int, edge_index: torch.Tensor, edge_type: torch.Tensor, order=3 -): - """ - Extends order of the existing bond index. - - For instance, if atom-1-atom-2-atom-3 form an bond angle, then atom-1-atom-3 - will be added to the bond index for order=3. - - The importance of this is highlighted in section 2.1 of the paper: - https://arxiv.org/abs/1909.11459 - - Parameters - ---------- - num_nodes: int - Number of atoms. - edge_index: torch.Tensor - Bond indices of the original graph. - edge_type: torch.Tensor - Bond types of the original graph. - order: int - Extension order. - - Returns - ------- - new_edge_index: torch.Tensor - Extended edge indices. - new_edge_type: torch.Tensor - Extended edge types. - """ - - def binarize(x): - return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) - - def get_higher_order_adj_matrix(adj, order): - """ - Args: - adj: (N, N) - type_mat: (N, N) - Returns: - Following attributes will be updated: - - edge_index - - edge_type - Following attributes will be added to the data object: - - bond_edge_index: Original edge_index. - """ - adj_mats = [ - torch.eye(adj.size(0), dtype=torch.long, device=adj.device), - binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)), - ] - - for i in range(2, order + 1): - adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1])) - order_mat = torch.zeros_like(adj) - - for i in range(1, order + 1): - order_mat += (adj_mats[i] - adj_mats[i - 1]) * i - - return order_mat - - num_types = len(BOND_TYPES) - - N = num_nodes - adj = to_dense_adj(edge_index).squeeze(0) - adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) - - type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) - type_highorder = torch.where( - adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order) - ) - assert (type_mat * type_highorder == 0).all() - type_new = type_mat + type_highorder - - new_edge_index, new_edge_type = dense_to_sparse(type_new) - - # data.bond_edge_index = data.edge_index # Save original edges - new_edge_index, new_edge_type = coalesce( - new_edge_index, new_edge_type.long(), N, N - ) # modify data - - return new_edge_index, new_edge_type - - -def _extend_to_radius_graph( - pos: torch.Tensor, - edge_index: torch.Tensor, - edge_type: torch.Tensor, - batch: torch.Tensor, - cutoff: float = 10.0, - max_num_neighbors: int = 32, - unspecified_type_number=0, -): - assert edge_type.dim() == 1 - N = pos.size(0) - - bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) - rgraph_edge_index = radius_graph( - pos, r=cutoff, batch=batch, max_num_neighbors=max_num_neighbors - ) # (2, E_r) - - rgraph_adj = torch.sparse.LongTensor( - rgraph_edge_index, - torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) - * unspecified_type_number, - torch.Size([N, N]), - ) - - composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) - - new_edge_index = composed_adj.indices() - new_edge_type = composed_adj.values().long() - - return new_edge_index, new_edge_type - - -def extend_graph_order_radius( - pos: torch.Tensor, - edge_index: torch.Tensor, - edge_type: torch.Tensor, - batch: torch.Tensor, - cutoff: float = 10.0, - order: int = 3, - max_num_neighbors: int = 32, - extend_radius: bool = True, - extend_order: bool = True, -): - """Extends bond index""" - num_nodes = pos.shape[0] - if extend_order: - edge_index, edge_type = _extend_graph_order( - num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order - ) - - if extend_radius: - edge_index, edge_type = _extend_to_radius_graph( - pos=pos, - edge_index=edge_index, - edge_type=edge_type, - cutoff=cutoff, - batch=batch, - max_num_neighbors=max_num_neighbors, - ) - - return edge_index, edge_type - - -def get_neighbor_ids(data): - """ - Takes the edge indices and returns dictionary mapping atom index to neighbor indices - Note: this only includes atoms with degree > 1 - """ - batch_nbrs = deepcopy(data.neighbors) - batch_nbrs = [obj[0] for obj in batch_nbrs] - neighbors = batch_nbrs.pop(0) # get first element - n_atoms_per_mol = data.batch.bincount() # get atom count per graph - n_atoms_prev_mol = 0 - - for i, n_dict in enumerate(batch_nbrs): - new_dict = {} - n_atoms_prev_mol += n_atoms_per_mol[i].item() - for k, v in n_dict.items(): - new_dict[k + n_atoms_prev_mol] = v + n_atoms_prev_mol - neighbors.update(new_dict) - - return neighbors - - -def signed_volume(local_coords): - """ - Compute signed volume given ordered neighbor local coordinates - From GeoMol - - :param local_coords: (n_tetrahedral_chiral_centers, 4, n_generated_confs, 3) - :return: signed volume of each tetrahedral center - (n_tetrahedral_chiral_centers, n_generated_confs) - """ - v1 = local_coords[:, 0] - local_coords[:, 3] - v2 = local_coords[:, 1] - local_coords[:, 3] - v3 = local_coords[:, 2] - local_coords[:, 3] - cp = v2.cross(v3, dim=-1) - vol = torch.sum(v1 * cp, dim=-1) - return torch.sign(vol) - - -def get_chiral_tensors(mol): - """Only consider chiral atoms with 4 neighbors""" - chiral_index = torch.tensor( - [ - i - for i, atom in enumerate(mol.GetAtoms()) - if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) - ], - dtype=torch.int32, - ).view( - 1, -1 - ) # (1, n_chiral_centers) - # (n_chiral_centers, 4) - chiral_nbr_index = torch.tensor( - [ - [n.GetIdx() for n in atom.GetNeighbors()] - for atom in mol.GetAtoms() - if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) - ], - dtype=torch.int32, - ).view( - 1, -1 - ) # (1, n_chiral_centers * 4) - # (n_chiral_centers,) - chiral_tag = torch.tensor( - [ - chirality[atom.GetChiralTag()] - for atom in mol.GetAtoms() - if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) - ], - dtype=torch.float32, - ) - - return chiral_index, chiral_nbr_index, chiral_tag diff --git a/etflow/commons/utils.py b/etflow/commons/utils.py index 4a6e48d..c47e47f 100644 --- a/etflow/commons/utils.py +++ b/etflow/commons/utils.py @@ -1,4 +1,52 @@ +# allowable multiple choice node and edge features +from copy import deepcopy + import numpy as np +import torch +from rdkit import Chem +from rdkit.Chem.rdchem import BondType as BT +from rdkit.Chem.rdchem import ChiralType +from torch_cluster import radius_graph +from torch_geometric.utils import dense_to_sparse, to_dense_adj +from torch_sparse import coalesce + +# similar to GeoMol +BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} +chirality = { + ChiralType.CHI_TETRAHEDRAL_CW: -1.0, + ChiralType.CHI_TETRAHEDRAL_CCW: 1.0, + ChiralType.CHI_UNSPECIFIED: 0, + ChiralType.CHI_OTHER: 0, +} + +allowable_features = { + "possible_atomic_num_list": list(range(1, 119)) + ["misc"], + "possible_chirality_list": [ + "CHI_UNSPECIFIED", + "CHI_TETRAHEDRAL_CW", + "CHI_TETRAHEDRAL_CCW", + "CHI_OTHER", + "misc", + ], + "possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"], + "possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"], + "possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"], + "possible_implicit_valence": [0, 1, 2, 3, 4, 5, 6, "misc"], + "possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"], + "possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"], + "possible_is_aromatic_list": [False, True], + "possible_is_in_ring_list": [False, True], + "possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"], + "possible_bond_stereo_list": [ + "STEREONONE", + "STEREOZ", + "STEREOE", + "STEREOCIS", + "STEREOTRANS", + "STEREOANY", + ], + "possible_is_conjugated_list": [False, True], +} # Gradient clipping @@ -20,3 +68,317 @@ def mean(self): def std(self): return np.std(self.items) + + +def get_atomic_number_and_charge(mol: Chem.Mol): + """Returns atoms number and charge for rdkit molecule""" + return np.array( + [[atom.GetAtomicNum(), atom.GetFormalCharge()] for atom in mol.GetAtoms()] + ) + + +def GetNumRings(atom): + return sum([atom.IsInRingSize(i) for i in range(3, 7)]) + + +def atom_to_feature_vector(atom): + """Node Invariant Features for an Atom.""" + atom_feature = [ + safe_index( + allowable_features["possible_chirality_list"], + chirality[atom.GetChiralTag()], + ), + safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()), + safe_index( + allowable_features["possible_formal_charge_list"], atom.GetFormalCharge() + ), + safe_index( + allowable_features["possible_implicit_valence"], atom.GetImplicitValence() + ), + safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()), + safe_index( + allowable_features["possible_hybridization_list"], + str(atom.GetHybridization()), + ), + safe_index( + allowable_features["possible_number_radical_e_list"], + atom.GetNumRadicalElectrons(), + ), + allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()), + allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()), + GetNumRings(atom), + ] + return atom_feature + + +def safe_index(l, e): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return l.index(e) + except Exception as e: + return len(l) - 1 + + +def bond_to_feature_vector(bond): + """ + Converts rdkit bond object to feature list of indices + :param mol: rdkit bond object + :return: list + """ + bond_feature = [ + safe_index( + allowable_features["possible_bond_type_list"], str(bond.GetBondType()) + ), + # allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), + # allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), + ] + return bond_feature + + +def compute_edge_index( + mol, no_reverse: bool = False, with_edge_attr=False +) -> torch.Tensor: + """Computes edge index from mol object""" + edge_list = [] + bond_types = [] + for bond in mol.GetBonds(): + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_list.append((i, j)) + bond_types.append(bond_to_feature_vector(bond)) + + if not no_reverse: + edge_list.append((j, i)) + bond_types.append(bond_to_feature_vector(bond)) + + if len(edge_list) == 0: + return torch.empty((2, 0)).long() + + edge_index = torch.from_numpy(np.array(edge_list).T).long() + + if with_edge_attr: + edge_attr = torch.tensor(bond_types, dtype=torch.float32) # (num_edges, 1) + return edge_index, edge_attr + + return edge_index, None + + +def _extend_graph_order( + num_nodes: int, edge_index: torch.Tensor, edge_type: torch.Tensor, order=3 +): + """ + Extends order of the existing bond index. + + For instance, if atom-1-atom-2-atom-3 form an bond angle, then atom-1-atom-3 + will be added to the bond index for order=3. + + The importance of this is highlighted in section 2.1 of the paper: + https://arxiv.org/abs/1909.11459 + + Parameters + ---------- + num_nodes: int + Number of atoms. + edge_index: torch.Tensor + Bond indices of the original graph. + edge_type: torch.Tensor + Bond types of the original graph. + order: int + Extension order. + + Returns + ------- + new_edge_index: torch.Tensor + Extended edge indices. + new_edge_type: torch.Tensor + Extended edge types. + """ + + def binarize(x): + return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) + + def get_higher_order_adj_matrix(adj, order): + """ + Args: + adj: (N, N) + type_mat: (N, N) + Returns: + Following attributes will be updated: + - edge_index + - edge_type + Following attributes will be added to the data object: + - bond_edge_index: Original edge_index. + """ + adj_mats = [ + torch.eye(adj.size(0), dtype=torch.long, device=adj.device), + binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)), + ] + + for i in range(2, order + 1): + adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1])) + order_mat = torch.zeros_like(adj) + + for i in range(1, order + 1): + order_mat += (adj_mats[i] - adj_mats[i - 1]) * i + + return order_mat + + num_types = len(BOND_TYPES) + + N = num_nodes + adj = to_dense_adj(edge_index).squeeze(0) + adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) + + type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) + type_highorder = torch.where( + adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order) + ) + assert (type_mat * type_highorder == 0).all() + type_new = type_mat + type_highorder + + new_edge_index, new_edge_type = dense_to_sparse(type_new) + + # data.bond_edge_index = data.edge_index # Save original edges + new_edge_index, new_edge_type = coalesce( + new_edge_index, new_edge_type.long(), N, N + ) # modify data + + return new_edge_index, new_edge_type + + +def _extend_to_radius_graph( + pos: torch.Tensor, + edge_index: torch.Tensor, + edge_type: torch.Tensor, + batch: torch.Tensor, + cutoff: float = 10.0, + max_num_neighbors: int = 32, + unspecified_type_number=0, +): + assert edge_type.dim() == 1 + N = pos.size(0) + + bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) + rgraph_edge_index = radius_graph( + pos, r=cutoff, batch=batch, max_num_neighbors=max_num_neighbors + ) # (2, E_r) + + rgraph_adj = torch.sparse.LongTensor( + rgraph_edge_index, + torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) + * unspecified_type_number, + torch.Size([N, N]), + ) + + composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) + + new_edge_index = composed_adj.indices() + new_edge_type = composed_adj.values().long() + + return new_edge_index, new_edge_type + + +def extend_graph_order_radius( + pos: torch.Tensor, + edge_index: torch.Tensor, + edge_type: torch.Tensor, + batch: torch.Tensor, + cutoff: float = 10.0, + order: int = 3, + max_num_neighbors: int = 32, + extend_radius: bool = True, + extend_order: bool = True, +): + """Extends bond index""" + num_nodes = pos.shape[0] + if extend_order: + edge_index, edge_type = _extend_graph_order( + num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order + ) + + if extend_radius: + edge_index, edge_type = _extend_to_radius_graph( + pos=pos, + edge_index=edge_index, + edge_type=edge_type, + cutoff=cutoff, + batch=batch, + max_num_neighbors=max_num_neighbors, + ) + + return edge_index, edge_type + + +def get_neighbor_ids(data): + """ + Takes the edge indices and returns dictionary mapping atom index to neighbor indices + Note: this only includes atoms with degree > 1 + """ + batch_nbrs = deepcopy(data.neighbors) + batch_nbrs = [obj[0] for obj in batch_nbrs] + neighbors = batch_nbrs.pop(0) # get first element + n_atoms_per_mol = data.batch.bincount() # get atom count per graph + n_atoms_prev_mol = 0 + + for i, n_dict in enumerate(batch_nbrs): + new_dict = {} + n_atoms_prev_mol += n_atoms_per_mol[i].item() + for k, v in n_dict.items(): + new_dict[k + n_atoms_prev_mol] = v + n_atoms_prev_mol + neighbors.update(new_dict) + + return neighbors + + +def signed_volume(local_coords): + """ + Compute signed volume given ordered neighbor local coordinates + From GeoMol + + :param local_coords: (n_tetrahedral_chiral_centers, 4, n_generated_confs, 3) + :return: signed volume of each tetrahedral center + (n_tetrahedral_chiral_centers, n_generated_confs) + """ + v1 = local_coords[:, 0] - local_coords[:, 3] + v2 = local_coords[:, 1] - local_coords[:, 3] + v3 = local_coords[:, 2] - local_coords[:, 3] + cp = v2.cross(v3, dim=-1) + vol = torch.sum(v1 * cp, dim=-1) + return torch.sign(vol) + + +def get_chiral_tensors(mol): + """Only consider chiral atoms with 4 neighbors""" + chiral_index = torch.tensor( + [ + i + for i, atom in enumerate(mol.GetAtoms()) + if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) + ], + dtype=torch.int32, + ).view( + 1, -1 + ) # (1, n_chiral_centers) + # (n_chiral_centers, 4) + chiral_nbr_index = torch.tensor( + [ + [n.GetIdx() for n in atom.GetNeighbors()] + for atom in mol.GetAtoms() + if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) + ], + dtype=torch.int32, + ).view( + 1, -1 + ) # (1, n_chiral_centers * 4) + # (n_chiral_centers,) + chiral_tag = torch.tensor( + [ + chirality[atom.GetChiralTag()] + for atom in mol.GetAtoms() + if (chirality[atom.GetChiralTag()] != 0 and len(atom.GetNeighbors()) == 4) + ], + dtype=torch.float32, + ) + + return chiral_index, chiral_nbr_index, chiral_tag diff --git a/etflow/data/dataset.py b/etflow/data/dataset.py index 1bf3277..4403fa3 100644 --- a/etflow/data/dataset.py +++ b/etflow/data/dataset.py @@ -28,9 +28,9 @@ def __init__( # instantiate dataset self.dataset_name = dataset_name self.dataset = DATASET_MAPPING[dataset_name]() - self.mol_feat = MoleculeFeaturizer( - use_ogb_feat=use_ogb_feat, use_edge_feat=use_edge_feat - ) + self.mol_feat = MoleculeFeaturizer() + self.use_ogb_feat = use_ogb_feat + self.use_edge_feat = use_edge_feat self.cache = {} def len(self): @@ -45,11 +45,11 @@ def get(self, idx): smiles = data_bunch["smiles"] # featurize molecule - node_attr = self.mol_feat.get_atom_features(smiles) + node_attr = self.mol_feat.get_atom_features(smiles, self.use_ogb_feat) chiral_index, chiral_nbr_index, chiral_tag = self.mol_feat.get_chiral_centers( smiles ) - edge_index, edge_attr = self.mol_feat.get_edge_index(smiles) + edge_index, edge_attr = self.mol_feat.get_edge_index(smiles, self.use_edge_feat) mol = self.mol_feat.get_mol_with_conformer(smiles, pos) graph = Data( diff --git a/etflow/eval.py b/etflow/eval.py index 4e09901..47dc7f3 100644 --- a/etflow/eval.py +++ b/etflow/eval.py @@ -11,17 +11,18 @@ import datetime import os import os.path as osp +import time import numpy as np import pandas as pd import torch -import wandb # from lightning import seed_everything from loguru import logger as log from torch_geometric.data import Batch, Data from tqdm import tqdm +import wandb from etflow.commons import load_pkl, save_pkl from etflow.models import BaseFlow from etflow.utils import instantiate_dataset, instantiate_model, read_yaml @@ -47,12 +48,8 @@ def main( nsteps: int, batch_size: int, eps: float, - debug: bool, subset_type: str, ): - # seed = config.get("seed", 42) - # seed_everything(seed) - if cuda_available(): log.info("CUDA is available. Using GPU for sampling.") device = torch.device("cuda") @@ -94,6 +91,7 @@ def main( # we would want (num_samples, num_nodes, 3) generated_positions = [] + times = [] for batch_start in range(0, num_samples, max_batch_size): # get batch_size @@ -103,17 +101,26 @@ def main( batched_data = Batch.from_data_list([data] * batch_size) # get one_hot, edge_index, batch - z, edge_index, batch, node_attr = ( + ( + z, + edge_index, + batch, + node_attr, + chiral_index, + chiral_nbr_index, + chiral_tag, + ) = ( batched_data["atomic_numbers"].to(device), batched_data["edge_index"].to(device), batched_data["batch"].to(device), batched_data["node_attr"].to(device), + batched_data["chiral_index"].to(device), + batched_data["chiral_nbr_index"].to(device), + batched_data["chiral_tag"].to(device), ) - chiral_index = batched_data["chiral_index"].to(device) - chiral_nbr_index = batched_data["chiral_nbr_index"].to(device) - chiral_tag = batched_data["chiral_tag"].to(device) - + # get time-estimate + start = time.time() with torch.no_grad(): if model_type == "BaseSFM": # generate samples @@ -139,6 +146,8 @@ def main( chiral_tag=chiral_tag, eps=eps, ) + end = time.time() + times.append((end - start) / batch_size) # store time per conformer # reshape to (num_samples, num_atoms, 3) using batch pos = pos.view(batch_size, -1, 3).cpu().detach().numpy() @@ -146,26 +155,13 @@ def main( # append to generated_positions generated_positions.append(pos) - # if debug mode, break after first batch - if debug: - break - - # if debug mode, break after first molecule - if debug: - break - - # concatenate generated_positions - generated_positions = np.concatenate( - generated_positions, axis=0 - ) # (num_samples, num_atoms, 3) + # concatenate generated_positions: (num_samples, num_atoms, 3) + generated_positions = np.concatenate(generated_positions, axis=0) # save to file - if not debug: - path = osp.join(output_dir, f"{idx}.pkl") - log.info( - f"Saving generated positions to file for smiles {smiles} at {path}" - ) - save_pkl(path, generated_positions) + path = osp.join(output_dir, f"{idx}.pkl") + log.info(f"Saving generated positions to file for smiles {smiles} at {path}") + save_pkl(path, generated_positions) # compile all generate pkl into a single file log.info("Compile all generated pickle files into a single file") @@ -178,6 +174,10 @@ def main( lambda row: dataset[row["index"]].pos.unsqueeze(0).numpy(), axis=1 ) + # log time per conformer + wandb.log({"time_per_conformer": np.mean(times)}) + save_pkl(os.path.join(output_dir, "times.pkl"), times) + # create pos_gen data_list = {} for index in tqdm(l): @@ -276,6 +276,5 @@ def main( nsteps=args.nsteps, batch_size=args.batch_size, eps=args.eps, - debug=debug, subset_type=args.dataset_type, ) diff --git a/etflow/eval_xl.py b/etflow/eval_xl.py index 6628e06..5964f04 100644 --- a/etflow/eval_xl.py +++ b/etflow/eval_xl.py @@ -19,44 +19,29 @@ import numpy as np import pandas as pd import torch -import wandb from lightning import seed_everything from loguru import logger as log from torch_geometric.data import Batch, Data from tqdm import tqdm -from etflow.commons import ( - atom_to_feature_vector, - compute_edge_index, - get_atomic_number_and_charge, - get_chiral_tensors, - get_lpe, - load_pkl, - save_pkl, -) +import wandb +from etflow.commons import MoleculeFeaturizer, load_pkl, save_pkl from etflow.utils import instantiate_model, read_yaml torch.set_float32_matmul_precision("high") +mol_feat = MoleculeFeaturizer() -def get_data(mol, use_lpe=False, lpe_k=4): + +def get_data(mol, use_ogb_feat: bool, use_edge_feat: bool): """Convert mol object to Data object""" - x = get_atomic_number_and_charge(mol) - atomic_numbers = torch.from_numpy(x[:, 0]).int() - edge_index, _ = compute_edge_index(mol) - edge_index = edge_index.long() - node_attr = torch.tensor( - [atom_to_feature_vector(atom) for atom in mol.GetAtoms()], - dtype=torch.float32, + atomic_numbers = mol_feat.get_atomic_numbers_from_mol(mol) + edge_index, _ = mol_feat.get_edge_index_from_mol(mol, use_edge_feat) + node_attr = mol_feat.get_atom_features_from_mol(mol, use_ogb_feat) + chiral_index, chiral_nbr_index, chiral_tag = mol_feat.get_chiral_centers_from_mol( + mol ) - chiral_index, chiral_nbr_index, chiral_tag = get_chiral_tensors(mol) - if use_lpe: - lpe = get_lpe( - Data(edge_index=edge_index), - num_nodes=x.shape[0], - k=lpe_k, - ) - node_attr = torch.cat([node_attr, lpe], dim=-1) + return Data( atomic_numbers=atomic_numbers, edge_index=edge_index, @@ -118,15 +103,13 @@ def main( model.load_state_dict(state_dict) # check if we need to use lpe - lpe_k = None - use_lpe = False - if ( - "use_lpe" in config["datamodule_args"]["dataset_args"] - and config["datamodule_args"]["dataset_args"]["use_lpe"] - ): - use_lpe = True - lpe_k = config["datamodule_args"]["dataset_args"].get("lpe_k", 4) - log.info(f"Using LPE with k={lpe_k}") + use_ogb_feat = config["datamodule_args"]["dataset_args"].get("use_ogb_feat", True) + use_edge_feat = config["datamodule_args"]["dataset_args"].get( + "use_edge_feat", False + ) + log.info( + f"Using OGB features: {use_ogb_feat}, Using edge features: {use_edge_feat}" + ) # move to device model = model.to(device) @@ -146,7 +129,7 @@ def main( # get molecular graph mol_obj = mols[smiles][0] - data = get_data(mol_obj, use_lpe=use_lpe, lpe_k=lpe_k) + data = get_data(mol_obj, use_ogb_feat, use_edge_feat) # we would want (num_samples, num_nodes, 3) generated_positions = [] diff --git a/etflow/networks/torchmd_net/model_dynamics.py b/etflow/networks/torchmd_net/model_dynamics.py index a8abb9f..e35b981 100644 --- a/etflow/networks/torchmd_net/model_dynamics.py +++ b/etflow/networks/torchmd_net/model_dynamics.py @@ -83,16 +83,15 @@ def __init__( else: self.q_proj = nn.Linear(hidden_channels, hidden_channels) self.k_proj = nn.Linear(hidden_channels, hidden_channels) - self.v_proj = nn.Linear(hidden_channels, hidden_channels * 3) - self.o_proj = nn.Linear( + self.v_proj = nn.Linear( hidden_channels, hidden_channels * (3 + int(so3_equivariant)) ) - + self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias=False) # projection linear layers for edge attributes self.dk_proj = nn.Linear(num_rbf, hidden_channels) - self.dv_proj = nn.Linear(num_rbf, hidden_channels * 3) + self.dv_proj = nn.Linear(num_rbf, hidden_channels * (3 + int(so3_equivariant))) self.reset_parameters() @@ -131,7 +130,9 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, t, node_attr): q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim) k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim) # value features: (num_atoms, num_heads, 3 * head_dim) - v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim * 3) + v = self.v_proj(x).reshape( + -1, self.num_heads, self.head_dim * (3 + int(self.so3_equivariant)) + ) # vec features: (num_atoms, 3, hidden_channels) (all invariant) vec1, vec2, vec3 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) @@ -142,7 +143,9 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, t, node_attr): # into dk and dv vectors with shape (num_edges, num_heads, head_dim) # and (num_edges, num_heads, 3 * head_dim) respectively dk = self.act(self.dk_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim) - dv = self.act(self.dv_proj(f_ij)).reshape(-1, self.num_heads, self.head_dim * 3) + dv = self.act(self.dv_proj(f_ij)).reshape( + -1, self.num_heads, self.head_dim * (3 + int(self.so3_equivariant)) + ) # Message Passing Propagate x, vec = self.propagate( @@ -164,15 +167,9 @@ def forward(self, x, vec, edge_index, r_ij, f_ij, d_ij, t, node_attr): # normalize the vec if norm_coors is True vec = self.coors_norm(vec) - if self.so3_equivariant: - o1, o2, o3, o4 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) - vec3 = vec3 + vec3.cross(vec1) * o4.unsqueeze(1) - else: - o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) - - dx = vec_dot * o2 + o3 + o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1) dvec = vec3 * o1.unsqueeze(1) + vec - + dx = vec_dot * o2 + o3 return dx, dvec def message( @@ -194,14 +191,17 @@ def message( # value pathway v_j = v_j * dv # multiply with edge attr features - x, vec1, vec2 = torch.split(v_j, self.head_dim, dim=2) + x, vec1, vec2, vec3 = torch.split(v_j, self.head_dim, dim=2) # update scalar features x = x * attn.unsqueeze(2) # (num_edges, num_heads, head_dim) # update vector features (num_edges, 3, num_heads, head_dim) - vec = vec_j * vec1.unsqueeze(1) + vec2.unsqueeze(1) * d_ij.unsqueeze( - 2 - ).unsqueeze(3) + vec = ( + vec_j * vec1.unsqueeze(1) + + vec2.unsqueeze(1) * d_ij.unsqueeze(2).unsqueeze(3) + + vec3.unsqueeze(1) + * torch.cross(d_ij.unsqueeze(2).unsqueeze(3), vec_j, dim=1) + ) return x, vec def aggregate(