Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements to tokenization, deterministic sorting #76

Merged
merged 9 commits into from
Oct 12, 2022
5 changes: 1 addition & 4 deletions gufe/chemicalsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def __repr__(self):
f"{self.__class__.__name__}(name={self.name}, components={self.components})"
)

def __lt__(self, other):
return hash(self) < hash(other)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
Expand All @@ -96,7 +93,7 @@ def __hash__(self):
def _to_dict(self):
return {
"components": {
key: value for key, value in self.components.items()
key: value for key, value in sorted(self.components.items())
},
"box_vectors": self.box_vectors.tolist(),
"name": self.name,
Expand Down
4 changes: 2 additions & 2 deletions gufe/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def name(self):
return self._name

def _to_dict(self) -> dict:
return {"nodes": list(self.nodes),
"edges": list(self.edges),
return {"nodes": sorted(self.nodes),
"edges": sorted(self.edges),
"name": self.name}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions gufe/tests/test_chemicalsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_sorting(solvated_complex, solvated_ligand):
class TestChemicalSystem(GufeTokenizableTestsMixin):

cls = ChemicalSystem
key = "ChemicalSystem-e1cb9ce41e88ee474cf5b962c9388159"

@pytest.fixture
def instance(self, solv_comp, toluene_ligand_comp):
Expand Down
11 changes: 10 additions & 1 deletion gufe/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gufe import AlchemicalNetwork, ChemicalSystem, Transformation

from .test_protocol import DummyProtocol, DummyProtocolResult
from .test_tokenization import GufeTokenizableTestsMixin


@pytest.fixture
Expand Down Expand Up @@ -70,7 +71,15 @@ def benzene_variants_star_map(
)


class TestAlchemicalNetwork:
class TestAlchemicalNetwork(GufeTokenizableTestsMixin):

cls = AlchemicalNetwork
key = "AlchemicalNetwork-1229d9766b685039f7581c3207f68b22"

@pytest.fixture
def instance(self, benzene_variants_star_map):
return benzene_variants_star_map

def test_init(self, benzene_variants_star_map):
alnet = benzene_variants_star_map

Expand Down
1 change: 1 addition & 0 deletions gufe/tests/test_proteincomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def assert_same_pdb_lines(in_file_path, out_file_path):
class TestProteinComponent(GufeTokenizableTestsMixin):

cls = ProteinComponent
key = "ProteinComponent-d8c07b1d44e93c4cae31c9deadf1fec4"

@pytest.fixture
def instance(self, PDB_181L_path):
Expand Down
14 changes: 14 additions & 0 deletions gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def _create(
class TestProtocol(GufeTokenizableTestsMixin):

cls = DummyProtocol
key = "DummyProtocol-5660965464c9afdaac0ac4486a9566b3"

@pytest.fixture
def instance(self):
Expand Down Expand Up @@ -298,8 +299,14 @@ def test_graph(self, instance):
for neighbor in instance.graph.neighbors(node):
assert neighbor in node.dependencies

def test_key_stable(self, instance):
# for the DAG system, keys for `ProtocolUnit`s are based on UUIDs,
# so keys aren't stable up through `ProtocolDAG`s
pass

class TestProtocolDAG(ProtocolDAGTestsMixin):
cls = ProtocolDAG
key = "..."

@pytest.fixture
def instance(self, protocol_dag):
Expand All @@ -308,6 +315,7 @@ def instance(self, protocol_dag):

class TestProtocolDAGResult(ProtocolDAGTestsMixin):
cls = ProtocolDAGResult
key = "..."

@pytest.fixture
def instance(self, protocol_dag):
Expand Down Expand Up @@ -354,6 +362,7 @@ def test_protocol_unit_failures(self, instance: ProtocolDAGResult):

class TestProtocolDAGResultFailure(ProtocolDAGTestsMixin):
cls = ProtocolDAGResult
key = "..."

@pytest.fixture
def instance(self, protocol_dag_broken):
Expand Down Expand Up @@ -382,6 +391,7 @@ def test_protocol_unit_failure_traceback(self, instance: ProtocolDAGResult):

class TestProtocolUnit(GufeTokenizableTestsMixin):
cls = SimulationUnit
key = "..."

@pytest.fixture
def instance(self, vacuum_ligand, solvated_ligand):
Expand All @@ -399,6 +409,10 @@ def instance(self, vacuum_ligand, solvated_ligand):

return SimulationUnit(name=f"simulation", initialization=alpha)

def test_key_stable(self, instance):
# for the DAG system, keys for `ProtocolUnit`s are based on UUIDs,
# so keys aren't stable up through `ProtocolDAG`s
pass

class NoDepUnit(ProtocolUnit):
@staticmethod
Expand Down
1 change: 1 addition & 0 deletions gufe/tests/test_smallmoleculecomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_ensure_ofe_version():
class TestSmallMoleculeComponent(GufeTokenizableTestsMixin):

cls = SmallMoleculeComponent
key = "SmallMoleculeComponent-3a1b343b46ec93300bc74d83c133637a"

@pytest.fixture
def instance(self, named_ethane):
Expand Down
1 change: 1 addition & 0 deletions gufe/tests/test_solvents.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_bad_inputs(pos, neg):
class TestSolventComponent(GufeTokenizableTestsMixin):

cls = SolventComponent
key = "SolventComponent-187d235ef3c2035d8505083c8ad7d0a0"

@pytest.fixture
def instance(self):
Expand Down
6 changes: 6 additions & 0 deletions gufe/tests/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class GufeTokenizableTestsMixin(abc.ABC):

# set this to the `GufeTokenizable` subclass you are testing
cls: type[GufeTokenizable]
key: str

@pytest.fixture
def instance(self):
Expand Down Expand Up @@ -126,10 +127,14 @@ def test_to_shallow_dict_roundtrip(self, instance):
# include `np.nan`s
#assert ser == reser

def test_key_stable(self, instance):
assert self.key == instance.key


class TestGufeTokenizable(GufeTokenizableTestsMixin):

cls = Container
key = "Container-262ecded6cd03a619b99d667ded94c9e"

@pytest.fixture
def instance(self):
Expand Down Expand Up @@ -206,6 +211,7 @@ def test_notequal_different_type(self):
assert l1 != l2



class Outer:
class Inner:
pass
Expand Down
21 changes: 19 additions & 2 deletions gufe/tests/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gufe.protocols.protocoldag import execute

from .test_protocol import DummyProtocol, DummyProtocolResult
from .test_tokenization import GufeTokenizableTestsMixin


@pytest.fixture
Expand All @@ -24,7 +25,15 @@ def complex_equilibrium(solvated_complex):
return NonTransformation(solvated_complex, protocol=DummyProtocol(settings=None))


class TestTransformation:
class TestTransformation(GufeTokenizableTestsMixin):

cls = Transformation
key = "Transformation-432a08368b0d8779397177ec25058543"

@pytest.fixture
def instance(self, absolute_transformation):
return absolute_transformation

def test_init(self, absolute_transformation, solvated_ligand, solvated_complex):
tnf = absolute_transformation

Expand Down Expand Up @@ -72,7 +81,15 @@ def test_dict_roundtrip(self):
...


class TestNonTransformation:
class TestNonTransformation(GufeTokenizableTestsMixin):

cls = NonTransformation
key = "NonTransformation-7e7a724f1b41d03f0f00dcc876172bad"

@pytest.fixture
def instance(self, complex_equilibrium):
return complex_equilibrium

def test_init(self, complex_equilibrium, solvated_complex):

ntnf = complex_equilibrium
Expand Down
15 changes: 11 additions & 4 deletions gufe/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class GufeTokenizable(abc.ABC, metaclass=_ABCGufeClassMeta):
*across different Python sessions*.
"""
def __lt__(self, other):
return hash(self) < hash(other)
return self.key < other.key

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand All @@ -71,7 +71,7 @@ def _gufe_tokenize(self):
"""Return a list of normalized inputs for `gufe.base.tokenize`.

"""
return normalize(self.to_dict(include_defaults=False))
return normalize(self.to_keyed_dict(include_defaults=False))

@property
def key(self):
Expand Down Expand Up @@ -167,7 +167,7 @@ def from_dict(cls, dct: Dict):
"""
return dict_decode_dependencies(dct)

def to_keyed_dict(self) -> Dict:
def to_keyed_dict(self, include_defaults=True) -> Dict:
"""Generate keyed dict representation, with all referenced
`GufeTokenizable` objects given in keyed representations.

Expand All @@ -186,7 +186,14 @@ def to_keyed_dict(self) -> Dict:
:meth:`GufeTokenizable.to_shallow_dict`

"""
return key_encode_dependencies(self)
dct = key_encode_dependencies(self)

if not include_defaults:
for key, value in self.defaults.items():
if dct.get(key) == value:
dct.pop(key)

return dct

@classmethod
def from_keyed_dict(cls, dct: Dict):
Expand Down
3 changes: 0 additions & 3 deletions gufe/transformations/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def name(self):
"""User-specified for the transformation; used as part of its hash."""
return self._name

def __lt__(self, other):
return hash(self) < hash(other)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
Expand Down