From 1966ea36a6150ec962a75209c5ec57f9e0fc8993 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Thu, 14 Nov 2024 16:45:04 +0000 Subject: [PATCH 1/6] save the atom hybridization --- gufe/components/smallmoleculecomponent.py | 17 ++++++++++++++++- gufe/tests/test_smallmoleculecomponent.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index b51f2b7d..81a67e7f 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -64,6 +64,20 @@ 5: Chem.rdchem.BondStereo.STEREOTRANS} _BONDSTEREO_TO_INT = {v: k for k, v in _INT_TO_BONDSTEREO.items()} +# following the numbering in rdkit +_INT_TO_HYBRIDIZATION = { + 0: Chem.rdchem.HybridizationType.UNSPECIFIED, + 1: Chem.rdchem.HybridizationType.S, + 2: Chem.rdchem.HybridizationType.SP, + 3: Chem.rdchem.HybridizationType.SP2, + 4: Chem.rdchem.HybridizationType.SP3, + 5: Chem.rdchem.HybridizationType.SP2D, + 6: Chem.rdchem.HybridizationType.SP3D, + 7: Chem.rdchem.HybridizationType.SP3D2, + 8: Chem.rdchem.HybridizationType.OTHER +} +_HYBRIDIZATION_TO_INT = {v: k for k, v in _INT_TO_HYBRIDIZATION.items()} + def _setprops(obj, d: dict) -> None: # add props onto rdkit "obj" (atom/bond/mol/conformer) @@ -213,7 +227,7 @@ def _to_dict(self) -> dict: atoms.append(( atom.GetAtomicNum(), atom.GetIsotope(), atom.GetFormalCharge(), atom.GetIsAromatic(), _ATOMCHIRAL_TO_INT[atom.GetChiralTag()], atom.GetAtomMapNum(), - atom.GetPropsAsDict(includePrivate=False), + atom.GetPropsAsDict(includePrivate=False), _HYBRIDIZATION_TO_INT[atom.GetHybridization()] )) output['atoms'] = atoms @@ -247,6 +261,7 @@ def _from_dict(cls, d: dict): a.SetChiralTag(_INT_TO_ATOMCHIRAL[atom[4]]) a.SetAtomMapNum(atom[5]) _setprops(a, atom[6]) + a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) em.AddAtom(a) for bond in d['bonds']: diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 354771c9..2f4e17be 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -330,6 +330,18 @@ def test_to_dict(self, phenol): assert isinstance(d, dict) + def test_to_dict_hybridization(self, phenol): + """ + Make sure dict round trip saves the hybridization + + """ + phenol_dict = phenol.to_dict() + TOKENIZABLE_REGISTRY.clear() + new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) + for atom in new_phenol.to_rdkit().GetAtoms(): + if atom.GetAtomicNum() == 6: + assert atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2 + @pytest.mark.skipif(not HAS_OFFTK, reason="no openff toolkit available") def test_deserialize_roundtrip(self, toluene, phenol): roundtrip = SmallMoleculeComponent.from_dict(phenol.to_dict()) From 70540f557da8897a5e169f6aa973483794c3807f Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Fri, 22 Nov 2024 12:13:43 +0000 Subject: [PATCH 2/6] use try except for missing hybridization data --- gufe/components/smallmoleculecomponent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index 81a67e7f..53914d88 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -261,7 +261,11 @@ def _from_dict(cls, d: dict): a.SetChiralTag(_INT_TO_ATOMCHIRAL[atom[4]]) a.SetAtomMapNum(atom[5]) _setprops(a, atom[6]) - a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) + try: + a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) + except IndexError: + pass + em.AddAtom(a) for bond in d['bonds']: From e4867fc52128c5943842f85a615f6af3d3425380 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Fri, 22 Nov 2024 12:36:15 +0000 Subject: [PATCH 3/6] fix and expand tests --- gufe/tests/data/ligand_network.graphml | 6 +++--- gufe/tests/test_smallmoleculecomponent.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/gufe/tests/data/ligand_network.graphml b/gufe/tests/data/ligand_network.graphml index 31fe9e56..d331b2f6 100644 --- a/gufe/tests/data/ligand_network.graphml +++ b/gufe/tests/data/ligand_network.graphml @@ -4,13 +4,13 @@ - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}], [1, 2, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (3, 3), } \n\u00809B.\u00dc\u00c8\u00f4\u00bf\u00f5\u00ff\u00ff\u00ff\u00ff\u00ff\u00cf\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0001\u0000\u0000\u0000\u0000\u0000\u00e0?\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00809B.\u00dc\u00c8\u00f4?\u0006\u0000\u0000\u0000\u0000\u0000\u00d0\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} - {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}], [8, 0, 0, false, 0, 0, {}]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} + {":version:": 1, "__module__": "gufe.components.smallmoleculecomponent", "__qualname__": "SmallMoleculeComponent", "atoms": [[6, 0, 0, false, 0, 0, {}, 4], [8, 0, 0, false, 0, 0, {}, 4]], "bonds": [[0, 1, 1, 0, {}]], "conformer": ["\u0093NUMPY\u0001\u0000v\u0000{'descr': '<f8', 'fortran_order': False, 'shape': (2, 3), } \n\u0000\u0000\u0000\u0000\u0000\u0000\u00e8\u00bf\u0000\u0000\u0000\u0000\u0000\u0000\u0090<\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u00e8?\u0000\u0000\u0000\u0000\u0000\u0000\u0090\u00bc\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", {}], "molprops": {"ofe-name": ""}} [[0, 0]] diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 2f4e17be..449f8194 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -342,6 +342,24 @@ def test_to_dict_hybridization(self, phenol): if atom.GetAtomicNum() == 6: assert atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2 + def test_from_dict_missing_hybridization(self, phenol): + """ + For backwards compatibility make sure we can create an SMC with missing hybridization info. + """ + phenol_dict = phenol.to_dict() + new_atoms = [] + for atom in phenol_dict["atoms"]: + # remove the hybridization atomic info which should be at index 7 + new_atoms.append(tuple([atom_info for i, atom_info in enumerate(atom) if i != 7])) + phenol_dict["atoms"] = new_atoms + new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) + # they should be different objects due to the missing hybridization info + assert new_phenol != phenol + # make sure the rdkit objects are different + for atom_hybrid, atom_no_hybrid in zip(phenol.to_rdkit().GetAtoms(), new_phenol.to_rdkit().GetAtoms()): + assert atom_hybrid.GetHybridization() != atom_no_hybrid.GetHybridization() + + @pytest.mark.skipif(not HAS_OFFTK, reason="no openff toolkit available") def test_deserialize_roundtrip(self, toluene, phenol): roundtrip = SmallMoleculeComponent.from_dict(phenol.to_dict()) From 1e40b08ebbc72c1815a370297b3dbf971df846fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:01:10 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pull_request_template.md | 1 - gufe/__init__.py | 12 +-- gufe/chemicalsystem.py | 8 +- gufe/components/explicitmoleculecomponent.py | 21 ++--- gufe/components/proteincomponent.py | 31 ++----- gufe/components/smallmoleculecomponent.py | 2 +- gufe/components/solventcomponent.py | 13 +-- gufe/custom_codecs.py | 7 +- gufe/ligandnetwork.py | 27 ++---- gufe/mapping/atom_mapper.py | 4 +- gufe/mapping/ligandatommapping.py | 24 ++--- gufe/network.py | 14 +-- gufe/protocols/protocol.py | 17 +--- gufe/protocols/protocoldag.py | 27 ++---- gufe/settings/__init__.py | 8 +- gufe/settings/models.py | 24 ++--- gufe/storage/externalresource/filestorage.py | 4 +- .../storage/externalresource/memorystorage.py | 4 +- gufe/tests/conftest.py | 39 +++----- gufe/tests/storage/test_externalresource.py | 5 +- gufe/tests/test_chemicalsystem.py | 16 +--- gufe/tests/test_ligand_network.py | 42 +++------ gufe/tests/test_ligandatommapping.py | 20 ++--- gufe/tests/test_mapping.py | 12 +-- gufe/tests/test_mapping_visualization.py | 4 +- gufe/tests/test_models.py | 6 +- gufe/tests/test_proteincomponent.py | 20 ++--- gufe/tests/test_protocol.py | 88 +++++-------------- gufe/tests/test_protocoldag.py | 8 +- gufe/tests/test_protocolunit.py | 7 +- gufe/tests/test_serialization_migration.py | 6 +- gufe/tests/test_smallmoleculecomponent.py | 1 - gufe/tests/test_solvents.py | 16 +--- gufe/tests/test_tokenization.py | 13 +-- gufe/tests/test_transformation.py | 36 ++------ gufe/tokenization.py | 61 ++++--------- gufe/transformations/transformation.py | 18 +--- gufe/utils.py | 3 +- gufe/vendor/pdb_file/PdbxContainers.py | 53 +++-------- gufe/vendor/pdb_file/PdbxReader.py | 25 ++---- gufe/vendor/pdb_file/element.py | 4 +- gufe/vendor/pdb_file/pdbfile.py | 78 +++++----------- gufe/vendor/pdb_file/pdbstructure.py | 46 +++------- gufe/vendor/pdb_file/pdbxfile.py | 17 +--- gufe/vendor/pdb_file/topology.py | 50 +++-------- gufe/vendor/pdb_file/unitcell.py | 6 +- gufe/visualization/mapping_visualization.py | 12 +-- 47 files changed, 235 insertions(+), 725 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md index f66a59f2..d34f4787 100644 --- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -20,4 +20,3 @@ Checklist ## Developers certificate of origin - [ ] I certify that this contribution is covered by the MIT License [here](https://github.com/OpenFreeEnergy/openfe/blob/main/LICENSE) and the **Developer Certificate of Origin** at . - diff --git a/gufe/__init__.py b/gufe/__init__.py index ef3107b6..c943abf9 100644 --- a/gufe/__init__.py +++ b/gufe/__init__.py @@ -5,12 +5,7 @@ from . import tokenization, visualization from .chemicalsystem import ChemicalSystem -from .components import ( - Component, - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, -) +from .components import Component, ProteinComponent, SmallMoleculeComponent, SolventComponent from .ligandnetwork import LigandNetwork from .mapping import AtomMapper # more specific to atom based components from .mapping import ComponentMapping # how individual Components relate @@ -21,10 +16,7 @@ from .protocols import ProtocolDAGResult # the collected result of a DAG from .protocols import ProtocolUnit # the individual step within a method from .protocols import ProtocolUnitResult # the result of a single Unit -from .protocols import ( # potentially many DAGs together, giving an estimate - Context, - ProtocolResult, -) +from .protocols import Context, ProtocolResult # potentially many DAGs together, giving an estimate from .settings import Settings from .transformations import NonTransformation, Transformation diff --git a/gufe/chemicalsystem.py b/gufe/chemicalsystem.py index 9bab5515..c4eb4add 100644 --- a/gufe/chemicalsystem.py +++ b/gufe/chemicalsystem.py @@ -40,15 +40,11 @@ def __init__( self._name = name def __repr__(self): - return ( - f"{self.__class__.__name__}(name={self.name}, components={self.components})" - ) + return f"{self.__class__.__name__}(name={self.name}, components={self.components})" def _to_dict(self): return { - "components": { - key: value for key, value in sorted(self.components.items()) - }, + "components": {key: value for key, value in sorted(self.components.items())}, "name": self.name, } diff --git a/gufe/components/explicitmoleculecomponent.py b/gufe/components/explicitmoleculecomponent.py index 1d325e1d..88f1f196 100644 --- a/gufe/components/explicitmoleculecomponent.py +++ b/gufe/components/explicitmoleculecomponent.py @@ -59,16 +59,12 @@ def _check_partial_charges(mol: RDKitMol) -> None: p_chgs = np.array(mol.GetProp("atom.dprop.PartialCharge").split(), dtype=float) if len(p_chgs) != mol.GetNumAtoms(): - errmsg = ( - f"Incorrect number of partial charges: {len(p_chgs)} " - f" were provided for {mol.GetNumAtoms()} atoms" - ) + errmsg = f"Incorrect number of partial charges: {len(p_chgs)} " f" were provided for {mol.GetNumAtoms()} atoms" raise ValueError(errmsg) if (sum(p_chgs) - Chem.GetFormalCharge(mol)) > 0.01: errmsg = ( - f"Sum of partial charges {sum(p_chgs)} differs from " - f"RDKit formal charge {Chem.GetFormalCharge(mol)}" + f"Sum of partial charges {sum(p_chgs)} differs from " f"RDKit formal charge {Chem.GetFormalCharge(mol)}" ) raise ValueError(errmsg) @@ -81,16 +77,12 @@ def _check_partial_charges(mol: RDKitMol) -> None: atom_charge = atom.GetDoubleProp("PartialCharge") if not np.isclose(atom_charge, charge): errmsg = ( - f"non-equivalent partial charges between atom and " - f"molecule properties: {atom_charge} {charge}" + f"non-equivalent partial charges between atom and " f"molecule properties: {atom_charge} {charge}" ) raise ValueError(errmsg) if np.all(np.isclose(p_chgs, 0.0)): - wmsg = ( - f"Partial charges provided all equal to " - "zero. These may be ignored by some Protocols." - ) + wmsg = f"Partial charges provided all equal to " "zero. These may be ignored by some Protocols." warnings.warn(wmsg) else: wmsg = ( @@ -121,10 +113,7 @@ def __init__(self, rdkit: RDKitMol, name: str = ""): n_confs = len(conformers) if n_confs > 1: - warnings.warn( - f"Molecule provided with {n_confs} conformers. " - f"Only the first will be used." - ) + warnings.warn(f"Molecule provided with {n_confs} conformers. " f"Only the first will be used.") if not any(atom.GetAtomicNum() == 1 for atom in rdkit.GetAtoms()): warnings.warn( diff --git a/gufe/components/proteincomponent.py b/gufe/components/proteincomponent.py index 6f788701..a23ccf09 100644 --- a/gufe/components/proteincomponent.py +++ b/gufe/components/proteincomponent.py @@ -144,9 +144,7 @@ def from_pdbx_file(cls, pdbx_file: str, name=""): return cls._from_openmmPDBFile(openmm_PDBFile=openmm_PDBxFile, name=name) @classmethod - def _from_openmmPDBFile( - cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = "" - ): + def _from_openmmPDBFile(cls, openmm_PDBFile: Union[PDBFile, PDBxFile], name: str = ""): """Converts to our internal representation (rdkit Mol) Parameters @@ -201,9 +199,7 @@ def _from_openmmPDBFile( # Set Positions rd_mol = editable_rdmol.GetMol() - positions = np.array( - openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3 - ) + positions = np.array(openmm_PDBFile.positions.value_in_unit(omm_unit.angstrom), ndmin=3) for frame_id, frame in enumerate(positions): conf = Conformer(frame_id) @@ -218,9 +214,7 @@ def _from_openmmPDBFile( atomic_num = a.GetAtomicNum() atom_name = a.GetMonomerInfo().GetName() - connectivity = sum( - _BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds() - ) + connectivity = sum(_BONDORDER_TO_ORDER[bond.GetBondType()] for bond in a.GetBonds()) default_valence = periodicTable.GetDefaultValence(atomic_num) if connectivity == 0: # ions: @@ -364,9 +358,7 @@ def chainkey(m): if (new_resid := reskey(mi)) != current_resid: _, resname, resnum, icode = new_resid - r = top.addResidue( - name=resname, chain=c, id=str(resnum), insertionCode=icode - ) + r = top.addResidue(name=resname, chain=c, id=str(resnum), insertionCode=icode) current_resid = new_resid a = top.addAtom( @@ -381,9 +373,7 @@ def chainkey(m): for bond in self._rdkit.GetBonds(): a1 = atom_lookup[bond.GetBeginAtomIdx()] a2 = atom_lookup[bond.GetEndAtomIdx()] - top.addBond( - a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None) - ) + top.addBond(a1, a2, order=_BONDORDERS_RDKIT_TO_OPENMM.get(bond.GetBondType(), None)) return top @@ -405,9 +395,7 @@ def to_openmm_positions(self) -> omm_unit.Quantity: return openmm_pos - def to_pdb_file( - self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase] - ) -> str: + def to_pdb_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str: """ serialize protein to pdb file. @@ -449,9 +437,7 @@ def to_pdb_file( return out_path - def to_pdbx_file( - self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase] - ) -> str: + def to_pdbx_file(self, out_path: Union[str, bytes, PathLike[str], PathLike[bytes], io.TextIOBase]) -> str: """ serialize protein to pdbx file. @@ -529,8 +515,7 @@ def _to_dict(self) -> dict: ] conformers = [ - serialize_numpy(conf.GetPositions()) # .m_as(unit.angstrom) - for conf in self._rdkit.GetConformers() + serialize_numpy(conf.GetPositions()) for conf in self._rdkit.GetConformers() # .m_as(unit.angstrom) ] # Result diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index ca33a78e..019212e3 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -78,7 +78,7 @@ 5: Chem.rdchem.HybridizationType.SP2D, 6: Chem.rdchem.HybridizationType.SP3D, 7: Chem.rdchem.HybridizationType.SP3D2, - 8: Chem.rdchem.HybridizationType.OTHER + 8: Chem.rdchem.HybridizationType.OTHER, } _HYBRIDIZATION_TO_INT = {v: k for k, v in _INT_TO_HYBRIDIZATION.items()} diff --git a/gufe/components/solventcomponent.py b/gufe/components/solventcomponent.py index ef5d408d..8b9b01e1 100644 --- a/gufe/components/solventcomponent.py +++ b/gufe/components/solventcomponent.py @@ -76,17 +76,10 @@ def __init__( self._neutralize = neutralize - if not isinstance( - ion_concentration, unit.Quantity - ) or not ion_concentration.is_compatible_with(unit.molar): - raise ValueError( - f"ion_concentration must be given in units of" - f" concentration, got: {ion_concentration}" - ) + if not isinstance(ion_concentration, unit.Quantity) or not ion_concentration.is_compatible_with(unit.molar): + raise ValueError(f"ion_concentration must be given in units of" f" concentration, got: {ion_concentration}") if ion_concentration.m < 0: - raise ValueError( - f"ion_concentration must be positive, " f"got: {ion_concentration}" - ) + raise ValueError(f"ion_concentration must be positive, " f"got: {ion_concentration}") self._ion_concentration = ion_concentration diff --git a/gufe/custom_codecs.py b/gufe/custom_codecs.py index 36a13b92..7499ccd6 100644 --- a/gufe/custom_codecs.py +++ b/gufe/custom_codecs.py @@ -96,9 +96,7 @@ def is_openff_quantity_dict(dct): "shape": list(obj.shape), "bytes": obj.tobytes(), }, - from_dict=lambda dct: np.frombuffer( - dct["bytes"], dtype=np.dtype(dct["dtype"]) - ).reshape(dct["shape"]), + from_dict=lambda dct: np.frombuffer(dct["bytes"], dtype=np.dtype(dct["dtype"])).reshape(dct["shape"]), ) @@ -118,8 +116,7 @@ def is_openff_quantity_dict(dct): ":is_custom:": True, "pint_unit_registry": "openff_units", }, - from_dict=lambda dct: dct["magnitude"] - * DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]), + from_dict=lambda dct: dct["magnitude"] * DEFAULT_UNIT_REGISTRY.Quantity(dct["unit"]), is_my_obj=lambda obj: isinstance(obj, DEFAULT_UNIT_REGISTRY.Quantity), is_my_dict=is_openff_quantity_dict, ) diff --git a/gufe/ligandnetwork.py b/gufe/ligandnetwork.py index ff87f567..45eac118 100644 --- a/gufe/ligandnetwork.py +++ b/gufe/ligandnetwork.py @@ -39,9 +39,7 @@ def __init__( nodes = [] self._edges = frozenset(edges) - edge_nodes = set( - chain.from_iterable((e.componentA, e.componentB) for e in edges) - ) + edge_nodes = set(chain.from_iterable((e.componentA, e.componentB) for e in edges)) self._nodes = frozenset(edge_nodes) | frozenset(nodes) self._graph = None @@ -70,9 +68,7 @@ def graph(self) -> nx.MultiDiGraph: for node in sorted(self._nodes): graph.add_node(node) for edge in sorted(self._edges): - graph.add_edge( - edge.componentA, edge.componentB, object=edge, **edge.annotations - ) + graph.add_edge(edge.componentA, edge.componentB, object=edge, **edge.annotations) self._graph = nx.freeze(graph) @@ -116,14 +112,10 @@ def _serializable_graph(self) -> nx.Graph: # from here, we just build the graph serializable_graph = nx.MultiDiGraph() for mol, label in mol_to_label.items(): - serializable_graph.add_node( - label, moldict=json.dumps(mol.to_dict(), sort_keys=True) - ) + serializable_graph.add_node(label, moldict=json.dumps(mol.to_dict(), sort_keys=True)) for molA, molB, mapping, annotation in edge_data: - serializable_graph.add_edge( - molA, molB, mapping=mapping, annotations=annotation - ) + serializable_graph.add_edge(molA, molB, mapping=mapping, annotations=annotation) return serializable_graph @@ -134,8 +126,7 @@ def _from_serializable_graph(cls, graph: nx.Graph): This is the inverse of ``_serializable_graph``. """ label_to_mol = { - node: SmallMoleculeComponent.from_dict(json.loads(d)) - for node, d in graph.nodes(data="moldict") + node: SmallMoleculeComponent.from_dict(json.loads(d)) for node, d in graph.nodes(data="moldict") } edges = [ @@ -242,9 +233,7 @@ def sys_from_dict(component): """ syscomps = {alchemical_label: component} other_labels = set(labels) - {alchemical_label} - syscomps.update( - {label: components[label] for label in other_labels} - ) + syscomps.update({label: components[label] for label in other_labels}) if autoname: name = f"{component.name}_{leg_name}" @@ -261,9 +250,7 @@ def sys_from_dict(component): else: name = "" - transformation = gufe.Transformation( - sysA, sysB, protocol, mapping=edge, name=name - ) + transformation = gufe.Transformation(sysA, sysB, protocol, mapping=edge, name=name) transformations.append(transformation) diff --git a/gufe/mapping/atom_mapper.py b/gufe/mapping/atom_mapper.py index 76e49626..c2844979 100644 --- a/gufe/mapping/atom_mapper.py +++ b/gufe/mapping/atom_mapper.py @@ -19,9 +19,7 @@ class AtomMapper(GufeTokenizable): """ @abc.abstractmethod - def suggest_mappings( - self, A: gufe.Component, B: gufe.Component - ) -> Iterator[AtomMapping]: + def suggest_mappings(self, A: gufe.Component, B: gufe.Component) -> Iterator[AtomMapping]: """Suggests possible mappings between two Components Suggests zero or more :class:`.AtomMapping` objects, which are possible diff --git a/gufe/mapping/ligandatommapping.py b/gufe/mapping/ligandatommapping.py index f9077ef7..4d322000 100644 --- a/gufe/mapping/ligandatommapping.py +++ b/gufe/mapping/ligandatommapping.py @@ -60,13 +60,9 @@ def __init__( nB = self.componentB.to_rdkit().GetNumAtoms() for i, j in componentA_to_componentB.items(): if not (0 <= i < nA): - raise ValueError( - f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}" - ) + raise ValueError(f"Got invalid index for ComponentA ({i}); " f"must be 0 <= n < {nA}") if not (0 <= j < nB): - raise ValueError( - f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}" - ) + raise ValueError(f"Got invalid index for ComponentB ({i}); " f"must be 0 <= n < {nB}") self._compA_to_compB = componentA_to_componentB @@ -89,19 +85,11 @@ def componentB_to_componentA(self) -> dict[int, int]: @property def componentA_unique(self): - return ( - i - for i in range(self.componentA.to_rdkit().GetNumAtoms()) - if i not in self._compA_to_compB - ) + return (i for i in range(self.componentA.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB) @property def componentB_unique(self): - return ( - i - for i in range(self.componentB.to_rdkit().GetNumAtoms()) - if i not in self._compA_to_compB.values() - ) + return (i for i in range(self.componentB.to_rdkit().GetNumAtoms()) if i not in self._compA_to_compB.values()) @property def annotations(self): @@ -118,9 +106,7 @@ def _to_dict(self): "componentA": self.componentA, "componentB": self.componentB, "componentA_to_componentB": self._compA_to_compB, - "annotations": json.dumps( - self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder - ), + "annotations": json.dumps(self._annotations, sort_keys=True, cls=JSON_HANDLER.encoder), } @classmethod diff --git a/gufe/network.py b/gufe/network.py index 65b19ab9..18bb06d9 100644 --- a/gufe/network.py +++ b/gufe/network.py @@ -48,11 +48,7 @@ def __init__( else: self._nodes = frozenset(nodes) - self._nodes = ( - self._nodes - | frozenset(e.stateA for e in self._edges) - | frozenset(e.stateB for e in self._edges) - ) + self._nodes = self._nodes | frozenset(e.stateA for e in self._edges) | frozenset(e.stateB for e in self._edges) self._graph = None @@ -61,9 +57,7 @@ def _generate_graph(edges, nodes) -> nx.MultiDiGraph: g = nx.MultiDiGraph() for transformation in edges: - g.add_edge( - transformation.stateA, transformation.stateB, object=transformation - ) + g.add_edge(transformation.stateA, transformation.stateB, object=transformation) g.add_nodes_from(nodes) @@ -108,9 +102,7 @@ def _to_dict(self) -> dict: @classmethod def _from_dict(cls, d: dict) -> Self: - return cls( - nodes=frozenset(d["nodes"]), edges=frozenset(d["edges"]), name=d.get("name") - ) + return cls(nodes=frozenset(d["nodes"]), edges=frozenset(d["edges"]), name=d.get("name")) @classmethod def _defaults(cls): diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index 2089e80b..0ab9d5bb 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -177,9 +177,7 @@ def create( *, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[ - Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]] - ], + mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]], extends: Optional[ProtocolDAGResult] = None, name: Optional[str] = None, transformation_key: Optional[GufeKey] = None, @@ -224,10 +222,7 @@ def create( """ if isinstance(mapping, dict): warnings.warn( - ( - "mapping input as a dict is deprecated, " - "instead use either a single Mapping or list" - ), + ("mapping input as a dict is deprecated, " "instead use either a single Mapping or list"), DeprecationWarning, ) mapping = list(mapping.values()) @@ -244,9 +239,7 @@ def create( extends_key=extends.key if extends is not None else None, ) - def gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> ProtocolResult: + def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult: """Gather multiple ProtocolDAGResults into a single ProtocolResult. Parameters @@ -263,9 +256,7 @@ def gather( return self.result_cls(**self._gather(protocol_dag_results)) @abc.abstractmethod - def _gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]: """Method to override in custom Protocol subclasses. This method should take any number of ``ProtocolDAGResult``s produced diff --git a/gufe/protocols/protocoldag.py b/gufe/protocols/protocoldag.py index 05ca65e3..74e7afbd 100644 --- a/gufe/protocols/protocoldag.py +++ b/gufe/protocols/protocoldag.py @@ -45,9 +45,7 @@ def _build_graph(nodes): @staticmethod def _iterate_dag_order(graph): - return reversed( - list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key)) - ) + return reversed(list(nx.lexicographical_topological_sort(graph, key=lambda pu: pu.key))) @property def name(self) -> Optional[str]: @@ -121,9 +119,7 @@ def __init__( self._protocol_units = protocol_units self._protocol_unit_results = protocol_unit_results - self._transformation_key = ( - GufeKey(transformation_key) if transformation_key is not None else None - ) + self._transformation_key = GufeKey(transformation_key) if transformation_key is not None else None self._extends_key = GufeKey(extends_key) if extends_key is not None else None # build graph from protocol units @@ -229,9 +225,7 @@ def unit_to_result(self, protocol_unit: ProtocolUnit) -> ProtocolUnitResult: else: raise KeyError("No success for `protocol_unit` found") - def unit_to_all_results( - self, protocol_unit: ProtocolUnit - ) -> list[ProtocolUnitResult]: + def unit_to_all_results(self, protocol_unit: ProtocolUnit) -> list[ProtocolUnitResult]: """Return all results (sucess and failure) for a given Unit. Returns @@ -257,10 +251,7 @@ def result_to_unit(self, protocol_unit_result: ProtocolUnitResult) -> ProtocolUn def ok(self) -> bool: # ensure that for every protocol unit, there is an OK result object - return all( - any(pur.ok() for pur in self._unit_result_mapping[pu]) - for pu in self._protocol_units - ) + return all(any(pur.ok() for pur in self._unit_result_mapping[pu]) for pu in self._protocol_units) @property def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]: @@ -272,11 +263,7 @@ def terminal_protocol_unit_results(self) -> list[ProtocolUnitResult]: All ProtocolUnitResults which do not have a ProtocolUnitResult that follows on (depends) on them. """ - return [ - u - for u in self._protocol_unit_results - if not nx.ancestors(self._result_graph, u) - ] + return [u for u in self._protocol_unit_results if not nx.ancestors(self._result_graph, u)] class ProtocolDAG(GufeTokenizable, DAGMixin): @@ -334,9 +321,7 @@ def __init__( self._name = name self._protocol_units = protocol_units - self._transformation_key = ( - GufeKey(transformation_key) if transformation_key is not None else None - ) + self._transformation_key = GufeKey(transformation_key) if transformation_key is not None else None self._extends_key = GufeKey(extends_key) if extends_key is not None else None # build graph from protocol units diff --git a/gufe/settings/__init__.py b/gufe/settings/__init__.py index ac3a2d0a..be4c3fde 100644 --- a/gufe/settings/__init__.py +++ b/gufe/settings/__init__.py @@ -1,10 +1,4 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe """General models for defining the parameters that protocols use""" -from .models import ( - BaseForceFieldSettings, - OpenMMSystemGeneratorFFSettings, - Settings, - SettingsBaseModel, - ThermoSettings, -) +from .models import BaseForceFieldSettings, OpenMMSystemGeneratorFFSettings, Settings, SettingsBaseModel, ThermoSettings diff --git a/gufe/settings/models.py b/gufe/settings/models.py index 31ee6cc7..734fcb56 100644 --- a/gufe/settings/models.py +++ b/gufe/settings/models.py @@ -53,9 +53,7 @@ def frozen_copy(self): def freeze_model(model): submodels = ( - mod - for field in model.__fields__ - if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: freeze_model(mod) @@ -76,9 +74,7 @@ def unfrozen_copy(self): def unfreeze_model(model): submodels = ( - mod - for field in model.__fields__ - if isinstance(mod := getattr(model, field), SettingsBaseModel) + mod for field in model.__fields__ if isinstance(mod := getattr(model, field), SettingsBaseModel) ) for mod in submodels: unfreeze_model(mod) @@ -112,16 +108,12 @@ class ThermoSettings(SettingsBaseModel): possible. """ - temperature: FloatQuantity["kelvin"] = Field( - None, description="Simulation temperature, default units kelvin" - ) + temperature: FloatQuantity["kelvin"] = Field(None, description="Simulation temperature, default units kelvin") pressure: FloatQuantity["standard_atmosphere"] = Field( None, description="Simulation pressure, default units standard atmosphere (atm)" ) ph: Union[PositiveFloat, None] = Field(None, description="Simulation pH") - redox_potential: Optional[float] = Field( - None, description="Simulation redox potential" - ) + redox_potential: Optional[float] = Field(None, description="Simulation redox potential") class BaseForceFieldSettings(SettingsBaseModel, abc.ABC): @@ -171,9 +163,7 @@ class Config: ] """List of force field paths for all components except :class:`SmallMoleculeComponent` """ - small_molecule_forcefield: str = ( - "openff-2.1.1" # other default ideas 'openff-2.0.0', 'gaff-2.11', 'espaloma-0.2.0' - ) + small_molecule_forcefield: str = "openff-2.1.1" # other default ideas 'openff-2.0.0', 'gaff-2.11', 'espaloma-0.2.0' """Name of the force field to be used for :class:`SmallMoleculeComponent` """ nonbonded_method = "PME" @@ -198,9 +188,7 @@ def allowed_nonbonded(cls, v): def is_positive_distance(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.nanometer): - raise ValueError( - "nonbonded_cutoff must be in distance units " "(i.e. nanometers)" - ) + raise ValueError("nonbonded_cutoff must be in distance units " "(i.e. nanometers)") if v < 0: errmsg = "nonbonded_cutoff must be a positive value" raise ValueError(errmsg) diff --git a/gufe/storage/externalresource/filestorage.py b/gufe/storage/externalresource/filestorage.py index cebe0343..e2b5a54b 100644 --- a/gufe/storage/externalresource/filestorage.py +++ b/gufe/storage/externalresource/filestorage.py @@ -49,9 +49,7 @@ def _delete(self, location): if self.exists(location): path.unlink() else: - raise MissingExternalResourceError( - f"Unable to delete '{str(path)}': File does not exist" - ) + raise MissingExternalResourceError(f"Unable to delete '{str(path)}': File does not exist") def _as_path(self, location): return self.root_dir / pathlib.Path(location) diff --git a/gufe/storage/externalresource/memorystorage.py b/gufe/storage/externalresource/memorystorage.py index 6f36ee2e..e8ee07d3 100644 --- a/gufe/storage/externalresource/memorystorage.py +++ b/gufe/storage/externalresource/memorystorage.py @@ -20,9 +20,7 @@ def _delete(self, location): try: del self._data[location] except KeyError: - raise MissingExternalResourceError( - f"Unable to delete '{location}': key does not exist" - ) + raise MissingExternalResourceError(f"Unable to delete '{location}': key does not exist") def __eq__(self, other): return self is other diff --git a/gufe/tests/conftest.py b/gufe/tests/conftest.py index 5a7faa01..63dd30a0 100644 --- a/gufe/tests/conftest.py +++ b/gufe/tests/conftest.py @@ -58,12 +58,13 @@ def get_test_filename(filename): ] -_pl_benchmark_url_pattern = "https://github.com/OpenFreeEnergy/openfe-benchmarks/blob/main/openfe_benchmarks/data/{name}.pdb?raw=true" +_pl_benchmark_url_pattern = ( + "https://github.com/OpenFreeEnergy/openfe-benchmarks/blob/main/openfe_benchmarks/data/{name}.pdb?raw=true" +) PDB_BENCHMARK_LOADERS = { - name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name)) - for name in _benchmark_pdb_names + name: URLFileLike(url=_pl_benchmark_url_pattern.format(name=name)) for name in _benchmark_pdb_names } PDB_FILE_LOADERS = {name: lambda: get_test_filename(name) for name in ["181l.pdb"]} @@ -196,9 +197,7 @@ def prot_comp(PDB_181L_path): @pytest.fixture def solv_comp(): - yield gufe.SolventComponent( - positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar - ) + yield gufe.SolventComponent(positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar) @pytest.fixture @@ -270,9 +269,7 @@ def benzene_variants_star_map_transformations( solvated_ligands = {} solvated_ligand_transformations = {} - solvated_ligands["benzene"] = gufe.ChemicalSystem( - {"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent" - ) + solvated_ligands["benzene"] = gufe.ChemicalSystem({"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent") for ligand in variants: solvated_ligands[ligand.name] = gufe.ChemicalSystem( @@ -300,28 +297,20 @@ def benzene_variants_star_map_transformations( {"protein": prot_comp, "solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-complex", ) - solvated_complex_transformations[("benzene", ligand.name)] = ( - gufe.Transformation( - solvated_complexes["benzene"], - solvated_complexes[ligand.name], - protocol=DummyProtocol(settings=DummyProtocol.default_settings()), - mapping=None, - ) + solvated_complex_transformations[("benzene", ligand.name)] = gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand.name], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, ) - return list(solvated_ligand_transformations.values()), list( - solvated_complex_transformations.values() - ) + return list(solvated_ligand_transformations.values()), list(solvated_complex_transformations.values()) @pytest.fixture def benzene_variants_star_map(benzene_variants_star_map_transformations): - solvated_ligand_transformations, solvated_complex_transformations = ( - benzene_variants_star_map_transformations - ) - return gufe.AlchemicalNetwork( - solvated_ligand_transformations + solvated_complex_transformations - ) + solvated_ligand_transformations, solvated_complex_transformations = benzene_variants_star_map_transformations + return gufe.AlchemicalNetwork(solvated_ligand_transformations + solvated_complex_transformations) @pytest.fixture diff --git a/gufe/tests/storage/test_externalresource.py b/gufe/tests/storage/test_externalresource.py index 80d67e73..adfe9170 100644 --- a/gufe/tests/storage/test_externalresource.py +++ b/gufe/tests/storage/test_externalresource.py @@ -5,10 +5,7 @@ import pytest -from gufe.storage.errors import ( - ChangedExternalResourceError, - MissingExternalResourceError, -) +from gufe.storage.errors import ChangedExternalResourceError, MissingExternalResourceError from gufe.storage.externalresource import FileStorage, MemoryStorage # NOTE: Tests for the abstract base are just part of the tests of its diff --git a/gufe/tests/test_chemicalsystem.py b/gufe/tests/test_chemicalsystem.py index 00ced19d..414cd96b 100644 --- a/gufe/tests/test_chemicalsystem.py +++ b/gufe/tests/test_chemicalsystem.py @@ -48,13 +48,9 @@ def test_complex_construction(prot_comp, solv_comp, toluene_ligand_comp): def test_hash_and_eq(prot_comp, solv_comp, toluene_ligand_comp): - c1 = ChemicalSystem( - {"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp} - ) + c1 = ChemicalSystem({"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp}) - c2 = ChemicalSystem( - {"solvent": solv_comp, "ligand": toluene_ligand_comp, "protein": prot_comp} - ) + c2 = ChemicalSystem({"solvent": solv_comp, "ligand": toluene_ligand_comp, "protein": prot_comp}) assert c1 == c2 assert hash(c1) == hash(c2) @@ -66,9 +62,7 @@ def test_chemical_system_neq_1(solvated_complex, prot_comp): assert hash(solvated_complex) != hash(prot_comp) -def test_chemical_system_neq_2( - solvated_complex, prot_comp, solv_comp, toluene_ligand_comp -): +def test_chemical_system_neq_2(solvated_complex, prot_comp, solv_comp, toluene_ligand_comp): # names are different complex2 = ChemicalSystem( {"protein": prot_comp, "solvent": solv_comp, "ligand": toluene_ligand_comp}, @@ -85,9 +79,7 @@ def test_chemical_system_neq_4(solvated_complex, solvated_ligand): assert hash(solvated_complex) != hash(solvated_ligand) -def test_chemical_system_neq_5( - solvated_complex, prot_comp, solv_comp, phenol_ligand_comp -): +def test_chemical_system_neq_5(solvated_complex, prot_comp, solv_comp, phenol_ligand_comp): # same component keys, but different components complex2 = ChemicalSystem( {"protein": prot_comp, "solvent": solv_comp, "ligand": phenol_ligand_comp}, diff --git a/gufe/tests/test_ligand_network.py b/gufe/tests/test_ligand_network.py index ce025132..a7116a8e 100644 --- a/gufe/tests/test_ligand_network.py +++ b/gufe/tests/test_ligand_network.py @@ -51,13 +51,9 @@ def mols(): @pytest.fixture def std_edges(mols): mol1, mol2, mol3 = mols - edge12 = LigandAtomMapping( - mol1, mol2, {0: 0, 1: 1}, {"score": 0.0, "length": 1.0 * unit.angstrom} - ) + edge12 = LigandAtomMapping(mol1, mol2, {0: 0, 1: 1}, {"score": 0.0, "length": 1.0 * unit.angstrom}) edge23 = LigandAtomMapping(mol2, mol3, {0: 0}, {"score": 1.0}) - edge13 = LigandAtomMapping( - mol1, mol3, {0: 0, 2: 1}, {"score": 0.5, "time": 2.0 * unit.second} - ) + edge13 = LigandAtomMapping(mol1, mol3, {0: 0, 2: 1}, {"score": 0.5, "time": 2.0 * unit.second}) return edge12, edge23, edge13 @@ -260,16 +256,12 @@ def test_enlarge_graph_add_duplicate_edge(self, mols, simple_network): # Adding a duplicate of an existing edge should create a new network # with the same edges and nodes as the previous one. mol1, _, mol3 = mols - duplicate = LigandAtomMapping( - mol1, mol3, {0: 0, 2: 1}, {"score": 0.5, "time": 2.0 * unit.second} - ) + duplicate = LigandAtomMapping(mol1, mol3, {0: 0, 2: 1}, {"score": 0.5, "time": 2.0 * unit.second}) network = simple_network.network existing = network.edges assert duplicate in existing # matches by == - assert any( - duplicate is edge for edge in existing - ) # one edge *is* the duplicate + assert any(duplicate is edge for edge in existing) # one edge *is* the duplicate new_network = network.enlarge_graph(edges=[duplicate]) assert len(new_network.nodes) == len(network.nodes) @@ -289,9 +281,7 @@ def test_to_graphml(self, simple_network, ligandnetwork_graphml): assert simple_network.network.to_graphml() == ligandnetwork_graphml def test_from_graphml(self, simple_network, ligandnetwork_graphml): - assert ( - LigandNetwork.from_graphml(ligandnetwork_graphml) == simple_network.network - ) + assert LigandNetwork.from_graphml(ligandnetwork_graphml) == simple_network.network def test_is_connected(self, simple_network): assert simple_network.network.is_connected() @@ -342,10 +332,7 @@ def test_to_rbfe_alchemical_network( if with_cofactor: labels.add("cofactor") else: # -no-cov- - raise RuntimeError( - "Something went weird in testing. Unable " - f"to get leg for edge {edge}" - ) + raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}") assert set(compsA) == labels assert set(compsB) == labels @@ -360,9 +347,7 @@ def test_to_rbfe_alchemical_network( assert isinstance(edge.mapping, gufe.ComponentMapping) assert edge.mapping in real_molecules_network.edges - def test_to_rbfe_alchemical_network_autoname_false( - self, real_molecules_network, prot_comp, solv_comp - ): + def test_to_rbfe_alchemical_network_autoname_false(self, real_molecules_network, prot_comp, solv_comp): rbfe = real_molecules_network.to_rbfe_alchemical_network( solvent=solv_comp, protein=prot_comp, @@ -374,9 +359,7 @@ def test_to_rbfe_alchemical_network_autoname_false( for sys in [edge.stateA, edge.stateB]: assert sys.name == "" - def test_to_rbfe_alchemical_network_autoname_true( - self, real_molecules_network, prot_comp, solv_comp - ): + def test_to_rbfe_alchemical_network_autoname_true(self, real_molecules_network, prot_comp, solv_comp): rbfe = real_molecules_network.to_rbfe_alchemical_network( solvent=solv_comp, protein=prot_comp, @@ -398,9 +381,7 @@ def test_to_rhfe_alchemical_network(self, real_molecules_network, solv_comp): others = {} protocol = DummyProtocol(DummyProtocol.default_settings()) - rhfe = real_molecules_network.to_rhfe_alchemical_network( - solvent=solv_comp, protocol=protocol, **others - ) + rhfe = real_molecules_network.to_rhfe_alchemical_network(solvent=solv_comp, protocol=protocol, **others) expected_names = { "easy_rhfe_benzene_vacuum_toluene_vacuum", @@ -423,10 +404,7 @@ def test_to_rhfe_alchemical_network(self, real_molecules_network, solv_comp): elif "solvent" in edge.name: labels = {"ligand", "solvent"} else: # -no-cov- - raise RuntimeError( - "Something went weird in testing. Unable " - f"to get leg for edge {edge}" - ) + raise RuntimeError("Something went weird in testing. Unable " f"to get leg for edge {edge}") labels |= set(others) diff --git a/gufe/tests/test_ligandatommapping.py b/gufe/tests/test_ligandatommapping.py index fdd5ff69..63ae076e 100644 --- a/gufe/tests/test_ligandatommapping.py +++ b/gufe/tests/test_ligandatommapping.py @@ -231,9 +231,7 @@ def test_draw_mapping_svg(tmpdir, other_mapping): class TestLigandAtomMappingSerialization: - def test_deserialize_roundtrip( - self, benzene_phenol_mapping, benzene_anisole_mapping - ): + def test_deserialize_roundtrip(self, benzene_phenol_mapping, benzene_anisole_mapping): roundtrip = LigandAtomMapping.from_dict(benzene_phenol_mapping.to_dict()) @@ -299,27 +297,19 @@ def molB(self): def test_too_large_A(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentA"): - LigandAtomMapping( - componentA=molA, componentB=molB, componentA_to_componentB={9: 5} - ) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={9: 5}) def test_too_small_A(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentA"): - LigandAtomMapping( - componentA=molA, componentB=molB, componentA_to_componentB={-2: 5} - ) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={-2: 5}) def test_too_large_B(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentB"): - LigandAtomMapping( - componentA=molA, componentB=molB, componentA_to_componentB={5: 11} - ) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: 11}) def test_too_small_B(self, molA, molB): with pytest.raises(ValueError, match="invalid index for ComponentB"): - LigandAtomMapping( - componentA=molA, componentB=molB, componentA_to_componentB={5: -1} - ) + LigandAtomMapping(componentA=molA, componentB=molB, componentA_to_componentB={5: -1}) class TestLigandAtomMapping(GufeTokenizableTestsMixin): diff --git a/gufe/tests/test_mapping.py b/gufe/tests/test_mapping.py index 85b034e0..620c2f55 100644 --- a/gufe/tests/test_mapping.py +++ b/gufe/tests/test_mapping.py @@ -41,18 +41,10 @@ def componentB_to_componentA(self): return {v: k for k, v in self._mapping} def componentA_unique(self): - return ( - i - for i in range(self._molA.to_rdkit().GetNumAtoms()) - if i not in self._mapping - ) + return (i for i in range(self._molA.to_rdkit().GetNumAtoms()) if i not in self._mapping) def componentB_unique(self): - return ( - i - for i in range(self._molB.to_rdkit().GetNumAtoms()) - if i not in self._mapping.values() - ) + return (i for i in range(self._molB.to_rdkit().GetNumAtoms()) if i not in self._mapping.values()) class TestMappingAbstractClass(GufeTokenizableTestsMixin): diff --git a/gufe/tests/test_mapping_visualization.py b/gufe/tests/test_mapping_visualization.py index 8f541b55..756ad13a 100644 --- a/gufe/tests/test_mapping_visualization.py +++ b/gufe/tests/test_mapping_visualization.py @@ -132,9 +132,7 @@ def benzene_phenol_mapping(benzene_transforms, maps): ], ], ) -def test_benzene_to_phenol_uniques( - molname, atoms, elems, bond_changes, bond_deletions, benzene_transforms, maps -): +def test_benzene_to_phenol_uniques(molname, atoms, elems, bond_changes, bond_deletions, benzene_transforms, maps): mol1 = benzene_transforms["benzene"] mol2 = benzene_transforms[molname] diff --git a/gufe/tests/test_models.py b/gufe/tests/test_models.py index 71b3353a..887cbdd7 100644 --- a/gufe/tests/test_models.py +++ b/gufe/tests/test_models.py @@ -9,11 +9,7 @@ import pytest from openff.units import unit -from gufe.settings.models import ( - OpenMMSystemGeneratorFFSettings, - Settings, - ThermoSettings, -) +from gufe.settings.models import OpenMMSystemGeneratorFFSettings, Settings, ThermoSettings def test_model_schema(): diff --git a/gufe/tests/test_proteincomponent.py b/gufe/tests/test_proteincomponent.py index 2ac80eea..3efce664 100644 --- a/gufe/tests/test_proteincomponent.py +++ b/gufe/tests/test_proteincomponent.py @@ -48,12 +48,8 @@ def assert_same_pdb_lines(in_file_path, out_file_path): if must_close: out_file.close() - in_lines = [ - l for l in in_lines if not l.startswith(("REMARK", "CRYST", "# Created with")) - ] - out_lines = [ - l for l in out_lines if not l.startswith(("REMARK", "CRYST", "# Created with")) - ] + in_lines = [l for l in in_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))] + out_lines = [l for l in out_lines if not l.startswith(("REMARK", "CRYST", "# Created with"))] assert in_lines == out_lines @@ -147,9 +143,7 @@ def _test_file_output(self, input_path, output_path, input_type, output_func): assert_same_pdb_lines(in_file_path=str(input_path), out_file_path=output_path) - @pytest.mark.parametrize( - "input_type", ["filename", "Path", "StringIO", "TextIOWrapper"] - ) + @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"]) def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, input_type): p = self.cls.from_pdbx_file(str(PDBx_181L_openMMClean_path), name="Bob") out_file_name = "tmp_181L_pdbx.cif" @@ -162,9 +156,7 @@ def test_to_pdbx_file(self, PDBx_181L_openMMClean_path, tmp_path, input_type): output_func=p.to_pdbx_file, ) - @pytest.mark.parametrize( - "input_type", ["filename", "Path", "StringIO", "TextIOWrapper"] - ) + @pytest.mark.parametrize("input_type", ["filename", "Path", "StringIO", "TextIOWrapper"]) def test_to_pdb_input_types(self, PDB_181L_OpenMMClean_path, tmp_path, input_type): p = self.cls.from_pdb_file(str(PDB_181L_OpenMMClean_path), name="Bob") @@ -192,9 +184,7 @@ def test_to_pdb_round_trip(self, in_pdb_path, tmp_path): out_ref_file_name = "tmp_" + in_pdb_path + "_openmm_ref.pdb" out_ref_file = tmp_path / out_ref_file_name - pdbfile.PDBFile.writeFile( - openmm_pdb.topology, openmm_pdb.positions, file=open(str(out_ref_file), "w") - ) + pdbfile.PDBFile.writeFile(openmm_pdb.topology, openmm_pdb.positions, file=open(str(out_ref_file), "w")) assert_same_pdb_lines(in_file_path=str(out_ref_file), out_file_path=out_file) def test_io_pdb_comparison(self, PDB_181L_OpenMMClean_path, tmp_path): diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index ccfc1546..b6675e2b 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -60,9 +60,7 @@ def _execute(ctx, *, simulations, **inputs): output = [s.outputs["log"] for s in simulations] output.append("assembling_results") - key_results = { - str(s.inputs["window"]): s.outputs["key_result"] for s in simulations - } + key_results = {str(s.inputs["window"]): s.outputs["key_result"] for s in simulations} return dict(log=output, key_results=key_results) @@ -136,23 +134,17 @@ def _create( # create several units that would each run an independent simulation simulations: list[ProtocolUnit] = [ - SimulationUnit( - settings=self.settings, name=f"sim {i}", window=i, initialization=alpha - ) + SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha) for i in range(self.settings.n_repeats) # type: ignore ] # gather results from simulations, finalize outputs - omega = FinishUnit( - settings=self.settings, name="the end", simulations=simulations - ) + omega = FinishUnit(settings=self.settings, name="the end", simulations=simulations) # return all `ProtocolUnit`s we created return [alpha, *simulations, omega] - def _gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> dict[str, Any]: outputs = defaultdict(list) for pdr in protocol_dag_results: @@ -190,10 +182,7 @@ def _create( # create several units that would each run an independent simulation simulations: list[ProtocolUnit] = [ - SimulationUnit( - settings=self.settings, name=f"sim {i}", window=i, initialization=alpha - ) - for i in range(21) + SimulationUnit(settings=self.settings, name=f"sim {i}", window=i, initialization=alpha) for i in range(21) ] # introduce a broken ProtocolUnit @@ -207,9 +196,7 @@ def _create( ) # gather results from simulations, finalize outputs - omega = FinishUnit( - settings=self.settings, name="the end", simulations=simulations - ) + omega = FinishUnit(settings=self.settings, name="the end", simulations=simulations) # return all `ProtocolUnit`s we created return [alpha, *simulations, omega] @@ -240,9 +227,7 @@ def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir): scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - dagresult: ProtocolDAGResult = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch - ) + dagresult: ProtocolDAGResult = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) return protocol, dag, dagresult @@ -279,9 +264,7 @@ def test_dag_execute(self, protocol_dag): # gather SimulationUnits simulationresults = [ - dagresult.unit_to_result(pu) - for pu in dagresult.protocol_units - if isinstance(pu, SimulationUnit) + dagresult.unit_to_result(pu) for pu in dagresult.protocol_units if isinstance(pu, SimulationUnit) ] # check that we have dependency information in results @@ -291,19 +274,13 @@ def test_dag_execute(self, protocol_dag): assert len(dagresult.graph) == 23 # check that each simulation has its own shared directory - assert len({i.outputs["shared"] for i in simulationresults}) == len( - simulationresults - ) + assert len({i.outputs["shared"] for i in simulationresults}) == len(simulationresults) # check that each simulation has its own scratch directory - assert len({i.outputs["scratch"] for i in simulationresults}) == len( - simulationresults - ) + assert len({i.outputs["scratch"] for i in simulationresults}) == len(simulationresults) # check that shared and scratch not the same for each simulation - assert all( - [i.outputs["scratch"] != i.outputs["shared"] for i in simulationresults] - ) + assert all([i.outputs["scratch"] != i.outputs["shared"] for i in simulationresults]) def test_terminal_units(self, protocol_dag): prot, dag, res = protocol_dag @@ -333,9 +310,7 @@ def test_dag_execute_failure(self, protocol_dag_broken): assert len(succeeded_units) > 0 - def test_dag_execute_failure_raise_error( - self, solvated_ligand, vacuum_ligand, tmpdir - ): + def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -371,15 +346,11 @@ def test_create_execute_gather(self, protocol_dag): assert protocolresult.get_estimate() == 95500.0 - def test_deprecation_warning_on_dict_mapping( - self, instance, vacuum_ligand, solvated_ligand - ): + def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand): lig = solvated_ligand.components["ligand"] mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={}) - with pytest.warns( - DeprecationWarning, match="mapping input as a dict is deprecated" - ): + with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"): instance.create( stateA=solvated_ligand, stateB=vacuum_ligand, @@ -474,10 +445,7 @@ def test_protocol_unit_failures(self, instance: ProtocolDAGResult): def test_protocol_unit_successes(self, instance: ProtocolDAGResult): assert len(instance.protocol_unit_successes) == 23 - assert all( - isinstance(i, ProtocolUnitResult) - for i in instance.protocol_unit_successes - ) + assert all(isinstance(i, ProtocolUnitResult) for i in instance.protocol_unit_successes) class TestProtocolDAGResultFailure(ProtocolDAGTestsMixin): cls = ProtocolDAGResult @@ -493,12 +461,7 @@ def test_protocol_unit_failures(self, instance: ProtocolDAGResult): # protocolunitfailures should have no dependents for puf in instance.protocol_unit_failures: - assert all( - [ - puf not in pu.dependencies - for pu in instance.protocol_unit_results - ] - ) + assert all([puf not in pu.dependencies for pu in instance.protocol_unit_results]) for node in instance.result_graph.nodes: with pytest.raises(KeyError): @@ -578,8 +541,7 @@ def _gather(self, dag_results): return { "vals": list( itertools.chain.from_iterable( - (d.outputs["local"] for d in dag.protocol_unit_results) - for dag in dag_results + (d.outputs["local"] for d in dag.protocol_unit_results) for dag in dag_results ) ), } @@ -593,12 +555,8 @@ def protocol(self): @pytest.fixture() def dag(self, protocol): return protocol.create( - stateA=ChemicalSystem( - components={"solvent": gufe.SolventComponent(positive_ion="Na")} - ), - stateB=ChemicalSystem( - components={"solvent": gufe.SolventComponent(positive_ion="Li")} - ), + stateA=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Na")}), + stateB=ChemicalSystem(components={"solvent": gufe.SolventComponent(positive_ion="Li")}), mapping=None, ) @@ -613,9 +571,7 @@ def test_gather(self, protocol, dag, tmpdir): scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - dag_result = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch - ) + dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) assert dag_result.ok() @@ -633,9 +589,7 @@ def test_terminal_units(self, protocol, dag, tmpdir): scratch.mkdir(parents=True) # we have no dependencies, so this should be all three Unit results - dag_result = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch - ) + dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) terminal_results = dag_result.terminal_protocol_unit_results diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index fb3a963d..5c898492 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -52,9 +52,7 @@ def _defaults(cls): return {} def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]: - return [ - WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore - ] + return [WriterUnit(identity=i) for i in range(self.settings.n_repeats)] # type: ignore def _gather(self, results): return {} @@ -94,9 +92,7 @@ def test_execute_dag(tmpdir, keep_shared, keep_scratch, writefile_dag): # will have produced 4 files in scratch and shared directory for pu in writefile_dag.protocol_units: identity = pu.inputs["identity"] - shared_file = os.path.join( - shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt" - ) + shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt") scratch_file = os.path.join( scratch, f"scratch_{str(pu.key)}_attempt_0", diff --git a/gufe/tests/test_protocolunit.py b/gufe/tests/test_protocolunit.py index 880f1e90..51a37a47 100644 --- a/gufe/tests/test_protocolunit.py +++ b/gufe/tests/test_protocolunit.py @@ -3,12 +3,7 @@ import pytest -from gufe.protocols.protocolunit import ( - Context, - ProtocolUnit, - ProtocolUnitFailure, - ProtocolUnitResult, -) +from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult from gufe.tests.test_tokenization import GufeTokenizableTestsMixin diff --git a/gufe/tests/test_serialization_migration.py b/gufe/tests/test_serialization_migration.py index b7057d28..69d71d4f 100644 --- a/gufe/tests/test_serialization_migration.py +++ b/gufe/tests/test_serialization_migration.py @@ -263,8 +263,4 @@ class TestNestedKeyMoved(MigrationTester): @pytest.fixture def instance(self): - return self.cls( - GrandparentSettings( - son=SonSettings(), daughter=DaughterSettings(daughter_child=10) - ) - ) + return self.cls(GrandparentSettings(son=SonSettings(), daughter=DaughterSettings(daughter_child=10))) diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 159a2b40..2e64623c 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -368,7 +368,6 @@ def test_from_dict_missing_hybridization(self, phenol): for atom_hybrid, atom_no_hybrid in zip(phenol.to_rdkit().GetAtoms(), new_phenol.to_rdkit().GetAtoms()): assert atom_hybrid.GetHybridization() != atom_no_hybrid.GetHybridization() - @pytest.mark.skipif(not HAS_OFFTK, reason="no openff toolkit available") def test_deserialize_roundtrip(self, toluene, phenol): roundtrip = SmallMoleculeComponent.from_dict(phenol.to_dict()) diff --git a/gufe/tests/test_solvents.py b/gufe/tests/test_solvents.py index 70a8cecd..e8530777 100644 --- a/gufe/tests/test_solvents.py +++ b/gufe/tests/test_solvents.py @@ -43,17 +43,13 @@ def test_neq(): @pytest.mark.parametrize("conc", [0.0 * unit.molar, 1.75 * unit.molar]) def test_from_dict(conc): - s1 = SolventComponent( - positive_ion="Na", negative_ion="Cl", ion_concentration=conc, neutralize=False - ) + s1 = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc, neutralize=False) assert SolventComponent.from_dict(s1.to_dict()) == s1 def test_conc(): - s = SolventComponent( - positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar - ) + s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar) assert s.ion_concentration == unit.Quantity("1.75 M") @@ -68,15 +64,11 @@ def test_conc(): ) # negative conc def test_bad_conc(conc): with pytest.raises(ValueError): - _ = SolventComponent( - positive_ion="Na", negative_ion="Cl", ion_concentration=conc - ) + _ = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=conc) def test_solvent_charge(): - s = SolventComponent( - positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar - ) + s = SolventComponent(positive_ion="Na", negative_ion="Cl", ion_concentration=1.75 * unit.molar) assert s.total_charge is None diff --git a/gufe/tests/test_tokenization.py b/gufe/tests/test_tokenization.py index f6519bd2..fc2b49ce 100644 --- a/gufe/tests/test_tokenization.py +++ b/gufe/tests/test_tokenization.py @@ -279,9 +279,7 @@ def test_to_json_string(self): assert json.loads(raw_json, cls=JSON_HANDLER.decoder) == expected_key_chain def test_from_json_string(self): - recreated = self.cls.from_json( - content=json.dumps(self.expected_keyed_chain, cls=JSON_HANDLER.encoder) - ) + recreated = self.cls.from_json(content=json.dumps(self.expected_keyed_chain, cls=JSON_HANDLER.encoder)) assert recreated == self.cont assert recreated is self.cont @@ -292,10 +290,7 @@ def test_to_json_file(self, tmpdir): # tuples are converted to lists in JSON so fix the expected result to use lists expected_key_chain = [list(tok) for tok in self.expected_keyed_chain] - assert ( - json.load(file_path.open(mode="r"), cls=JSON_HANDLER.decoder) - == expected_key_chain - ) + assert json.load(file_path.open(mode="r"), cls=JSON_HANDLER.decoder) == expected_key_chain def test_from_json_file(self, tmpdir): file_path = tmpdir / "container.json" @@ -469,9 +464,7 @@ def test_token(self): def test_gufe_to_digraph(solvated_complex): graph = gufe_to_digraph(solvated_complex) - connected_objects = gufe_objects_from_shallow_dict( - solvated_complex.to_shallow_dict() - ) + connected_objects = gufe_objects_from_shallow_dict(solvated_complex.to_shallow_dict()) assert len(graph.nodes) == 4 assert len(graph.edges) == 3 diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py index 3aef1496..445ff651 100644 --- a/gufe/tests/test_transformation.py +++ b/gufe/tests/test_transformation.py @@ -60,9 +60,7 @@ def test_protocol(self, absolute_transformation, tmpdir): scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - protocoldagresult = execute_DAG( - protocoldag, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) protocolresult = tnf.gather([protocoldagresult]) @@ -85,14 +83,10 @@ def test_protocol_extend(self, absolute_transformation, tmpdir): scratch.mkdir(parents=True) protocoldag = tnf.create() - protocoldagresult = execute_DAG( - protocoldag, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) protocoldag2 = tnf.create(extends=protocoldagresult) - protocoldagresult2 = execute_DAG( - protocoldag2, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult2 = execute_DAG(protocoldag2, shared_basedir=shared, scratch_basedir=scratch) protocolresult = tnf.gather([protocoldagresult, protocoldagresult2]) @@ -133,16 +127,12 @@ def test_dump_load_roundtrip(self, absolute_transformation): recreated = Transformation.load(string) assert absolute_transformation == recreated - def test_deprecation_warning_on_dict_mapping( - self, solvated_ligand, solvated_complex - ): + def test_deprecation_warning_on_dict_mapping(self, solvated_ligand, solvated_complex): lig = solvated_complex.components["ligand"] # this mapping makes no sense, but it'll trigger the dep warning we want mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={}) - with pytest.warns( - DeprecationWarning, match="mapping input as a dict is deprecated" - ): + with pytest.warns(DeprecationWarning, match="mapping input as a dict is deprecated"): Transformation( solvated_complex, solvated_ligand, @@ -180,9 +170,7 @@ def test_protocol(self, complex_equilibrium, tmpdir): scratch = pathlib.Path("scratch") scratch.mkdir(parents=True) - protocoldagresult = execute_DAG( - protocoldag, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) protocolresult = ntnf.gather([protocoldagresult]) @@ -205,14 +193,10 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir): scratch.mkdir(parents=True) protocoldag = ntnf.create() - protocoldagresult = execute_DAG( - protocoldag, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult = execute_DAG(protocoldag, shared_basedir=shared, scratch_basedir=scratch) protocoldag2 = ntnf.create(extends=protocoldagresult) - protocoldagresult2 = execute_DAG( - protocoldag2, shared_basedir=shared, scratch_basedir=scratch - ) + protocoldagresult2 = execute_DAG(protocoldag2, shared_basedir=shared, scratch_basedir=scratch) protocolresult = ntnf.gather([protocoldagresult, protocoldagresult2]) @@ -223,9 +207,7 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir): def test_equality(self, complex_equilibrium, solvated_ligand, solvated_complex): s = DummyProtocol.default_settings() s.n_repeats = 4031 - different_protocol_settings = NonTransformation( - solvated_complex, protocol=DummyProtocol(settings=s) - ) + different_protocol_settings = NonTransformation(solvated_complex, protocol=DummyProtocol(settings=s)) assert complex_equilibrium != different_protocol_settings identical = NonTransformation( diff --git a/gufe/tokenization.py b/gufe/tokenization.py index 138b500d..9ace1bf0 100644 --- a/gufe/tokenization.py +++ b/gufe/tokenization.py @@ -195,8 +195,7 @@ def old_key_removed(dct, old_key, should_warn): # TODO: this should be put elsewhere so that the warning can be more # meaningful (somewhere that knows what class we're recreating) warnings.warn( - f"Outdated serialization: '{old_key}', with value " - f"'{dct[old_key]}' is no longer used in this object" + f"Outdated serialization: '{old_key}', with value " f"'{dct[old_key]}' is no longer used in this object" ) del dct[old_key] @@ -619,9 +618,7 @@ def copy_with_replacements(self, **replacements): """ dct = self._to_dict() if invalid := set(replacements) - set(dct): - raise TypeError( - f"Invalid replacement keys: {invalid}. " f"Allowed keys are: {set(dct)}" - ) + raise TypeError(f"Invalid replacement keys: {invalid}. " f"Allowed keys are: {set(dct)}") dct.update(replacements) return self._from_dict(dct) @@ -684,9 +681,7 @@ def to_json(self, file: Optional[PathLike | TextIO] = None) -> None | str: return None @classmethod - def from_json( - cls, file: Optional[PathLike | TextIO] = None, content: Optional[str] = None - ): + def from_json(cls, file: Optional[PathLike | TextIO] = None, content: Optional[str] = None): """ Generate an instance from JSON keyed chain representation. @@ -705,9 +700,7 @@ def from_json( """ if content is not None and file is not None: - raise ValueError( - "Cannot specify both `content` and `file`; only one input allowed" - ) + raise ValueError("Cannot specify both `content` and `file`; only one input allowed") elif content is None and file is None: raise ValueError("Must specify either `content` and `file` for JSON input") @@ -741,9 +734,7 @@ def token(self) -> str: return self.split("-")[1] -def gufe_objects_from_shallow_dict( - obj: Union[list, dict, GufeTokenizable] -) -> list[GufeTokenizable]: +def gufe_objects_from_shallow_dict(obj: Union[list, dict, GufeTokenizable]) -> list[GufeTokenizable]: """Find GufeTokenizables within a shallow dict. This function recursively looks through the list/dict structures encoding @@ -768,16 +759,10 @@ def gufe_objects_from_shallow_dict( return [obj] elif isinstance(obj, list): - return list( - chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj]) - ) + return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj])) elif isinstance(obj, dict): - return list( - chain.from_iterable( - [gufe_objects_from_shallow_dict(item) for item in obj.values()] - ) - ) + return list(chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj.values()])) return [] @@ -906,8 +891,7 @@ def gufe_to_keyed_chain_rep( """ key_and_keyed_dicts = [ - (str(gt.key), gt.to_keyed_dict()) - for gt in nx.topological_sort(gufe_to_digraph(gufe_object)) + (str(gt.key), gt.to_keyed_dict()) for gt in nx.topological_sort(gufe_to_digraph(gufe_object)) ][::-1] return key_and_keyed_dicts @@ -932,9 +916,7 @@ def __getitem__(self, index): # TOKENIZABLE_REGISTRY: Dict[str, weakref.ref[GufeTokenizable]] = {} -TOKENIZABLE_REGISTRY: weakref.WeakValueDictionary[str, GufeTokenizable] = ( - weakref.WeakValueDictionary() -) +TOKENIZABLE_REGISTRY: weakref.WeakValueDictionary[str, GufeTokenizable] = weakref.WeakValueDictionary() """Registry of tokenizable objects. Used to avoid duplication of tokenizable `gufe` objects in memory when @@ -969,8 +951,7 @@ def is_gufe_key_dict(dct: Any): def import_qualname(modname: str, qualname: str, remappings=REMAPPED_CLASSES): if (qualname is None) or (modname is None): raise ValueError( - "`__qualname__` or `__module__` cannot be None; " - f"unable to identify object {modname}.{qualname}" + "`__qualname__` or `__module__` cannot be None; " f"unable to identify object {modname}.{qualname}" ) if (modname, qualname) in remappings: @@ -1018,16 +999,10 @@ def modify_dependencies(obj: Union[dict, list], modifier, is_mine, mode, top=Tru obj = modifier(obj) if isinstance(obj, dict): - obj = { - key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False) - for key, value in obj.items() - } + obj = {key: modify_dependencies(value, modifier, is_mine, mode=mode, top=False) for key, value in obj.items()} elif isinstance(obj, list): - obj = [ - modify_dependencies(item, modifier, is_mine, mode=mode, top=False) - for item in obj - ] + obj = [modify_dependencies(item, modifier, is_mine, mode=mode, top=False) for item in obj] if is_mine(obj) and not top and mode == "decode": obj = modifier(obj) @@ -1044,9 +1019,7 @@ def to_dict(obj: GufeTokenizable) -> dict: def dict_encode_dependencies(obj: GufeTokenizable) -> dict: - return modify_dependencies( - obj.to_shallow_dict(), to_dict, is_gufe_obj, mode="encode", top=True - ) + return modify_dependencies(obj.to_shallow_dict(), to_dict, is_gufe_obj, mode="encode", top=True) def key_encode_dependencies(obj: GufeTokenizable) -> dict: @@ -1088,14 +1061,10 @@ def _from_dict(dct: dict) -> GufeTokenizable: def dict_decode_dependencies(dct: dict) -> GufeTokenizable: - return from_dict( - modify_dependencies(dct, from_dict, is_gufe_dict, mode="decode", top=True) - ) + return from_dict(modify_dependencies(dct, from_dict, is_gufe_dict, mode="decode", top=True)) -def key_decode_dependencies( - dct: dict, registry=TOKENIZABLE_REGISTRY -) -> GufeTokenizable: +def key_decode_dependencies(dct: dict, registry=TOKENIZABLE_REGISTRY) -> GufeTokenizable: # this version requires that all dependent objects are already registered # responsibility of the storage system that uses this to do so dct = modify_dependencies( diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py index e7520e2f..e1e4dd8b 100644 --- a/gufe/transformations/transformation.py +++ b/gufe/transformations/transformation.py @@ -25,9 +25,7 @@ def __init__( stateA: ChemicalSystem, stateB: ChemicalSystem, protocol: Protocol, - mapping: Optional[ - Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]] - ] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]] = None, name: Optional[str] = None, ): r"""Two chemical states with a method for estimating free energy difference @@ -52,10 +50,7 @@ def __init__( """ if isinstance(mapping, dict): warnings.warn( - ( - "mapping input as a dict is deprecated, " - "instead use either a single Mapping or list" - ), + ("mapping input as a dict is deprecated, " "instead use either a single Mapping or list"), DeprecationWarning, ) mapping = list(mapping.values()) @@ -72,10 +67,7 @@ def _defaults(cls): return super()._defaults() def __repr__(self): - return ( - f"{self.__class__.__name__}(stateA={self.stateA}, " - f"stateB={self.stateB}, protocol={self.protocol})" - ) + return f"{self.__class__.__name__}(stateA={self.stateA}, " f"stateB={self.stateB}, protocol={self.protocol})" @property def stateA(self) -> ChemicalSystem: @@ -145,9 +137,7 @@ def create( transformation_key=self.key, ) - def gather( - self, protocol_dag_results: Iterable[ProtocolDAGResult] - ) -> ProtocolResult: + def gather(self, protocol_dag_results: Iterable[ProtocolDAGResult]) -> ProtocolResult: """ Gather multiple ``ProtocolDAGResult`` into a single ``ProtocolResult``. diff --git a/gufe/utils.py b/gufe/utils.py index f267591a..ea9ad861 100644 --- a/gufe/utils.py +++ b/gufe/utils.py @@ -28,8 +28,7 @@ def __init__(self, fn, mode=None, force_close=False): if isinstance(fn, filelikes): if mode is not None: warnings.warn( - f"mode='{mode}' specified with {fn.__class__.__name__}." - " User-specified mode will be ignored." + f"mode='{mode}' specified with {fn.__class__.__name__}." " User-specified mode will be ignored." ) self.to_open = None self.do_close = force_close diff --git a/gufe/vendor/pdb_file/PdbxContainers.py b/gufe/vendor/pdb_file/PdbxContainers.py index c1fbb798..bb8c6ca1 100644 --- a/gufe/vendor/pdb_file/PdbxContainers.py +++ b/gufe/vendor/pdb_file/PdbxContainers.py @@ -133,8 +133,7 @@ def replace(self, obj): def printIt(self, fh=sys.stdout, type="brief"): fh.write( - "+ %s container: %30s contains %4d categories\n" - % (self.getType(), self.getName(), len(self.__objNameList)) + "+ %s container: %30s contains %4d categories\n" % (self.getType(), self.getName(), len(self.__objNameList)) ) for nm in self.__objNameList: fh.write("--------------------------------------------\n") @@ -187,10 +186,7 @@ def isAttribute(self): return False def printIt(self, fh=sys.stdout, type="brief"): - fh.write( - "Definition container: %30s contains %4d categories\n" - % (self.getName(), len(self.getObjNameList())) - ) + fh.write("Definition container: %30s contains %4d categories\n" % (self.getName(), len(self.getObjNameList()))) if self.isCategory(): fh.write("Definition type: category\n") elif self.isAttribute(): @@ -301,9 +297,7 @@ def __init__(self, name, attributeNameList=None, rowList=None): self.__dqWsRe = re.compile(r'("\s)|(\s")') # self.__intRe = re.compile(r"^[0-9]+$") - self.__floatRe = re.compile( - r"^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$" - ) + self.__floatRe = re.compile(r"^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$") # self.__dataTypeList = [ "DT_NULL_VALUE", @@ -586,10 +580,7 @@ def renameAttribute(self, curAttributeName, newAttributeName): def printIt(self, fh=sys.stdout): fh.write("--------------------------------------------\n") - fh.write( - " Category: %s attribute list length: %d\n" - % (self._name, len(self._attributeNameList)) - ) + fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList))) for at in self._attributeNameList: fh.write(f" Category: {self._name} attribute: {at}\n") @@ -599,10 +590,7 @@ def printIt(self, fh=sys.stdout): # if len(row) == len(self._attributeNameList): for ii, v in enumerate(row): - fh.write( - " %30s: %s ...\n" - % (self._attributeNameList[ii], str(v)[:30]) - ) + fh.write(" %30s: %s ...\n" % (self._attributeNameList[ii], str(v)[:30])) else: fh.write( "+WARNING - %s data length %d attribute name length %s mismatched\n" @@ -611,10 +599,7 @@ def printIt(self, fh=sys.stdout): def dumpIt(self, fh=sys.stdout): fh.write("--------------------------------------------\n") - fh.write( - " Category: %s attribute list length: %d\n" - % (self._name, len(self._attributeNameList)) - ) + fh.write(" Category: %s attribute list length: %d\n" % (self._name, len(self._attributeNameList))) for at in self._attributeNameList: fh.write(f" Category: {self._name} attribute: {at}\n") @@ -656,16 +641,12 @@ def __formatPdbx(self, inp): else: if self.__avoidEmbeddedQuoting: # change priority to choose double quoting where possible. - if not self.__dqRe.search(inp) and not self.__sqWsRe.search( - inp - ): + if not self.__dqRe.search(inp) and not self.__sqWsRe.search(inp): return ( self.__doubleQuotedList(inp), "DT_DOUBLE_QUOTED_STRING", ) - elif not self.__sqRe.search(inp) and not self.__dqWsRe.search( - inp - ): + elif not self.__sqRe.search(inp) and not self.__dqWsRe.search(inp): return ( self.__singleQuotedList(inp), "DT_SINGLE_QUOTED_STRING", @@ -785,15 +766,10 @@ def getValueFormatted(self, attributeName=None, rowIndex=None): if isinstance(attribute, str) and isinstance(rowI, int): try: - list, type = self.__formatPdbx( - self._rowList[rowI][self._attributeNameList.index(attribute)] - ) + list, type = self.__formatPdbx(self._rowList[rowI][self._attributeNameList.index(attribute)]) return "".join(list) except IndexError: - self.__lfh.write( - "attributeName %s rowI %r rowdata %r\n" - % (attributeName, rowI, self._rowList[rowI]) - ) + self.__lfh.write("attributeName {} rowI {!r} rowdata {!r}\n".format(attributeName, rowI, self._rowList[rowI])) raise IndexError raise TypeError(attribute) @@ -814,9 +790,7 @@ def getAttributeValueMaxLengthList(self, steps=1): def getFormatTypeList(self, steps=1): try: - curDataTypeList = [ - "DT_NULL_VALUE" for i in range(len(self._attributeNameList)) - ] + curDataTypeList = ["DT_NULL_VALUE" for i in range(len(self._attributeNameList))] for row in self._rowList[::steps]: for indx in range(len(self._attributeNameList)): val = row[indx] @@ -836,10 +810,7 @@ def getFormatTypeList(self, steps=1): ii = self.__dataTypeList.index(dt) curFormatTypeList.append(self.__formatTypeList[ii]) except: - self.__lfh.write( - "PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" - % (indx, row) - ) + self.__lfh.write("PdbxDataCategory(getFormatTypeList) ++Index error at index %d in row %r\n" % (indx, row)) return curFormatTypeList, curDataTypeList diff --git a/gufe/vendor/pdb_file/PdbxReader.py b/gufe/vendor/pdb_file/PdbxReader.py index 7dc24f84..96f161e0 100644 --- a/gufe/vendor/pdb_file/PdbxReader.py +++ b/gufe/vendor/pdb_file/PdbxReader.py @@ -175,9 +175,7 @@ def __parser(self, tokenizer, containerList): try: curRow = curCategory[0] except IndexError: - self.__syntaxError( - "Internal index error accessing category data" - ) + self.__syntaxError("Internal index error accessing category data") return # Check for duplicate attributes and add attribute to table. @@ -191,9 +189,7 @@ def __parser(self, tokenizer, containerList): tCat, tAtt, curQuotedString, curWord = next(tokenizer) if tCat is not None or (curQuotedString is None and curWord is None): - self.__syntaxError( - f"Missing data for item _{curCatName}.{curAttName}" - ) + self.__syntaxError(f"Missing data for item _{curCatName}.{curAttName}") if curWord is not None: # @@ -201,9 +197,7 @@ def __parser(self, tokenizer, containerList): # reservedWord, state = self.__getState(curWord) if reservedWord is not None: - self.__syntaxError( - "Unexpected reserved word: %s" % (reservedWord) - ) + self.__syntaxError("Unexpected reserved word: %s" % (reservedWord)) curRow.append(curWord) @@ -239,9 +233,7 @@ def __parser(self, tokenizer, containerList): try: curContainer.append(curCategory) except AttributeError: - self.__syntaxError( - "loop_ declaration outside of data_ block or save_ frame" - ) + self.__syntaxError("loop_ declaration outside of data_ block or save_ frame") return curCategory.appendAttribute(curAttName) @@ -266,10 +258,7 @@ def __parser(self, tokenizer, containerList): if reservedWord == "stop": return else: - self.__syntaxError( - "Unexpected reserved word after loop declaration: %s" - % (reservedWord) - ) + self.__syntaxError("Unexpected reserved word after loop declaration: %s" % (reservedWord)) # Read the table of data for this loop_ - while True: @@ -282,9 +271,7 @@ def __parser(self, tokenizer, containerList): elif curQuotedString is not None: curRow.append(curQuotedString) - curCatName, curAttName, curQuotedString, curWord = next( - tokenizer - ) + curCatName, curAttName, curQuotedString, curWord = next(tokenizer) # loop_ data processing ends if - diff --git a/gufe/vendor/pdb_file/element.py b/gufe/vendor/pdb_file/element.py index e66dc276..3b675233 100644 --- a/gufe/vendor/pdb_file/element.py +++ b/gufe/vendor/pdb_file/element.py @@ -130,9 +130,7 @@ def getByMass(mass): # since the last call), re-generate the ordered by-mass dict cache if Element._elements_by_mass is None: Element._elements_by_mass = OrderedDict() - for elem in sorted( - Element._elements_by_symbol.values(), key=lambda x: x.mass - ): + for elem in sorted(Element._elements_by_symbol.values(), key=lambda x: x.mass): Element._elements_by_mass[elem.mass.value_in_unit(daltons)] = elem diff = mass diff --git a/gufe/vendor/pdb_file/pdbfile.py b/gufe/vendor/pdb_file/pdbfile.py index a9f2131e..9eb26111 100644 --- a/gufe/vendor/pdb_file/pdbfile.py +++ b/gufe/vendor/pdb_file/pdbfile.py @@ -176,9 +176,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): resName = residue.get_name() if resName in PDBFile._residueNameReplacements: resName = PDBFile._residueNameReplacements[resName] - r = top.addResidue( - resName, c, str(residue.number), residue.insertion_code - ) + r = top.addResidue(resName, c, str(residue.number), residue.insertion_code) if resName in PDBFile._atomNameReplacements: atomReplacements = PDBFile._atomNameReplacements[resName] else: @@ -186,10 +184,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): processedAtomNames = set() for atom in residue.atoms_by_name.values(): atomName = atom.get_name() - if ( - atomName in processedAtomNames - or atom.residue_name != residue.get_name() - ): + if atomName in processedAtomNames or atom.residue_name != residue.get_name(): continue processedAtomNames.add(atomName) if atomName in atomReplacements: @@ -220,9 +215,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): element = elem.zinc elif len(residue) == 1 and upper.startswith("CA"): element = elem.calcium - elif upper.startswith("D") and any( - a.name == atomName[1:] for a in residue.iter_atoms() - ): + elif upper.startswith("D") and any(a.name == atomName[1:] for a in residue.iter_atoms()): pass # A Drude particle else: try: @@ -238,10 +231,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): for residue in chain.iter_residues(): processedAtomNames = set() for atom in residue.atoms_by_name.values(): - if ( - atom.get_name() in processedAtomNames - or atom.residue_name != residue.get_name() - ): + if atom.get_name() in processedAtomNames or atom.residue_name != residue.get_name(): continue processedAtomNames.add(atom.get_name()) pos = atom.get_position().value_in_unit(nanometers) @@ -261,10 +251,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): i = connect[0] for j in connect[1:]: if i in atomByNumber and j in atomByNumber: - if ( - atomByNumber[i].element is not None - and atomByNumber[j].element is not None - ): + if atomByNumber[i].element is not None and atomByNumber[j].element is not None: if ( atomByNumber[i].element.symbol not in metalElements and atomByNumber[j].element.symbol not in metalElements @@ -272,14 +259,12 @@ def __init__(self, file, extraParticleIdentifier="EP"): connectBonds.append((atomByNumber[i], atomByNumber[j])) elif ( atomByNumber[i].element.symbol in metalElements - and atomByNumber[j].residue.name - not in PDBFile._standardResidues + and atomByNumber[j].residue.name not in PDBFile._standardResidues ): connectBonds.append((atomByNumber[i], atomByNumber[j])) elif ( atomByNumber[j].element.symbol in metalElements - and atomByNumber[i].residue.name - not in PDBFile._standardResidues + and atomByNumber[i].residue.name not in PDBFile._standardResidues ): connectBonds.append((atomByNumber[i], atomByNumber[j])) else: @@ -288,10 +273,7 @@ def __init__(self, file, extraParticleIdentifier="EP"): # Only add bonds that don't already exist. existingBonds = set(top.bonds()) for bond in connectBonds: - if ( - bond not in existingBonds - and (bond[1], bond[0]) not in existingBonds - ): + if bond not in existingBonds and (bond[1], bond[0]) not in existingBonds: top.addBond(bond[0], bond[1]) existingBonds.add(bond) @@ -329,9 +311,7 @@ def getPositions(self, asNumpy=False, frame=0): def _loadNameReplacementTables(): """Load the list of atom and residue name replacements.""" if len(PDBFile._residueNameReplacements) == 0: - tree = etree.parse( - os.path.join(os.path.dirname(__file__), "data", "pdbNames.xml") - ) + tree = etree.parse(os.path.join(os.path.dirname(__file__), "data", "pdbNames.xml")) allResidues = {} proteinResidues = {} nucleicAcidResidues = {} @@ -512,32 +492,25 @@ def writeModel( symbol = atom.element.symbol else: symbol = extraParticleIdentifier - if ( - len(atom.name) < 4 - and atom.name[:1].isalpha() - and len(symbol) < 2 - ): + if len(atom.name) < 4 and atom.name[:1].isalpha() and len(symbol) < 2: atomName = " " + atom.name elif len(atom.name) > 4: atomName = atom.name[:4] else: atomName = atom.name coords = positions[posIndex] - line = ( - "%s%5s %-4s %3s %s%4s%1s %s%s%s 1.00 0.00 %2s " - % ( - recordName, - _formatIndex(atomIndex, 5), - atomName, - resName, - chainName, - resId, - resIC, - _format_83(coords[0]), - _format_83(coords[1]), - _format_83(coords[2]), - symbol, - ) + line = "%s%5s %-4s %3s %s%4s%1s %s%s%s 1.00 0.00 %2s " % ( + recordName, + _formatIndex(atomIndex, 5), + atomName, + resName, + chainName, + resId, + resIC, + _format_83(coords[0]), + _format_83(coords[1]), + _format_83(coords[2]), + symbol, ) if len(line) != 80: raise ValueError("Fixed width overflow detected") @@ -546,8 +519,7 @@ def writeModel( atomIndex += 1 if resIndex == len(residues) - 1: print( - "TER %5s %3s %s%4s" - % (_formatIndex(atomIndex, 5), resName, chainName, resId), + "TER %5s %3s %s%4s" % (_formatIndex(atomIndex, 5), resName, chainName, resId), file=file, ) atomIndex += 1 @@ -641,9 +613,7 @@ def _format_83(f): return "%8.3f" % f if -9999999 < f < 99999999: return ("%8.3f" % f)[:8] - raise ValueError( - 'coordinate "%s" could not be represented ' "in a width-8 field" % f - ) + raise ValueError('coordinate "%s" could not be represented ' "in a width-8 field" % f) def _formatIndex(index, places): diff --git a/gufe/vendor/pdb_file/pdbstructure.py b/gufe/vendor/pdb_file/pdbstructure.py index dbd9953f..e07c3e23 100644 --- a/gufe/vendor/pdb_file/pdbstructure.py +++ b/gufe/vendor/pdb_file/pdbstructure.py @@ -125,9 +125,7 @@ class PdbStructure: methods. """ - def __init__( - self, input_stream, load_all_models=False, extraParticleIdentifier="EP" - ): + def __init__(self, input_stream, load_all_models=False, extraParticleIdentifier="EP"): """Create a PDB model from a PDB file stream. Parameters @@ -465,8 +463,7 @@ def _add_atom(self, atom): else: # Residue name does not match # Only residue name does not match warnings.warn( - "WARNING: two consecutive residues with same number (%s, %s)" - % (atom, self._current_residue.atoms[-1]) + "WARNING: two consecutive residues with same number ({}, {})".format(atom, self._current_residue.atoms[-1]) ) self._add_residue( Residue( @@ -547,9 +544,7 @@ def _finalize(self): class Residue: - def __init__( - self, name, number, insertion_code=" ", primary_alternate_location_indicator=" " - ): + def __init__(self, name, number, insertion_code=" ", primary_alternate_location_indicator=" "): alt_loc = primary_alternate_location_indicator self.primary_location_id = alt_loc self.locations = {} @@ -567,9 +562,7 @@ def _add_atom(self, atom): """ """ alt_loc = atom.alternate_location_indicator if alt_loc not in self.locations: - self.locations[alt_loc] = Residue.Location( - alt_loc, atom.residue_name_with_spaces - ) + self.locations[alt_loc] = Residue.Location(alt_loc, atom.residue_name_with_spaces) assert atom.residue_number == self.number assert atom.insertion_code == self.insertion_code # Check whether this is an existing atom with another position @@ -581,9 +574,7 @@ def _add_atom(self, atom): "WARNING: duplicate atom (%s, %s)" % ( atom, - old_atom._pdb_string( - old_atom.serial_number, atom.alternate_location_indicator - ), + old_atom._pdb_string(old_atom.serial_number, atom.alternate_location_indicator), ) ) else: @@ -812,15 +803,12 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier="EP"): if ( pdbstructure._current_model is None or pdbstructure._current_model._current_chain is None - or pdbstructure._current_model._current_chain._current_residue - is None + or pdbstructure._current_model._current_chain._current_residue is None ): # This is the first residue in the model. self.residue_number = pdbstructure._next_residue_number else: - currentRes = ( - pdbstructure._current_model._current_chain._current_residue - ) + currentRes = pdbstructure._current_model._current_chain._current_residue if currentRes.name_with_spaces != self.residue_name_with_spaces: # The residue name has changed. self.residue_number = pdbstructure._next_residue_number @@ -839,9 +827,7 @@ def __init__(self, pdb_line, pdbstructure=None, extraParticleIdentifier="EP"): except: occupancy = 1.0 try: - temperature_factor = unit.Quantity( - float(pdb_line[60:66]), unit.angstroms**2 - ) + temperature_factor = unit.Quantity(float(pdb_line[60:66]), unit.angstroms**2) except: temperature_factor = unit.Quantity(0.0, unit.angstroms**2) self.locations = {} @@ -1045,9 +1031,7 @@ class Location: Inner class of Atom for holding alternate locations """ - def __init__( - self, alt_loc, position, occupancy, temperature_factor, residue_name - ): + def __init__(self, alt_loc, position, occupancy, temperature_factor, residue_name): self.alternate_location_indicator = alt_loc self.position = position self.occupancy = occupancy @@ -1096,9 +1080,7 @@ def _parse_atom_index(index): import time # Test atom line parsing - pdb_line = ( - "ATOM 2209 CB TYR A 299 6.167 22.607 20.046 1.00 8.12 C" - ) + pdb_line = "ATOM 2209 CB TYR A 299 6.167 22.607 20.046 1.00 8.12 C" a = Atom(pdb_line) assert a.record_name == "ATOM" assert a.serial_number == 2209 @@ -1124,16 +1106,12 @@ def _parse_atom_index(index): # misaligned residue name - bad try: - a = Atom( - "ATOM 2209 CB TYRA 299 6.167 22.607 20.046 1.00 8.12 C" - ) + a = Atom("ATOM 2209 CB TYRA 299 6.167 22.607 20.046 1.00 8.12 C") assert False except ValueError: pass # four character residue name -- not so bad - a = Atom( - "ATOM 2209 CB NTYRA 299 6.167 22.607 20.046 1.00 8.12 C" - ) + a = Atom("ATOM 2209 CB NTYRA 299 6.167 22.607 20.046 1.00 8.12 C") atom_count = 0 residue_count = 0 diff --git a/gufe/vendor/pdb_file/pdbxfile.py b/gufe/vendor/pdb_file/pdbxfile.py index 6edd05f2..e5d6433f 100644 --- a/gufe/vendor/pdb_file/pdbxfile.py +++ b/gufe/vendor/pdb_file/pdbxfile.py @@ -144,9 +144,7 @@ def __init__(self, file): insertionCode = row[resInsertionCol] if insertionCode in (".", "?"): insertionCode = "" - if lastChainId != row[chainIdCol] or ( - altChainIdCol != -1 and lastAltChainId != row[altChainIdCol] - ): + if lastChainId != row[chainIdCol] or (altChainIdCol != -1 and lastAltChainId != row[altChainIdCol]): # The start of a new chain. chain = top.addChain(row[chainIdCol]) lastChainId = row[chainIdCol] @@ -189,9 +187,7 @@ def __init__(self, file): "Atom %s for model %s does not match the order of atoms for model %s" % (row[atomIdCol], model, models[0]) ) - self._positions[modelIndex].append( - np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])]) * 0.1 - ) + self._positions[modelIndex].append(np.array([float(row[xCol]), float(row[yCol]), float(row[zCol])]) * 0.1) for i in range(len(self._positions)): self._positions[i] = self._positions[i] * nanometers ## The atom positions read from the PDBx/mmCIF file. If the file contains multiple frames, these are the positions in the first frame. @@ -212,9 +208,7 @@ def __init__(self, file): float(row[cell.getAttributeIndex(attribute)]) * math.pi / 180.0 for attribute in ("angle_alpha", "angle_beta", "angle_gamma") ) - self.topology.setPeriodicBoxVectors( - computePeriodicBoxVectors(a, b, c, alpha, beta, gamma) - ) + self.topology.setPeriodicBoxVectors(computePeriodicBoxVectors(a, b, c, alpha, beta, gamma)) # Add bonds based on struct_conn records. @@ -239,10 +233,7 @@ def __init__(self, file): # Only add bonds that don't already exist. existingBonds = set(top.bonds()) for bond in connectBonds: - if ( - bond not in existingBonds - and (bond[1], bond[0]) not in existingBonds - ): + if bond not in existingBonds and (bond[1], bond[0]) not in existingBonds: top.addBond(bond[0], bond[1]) existingBonds.add(bond) diff --git a/gufe/vendor/pdb_file/topology.py b/gufe/vendor/pdb_file/topology.py index 5fa5e0a8..5d5e0d88 100644 --- a/gufe/vendor/pdb_file/topology.py +++ b/gufe/vendor/pdb_file/topology.py @@ -176,10 +176,7 @@ def addResidue(self, name, chain, id=None, insertionCode=""): Residue the newly created Residue """ - if ( - len(chain._residues) > 0 - and self._numResidues != chain._residues[-1].index + 1 - ): + if len(chain._residues) > 0 and self._numResidues != chain._residues[-1].index + 1: raise ValueError("All residues within a chain must be contiguous") if id is None: id = str(self._numResidues + 1) @@ -342,9 +339,7 @@ def createStandardBonds(self): if not Topology._hasLoadedStandardBonds: # Load the standard bond definitions. - Topology.loadBondDefinitions( - os.path.join(os.path.dirname(__file__), "data", "residues.xml") - ) + Topology.loadBondDefinitions(os.path.join(os.path.dirname(__file__), "data", "residues.xml")) Topology._hasLoadedStandardBonds = True for chain in self._chains: # First build a map of atom names to atoms. @@ -384,33 +379,22 @@ def createStandardBonds(self): bond_type = bond[2] bond_order = bond[3] - if ( - fromAtom in atomMaps[fromResidue] - and toAtom in atomMaps[toResidue] - ): + if fromAtom in atomMaps[fromResidue] and toAtom in atomMaps[toResidue]: # Histidine bond order correction depending on Protonation state of actual HIS # HD1-ND1-CE1=ND2 <-> ND1=CE1-NE2-HE2 - avoid "charged" resonance structure bond_atoms = (fromAtom, toAtom) - if ( - name == "HIS" - and "CE1" in bond_atoms - and any([N in bond_atoms for N in ["ND1", "NE2"]]) - ): + if name == "HIS" and "CE1" in bond_atoms and any([N in bond_atoms for N in ["ND1", "NE2"]]): atoms = atomMaps[i] ND1_protonated = "HD1" in atoms NE2_protonated = "HE2" in atoms - if ( - ND1_protonated and not NE2_protonated - ): # HD1-ND1-CE1=ND2 + if ND1_protonated and not NE2_protonated: # HD1-ND1-CE1=ND2 if "ND1" in bond_atoms: bond_order = 1 else: bond_order = 2 - elif ( - not ND1_protonated and NE2_protonated - ): # ND1=CE1-NE2-HE2 + elif not ND1_protonated and NE2_protonated: # ND1=CE1-NE2-HE2 if "ND1" in bond_atoms: bond_order = 2 else: @@ -458,9 +442,7 @@ def isDisulfideBonded(atom): sg2 = cyx[j]._atoms[atomNames[j].index("SG")] pos2 = positions[sg2.index] delta = [x - y for (x, y) in zip(pos1, pos2)] - distance = sqrt( - delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2] - ) + distance = sqrt(delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2]) if distance < candidate_distance and not isDisulfideBonded(sg2): candidate_distance = distance candidate_atom = sg2 @@ -521,27 +503,15 @@ def atoms(self): def bonds(self): """Iterate over all Bonds involving any atom in this residue.""" - return ( - bond - for bond in self.chain.topology.bonds() - if ((bond[0] in self._atoms) or (bond[1] in self._atoms)) - ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) or (bond[1] in self._atoms))) def internal_bonds(self): """Iterate over all internal Bonds.""" - return ( - bond - for bond in self.chain.topology.bonds() - if ((bond[0] in self._atoms) and (bond[1] in self._atoms)) - ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) and (bond[1] in self._atoms))) def external_bonds(self): """Iterate over all Bonds to external atoms.""" - return ( - bond - for bond in self.chain.topology.bonds() - if ((bond[0] in self._atoms) != (bond[1] in self._atoms)) - ) + return (bond for bond in self.chain.topology.bonds() if ((bond[0] in self._atoms) != (bond[1] in self._atoms))) def __len__(self): return len(self._atoms) diff --git a/gufe/vendor/pdb_file/unitcell.py b/gufe/vendor/pdb_file/unitcell.py index 51630270..bc65cec7 100644 --- a/gufe/vendor/pdb_file/unitcell.py +++ b/gufe/vendor/pdb_file/unitcell.py @@ -63,11 +63,7 @@ def computePeriodicBoxVectors(a_length, b_length, c_length, alpha, beta, gamma): a = [a_length, 0, 0] b = [b_length * math.cos(gamma), b_length * math.sin(gamma), 0] cx = c_length * math.cos(beta) - cy = ( - c_length - * (math.cos(alpha) - math.cos(beta) * math.cos(gamma)) - / math.sin(gamma) - ) + cy = c_length * (math.cos(alpha) - math.cos(beta) * math.cos(gamma)) / math.sin(gamma) cz = math.sqrt(c_length * c_length - cx * cx - cy * cy) c = [cx, cy, cz] diff --git a/gufe/visualization/mapping_visualization.py b/gufe/visualization/mapping_visualization.py index a0173293..ac110b8a 100644 --- a/gufe/visualization/mapping_visualization.py +++ b/gufe/visualization/mapping_visualization.py @@ -39,9 +39,7 @@ def _match_elements(mol1: Chem.Mol, idx1: int, mol2: Chem.Mol, idx2: int) -> boo return elem_mol1 == elem_mol2 -def _get_unique_bonds_and_atoms( - mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol -) -> dict: +def _get_unique_bonds_and_atoms(mapping: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol) -> dict: """ Given an input mapping, returns new atoms, element changes, and involved bonds. @@ -156,9 +154,7 @@ def _draw_molecules( # mol alignments if atom_mapping present for (i, j), atomMap in atom_mapping.items(): - AllChem.AlignMol( - copies[j], copies[i], atomMap=[(k, v) for v, k in atomMap.items()] - ) + AllChem.AlignMol(copies[j], copies[i], atomMap=[(k, v) for v, k in atomMap.items()]) # standard settings for our visualization d2d.drawOptions().useBWAtomPalette() @@ -177,9 +173,7 @@ def _draw_molecules( return d2d.GetDrawingText() -def draw_mapping( - mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None -): +def draw_mapping(mol1_to_mol2: dict[int, int], mol1: Chem.Mol, mol2: Chem.Mol, d2d=None): """ Method to visualise the atom map correspondence between two rdkit molecules given an input mapping. From 72f6bf57fae17298dfd67f7c1612b890467f6d38 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Tue, 3 Dec 2024 17:17:34 +0000 Subject: [PATCH 5/6] add warning --- gufe/components/smallmoleculecomponent.py | 4 ++++ gufe/tests/test_smallmoleculecomponent.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index 019212e3..3727f7bf 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import logging +import warnings # openff complains about oechem being missing, shhh logger = logging.getLogger("openff.toolkit") @@ -282,6 +283,9 @@ def _from_dict(cls, d: dict): try: a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) except IndexError: + warnings.warn("The atom hybridization data was not found and has been set to unspecified. This can be" + " fixed by recreating the SmallMoleculeComponent from the rdkit molecule after running " + "sanitization.") pass em.AddAtom(a) diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 2e64623c..e224227d 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -361,7 +361,8 @@ def test_from_dict_missing_hybridization(self, phenol): # remove the hybridization atomic info which should be at index 7 new_atoms.append(tuple([atom_info for i, atom_info in enumerate(atom) if i != 7])) phenol_dict["atoms"] = new_atoms - new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) + with pytest.warns(match="The atom hybridization data was not found and has been set to unspecified."): + new_phenol = SmallMoleculeComponent.from_dict(phenol_dict) # they should be different objects due to the missing hybridization info assert new_phenol != phenol # make sure the rdkit objects are different From c14f17cd8864acc9ed9d36c62e4e72a3b1ae2194 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:19:30 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- gufe/components/smallmoleculecomponent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gufe/components/smallmoleculecomponent.py b/gufe/components/smallmoleculecomponent.py index 3727f7bf..814c473b 100644 --- a/gufe/components/smallmoleculecomponent.py +++ b/gufe/components/smallmoleculecomponent.py @@ -283,9 +283,11 @@ def _from_dict(cls, d: dict): try: a.SetHybridization(_INT_TO_HYBRIDIZATION[atom[7]]) except IndexError: - warnings.warn("The atom hybridization data was not found and has been set to unspecified. This can be" - " fixed by recreating the SmallMoleculeComponent from the rdkit molecule after running " - "sanitization.") + warnings.warn( + "The atom hybridization data was not found and has been set to unspecified. This can be" + " fixed by recreating the SmallMoleculeComponent from the rdkit molecule after running " + "sanitization." + ) pass em.AddAtom(a)