Skip to content

Commit

Permalink
Speed up json loader (#1163)
Browse files Browse the repository at this point in the history
* speed up json loader

* include port position when writing, adding test to load back compound with ports

* change name and behavior for include_ports/show_ports

* switch the default back

* fix unit test (update keyword)
  • Loading branch information
daico007 authored Feb 5, 2024
1 parent 9d20e3d commit f50fb61
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 33 deletions.
32 changes: 16 additions & 16 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,7 +1953,7 @@ def _visualize_py3dmol(
tmp_dir = tempfile.mkdtemp()
cloned.save(
os.path.join(tmp_dir, "tmp.mol2"),
show_ports=show_ports,
include_ports=show_ports,
overwrite=True,
parmed_kwargs={"infer_residues": False},
)
Expand Down Expand Up @@ -1984,7 +1984,7 @@ def _visualize_nglview(
Parameters
----------
show_ports : bool, optional, default=False
include_ports : bool, optional, default=False
Visualize Ports in addition to Particles
"""
nglview = import_("nglview")
Expand All @@ -2001,7 +2001,7 @@ def remove_digits(x):
tmp_dir = tempfile.mkdtemp()
self.save(
os.path.join(tmp_dir, "tmp.mol2"),
show_ports=show_ports,
include_ports=show_ports,
overwrite=True,
)
widget = nglview.show_file(os.path.join(tmp_dir, "tmp.mol2"))
Expand Down Expand Up @@ -2930,7 +2930,7 @@ def _energy_minimize_openbabel(
def save(
self,
filename,
show_ports=False,
include_ports=False,
forcefield_name=None,
forcefield_files=None,
forcefield_debug=False,
Expand All @@ -2952,7 +2952,7 @@ def save(
'hoomdxml', 'gsd', 'gro', 'top', 'lammps', 'lmp', 'mcf', 'pdb', 'xyz',
'json', 'mol2', 'sdf', 'psf'. See parmed/structure.py for more
information on savers.
show_ports : bool, optional, default=False
include_ports : bool, optional, default=False
Save ports contained within the compound.
forcefield_files : str, optional, default=None
Apply a forcefield to the output file using a forcefield provided
Expand Down Expand Up @@ -3024,7 +3024,7 @@ def save(
When saving the compound as a json, only the following arguments are
used:
* filename
* show_ports
* include_ports
See Also
--------
Expand All @@ -3039,7 +3039,7 @@ def save(
conversion.save(
self,
filename,
show_ports,
include_ports,
forcefield_name,
forcefield_files,
forcefield_debug,
Expand Down Expand Up @@ -3232,13 +3232,13 @@ def from_trajectory(
)

def to_trajectory(
self, show_ports=False, chains=None, residues=None, box=None
self, include_ports=False, chains=None, residues=None, box=None
):
"""Convert to an md.Trajectory and flatten the compound.
Parameters
----------
show_ports : bool, optional, default=False
include_ports : bool, optional, default=False
Include all port atoms when converting to trajectory.
chains : mb.Compound or list of mb.Compound
Chain types to add to the topology
Expand All @@ -3261,7 +3261,7 @@ def to_trajectory(
"""
return conversion.to_trajectory(
compound=self,
show_ports=show_ports,
include_ports=include_ports,
chains=chains,
residues=residues,
box=box,
Expand Down Expand Up @@ -3323,7 +3323,7 @@ def to_parmed(
box=None,
title="",
residues=None,
show_ports=False,
include_ports=False,
infer_residues=False,
infer_residues_kwargs={},
):
Expand All @@ -3341,7 +3341,7 @@ def to_parmed(
residues : str of list of str, optional, default=None
Labels of residues in the Compound. Residues are assigned by checking
against Compound.name.
show_ports : boolean, optional, default=False
include_ports : boolean, optional, default=False
Include all port atoms when converting to a `Structure`.
infer_residues : bool, optional, default=True
Attempt to assign residues based on the number of bonds and particles in
Expand All @@ -3364,7 +3364,7 @@ def to_parmed(
box=box,
title=title,
residues=residues,
show_ports=show_ports,
include_ports=include_ports,
infer_residues=infer_residues,
infer_residues_kwargs=infer_residues_kwargs,
)
Expand Down Expand Up @@ -3400,7 +3400,7 @@ def to_pybel(
box=None,
title="",
residues=None,
show_ports=False,
include_ports=False,
infer_residues=False,
):
"""Create a pybel.Molecule from a Compound.
Expand All @@ -3413,7 +3413,7 @@ def to_pybel(
residues : str of list of str
Labels of residues in the Compound. Residues are assigned by
checking against Compound.name.
show_ports : boolean, optional, default=False
include_ports : boolean, optional, default=False
Include all port atoms when converting to a `Structure`.
infer_residues : bool, optional, default=False
Attempt to assign residues based on names of children
Expand All @@ -3438,7 +3438,7 @@ def to_pybel(
box=box,
title=title,
residues=residues,
show_ports=show_ports,
include_ports=include_ports,
)

def to_smiles(self, backend="pybel"):
Expand Down
30 changes: 16 additions & 14 deletions mbuild/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ def from_gmso(
def save(
compound,
filename,
show_ports=False,
include_ports=False,
forcefield_name=None,
forcefield_files=None,
forcefield_debug=False,
Expand All @@ -990,7 +990,7 @@ def save(
'hoomdxml', 'gsd', 'gro', 'top', 'lammps', 'lmp', 'mcf', 'xyz', 'pdb',
'sdf', 'mol2', 'psf'. See parmed/structure.py for more information on
savers.
show_ports : bool, optional, default=False
include_ports : bool, optional, default=False
Save ports contained within the compound.
forcefield_files : str, optional, default=None
Apply a forcefield to the output file using a forcefield provided by the
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def save(
-----
When saving the compound as a json, only the following arguments are used:
- filename
- show_ports
- include_ports
See Also
--------
Expand All @@ -1068,7 +1068,9 @@ def save(
extension = os.path.splitext(filename)[-1]

if extension == ".json":
compound_to_json(compound, file_path=filename, include_ports=show_ports)
compound_to_json(
compound, file_path=filename, include_ports=include_ports
)
return

# Savers supported by mbuild.formats
Expand Down Expand Up @@ -1098,7 +1100,7 @@ def save(
structure = compound.to_parmed(
box=box,
residues=residues,
show_ports=show_ports,
include_ports=include_ports,
**parmed_kwargs,
)
# Apply a force field with foyer if specified
Expand Down Expand Up @@ -1301,7 +1303,7 @@ def to_parmed(
box=None,
title="",
residues=None,
show_ports=False,
include_ports=False,
infer_residues=False,
infer_residues_kwargs={},
):
Expand All @@ -1321,7 +1323,7 @@ def to_parmed(
residues : str of list of str, optional, default=None
Labels of residues in the Compound. Residues are assigned by checking
against Compound.name.
show_ports : boolean, optional, default=False
include_ports : boolean, optional, default=False
Include all port atoms when converting to a `Structure`.
infer_residues : bool, optional, default=False
Attempt to assign residues based on the number of bonds and particles in
Expand Down Expand Up @@ -1361,7 +1363,7 @@ def to_parmed(
atom_residue_map = dict()

# Loop through particles and add initialize ParmEd atoms
for atom in compound.particles(include_ports=show_ports):
for atom in compound.particles(include_ports=include_ports):
if atom.port_particle:
current_residue = port_residue
atom_residue_map[atom] = current_residue
Expand Down Expand Up @@ -1458,13 +1460,13 @@ def to_parmed(


def to_trajectory(
compound, show_ports=False, chains=None, residues=None, box=None
compound, include_ports=False, chains=None, residues=None, box=None
):
"""Convert to an md.Trajectory and flatten the compound.
Parameters
----------
show_ports : bool, optional, default=False
include_ports : bool, optional, default=False
Include all port atoms when converting to trajectory.
chains : mb.Compound or list of mb.Compound
Chain types to add to the topology
Expand All @@ -1485,7 +1487,7 @@ def to_trajectory(
_to_topology
"""
md = import_("mdtraj")
atom_list = [particle for particle in compound.particles(show_ports)]
atom_list = [particle for particle in compound.particles(include_ports)]

top = _to_topology(compound, atom_list, chains, residues)

Expand Down Expand Up @@ -1650,7 +1652,7 @@ def to_pybel(
box=None,
title="",
residues=None,
show_ports=False,
include_ports=False,
infer_residues=False,
):
"""Create a pybel.Molecule from a Compound.
Expand All @@ -1665,7 +1667,7 @@ def to_pybel(
residues : str of list of str
Labels of residues in the Compound. Residues are assigned by checking
against Compound.name.
show_ports : boolean, optional, default=False
include_ports : boolean, optional, default=False
Include all port atoms when converting to a `Structure`.
infer_residues : bool, optional, default=False
Attempt to assign residues based on names of children
Expand Down Expand Up @@ -1697,7 +1699,7 @@ def to_pybel(
compound_residue_map = dict()
atom_residue_map = dict()

for i, part in enumerate(compound.particles(include_ports=show_ports)):
for i, part in enumerate(compound.particles(include_ports=include_ports)):
if residues and part.name in residues:
current_residue = mol.NewResidue()
current_residue.SetName(part.name)
Expand Down
20 changes: 17 additions & 3 deletions mbuild/formats/json_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ele

import mbuild as mb
from mbuild.bond_graph import BondGraph
from mbuild.exceptions import MBuildError


Expand Down Expand Up @@ -56,6 +57,7 @@ def compound_from_json(json_file):
sub_cmpd = _dict_to_mb(sub_compound)
converted_dict[sub_compound["id"]] = sub_cmpd
sub_cmpd = converted_dict[sub_compound["id"]]
sub_cmpd.bond_graph = None

label_str = sub_compound["label"]
label_list = compound.get("label_list", {})
Expand All @@ -64,7 +66,12 @@ def compound_from_json(json_file):
parent_compound.labels[key] = list()
if sub_compound["id"] in vals:
parent_compound.labels[key].append(sub_cmpd)
parent_compound.add(sub_cmpd, label=label_str)
parent_compound.add(sub_cmpd, check_box_size=False, label=label_str)

parent.bond_graph = BondGraph()
parent.bond_graph.add_nodes_from(
[particle for particle in parent.particles()]
)

_add_ports(compound_dict, converted_dict)
_add_bonds(compound_dict, parent, converted_dict)
Expand Down Expand Up @@ -152,6 +159,7 @@ def _particle_info(cmpd, include_ports=False):
else:
port_info["anchor"] = None
port_info["label"] = None
port_info["pos"] = port.pos.tolist()
# Is this the most efficient way?
for key, val in cmpd.labels.items():
if (val == port) and val.port_particle:
Expand Down Expand Up @@ -236,15 +244,21 @@ def _add_ports(compound_dict, converted_dict):
for port in ports:
label_str = port["label"]
port_to_add = mb.Port(anchor=converted_dict[port["anchor"]])
converted_dict[compound["id"]].add(port_to_add, label_str)
if port.get("pos", None) is not None:
port_to_add.translate_to(port.get("pos"))
converted_dict[compound["id"]].add(
port_to_add, label_str, check_box_size=False
)
# Not necessary to add same port twice
compound["ports"] = None
ports = subcompound.get("ports", None)
if ports:
for port in ports:
label_str = port["label"]
port_to_add = mb.Port(anchor=converted_dict[port["anchor"]])
converted_dict[subcompound["id"]].add(port_to_add, label_str)
converted_dict[subcompound["id"]].add(
port_to_add, label_str, check_box_size=False
)
subcompound["ports"] = None


Expand Down
9 changes: 9 additions & 0 deletions mbuild/tests/test_json_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,12 @@ def test_float_64_position(self):
compound_to_json(ethane, "ethane.json", include_ports=True)
ethane_copy = compound_from_json("ethane.json")
assert np.allclose(ethane.xyz, ethane_copy.xyz, atol=10**-6)

def test_compound_with_port(self):
ch2 = mb.lib.moieties.CH2()
ch2.save("ch2.json", include_ports=True, overwrite=True)

loaded_ch2 = mb.load("ch2.json")
assert len(loaded_ch2.all_ports()) == 2
for port in loaded_ch2.all_ports():
assert port.separation == 0.07

0 comments on commit f50fb61

Please sign in to comment.