Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'MDAnalysis.analysis.align' parallelization #4738

Draft
wants to merge 18 commits into
base: develop
Choose a base branch
from
31 changes: 29 additions & 2 deletions package/MDAnalysis/analysis/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
from MDAnalysis.lib.log import ProgressBar
from ..due import due, Doi

from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger('MDAnalysis.analysis.align')

Expand Down Expand Up @@ -678,6 +678,12 @@ class AlignTraj(AnalysisBase):

"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "dask")

def __init__(self, mobile, reference, select='all', filename=None,
prefix='rmsfit_', weights=None,
tol_mass=0.1, match_atoms=True, strict=False, force=True, in_memory=False,
Expand Down Expand Up @@ -866,6 +872,12 @@ def rmsd(self):
warnings.warn(wmsg, DeprecationWarning)
return self.results.rmsd

def _get_aggregator(self):
return ResultsGroup(
lookup={
"rmsd": ResultsGroup.ndarray_hstack,
}
)

class AverageStructure(AnalysisBase):
"""RMS-align trajectory to a reference structure using a selection,
Expand Down Expand Up @@ -896,6 +908,12 @@ class AverageStructure(AnalysisBase):

"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return ("serial", "multiprocessing", "dask")

def __init__(self, mobile, reference=None, select='all', filename=None,
weights=None,
tol_mass=0.1, match_atoms=True, strict=False, force=True, in_memory=False,
Expand Down Expand Up @@ -1089,6 +1107,15 @@ def _conclude(self):
if not self._verbose:
logging.disable(logging.NOTSET)

def _get_aggregator(self):
return ResultsGroup(
lookup={
"universe": ResultsGroup.ndarray_vstack,
"positions": ResultsGroup.ndarray_vstack,
"rmsd": ResultsGroup.ndarray_vstack,
}
)

@property
def universe(self):
wmsg = ("The `universe` attribute was deprecated in MDAnalysis 2.0.0 "
Expand Down Expand Up @@ -1669,4 +1696,4 @@ def get_atoms_byres(g, match_mask=np.logical_not(mismatch_mask)):
logger.error(errmsg)
raise SelectionError(errmsg)

return ag1, ag2
return ag1, ag2
13 changes: 13 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from MDAnalysis.analysis.hydrogenbonds.hbond_analysis import (
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.align import AverageStructure, AlignTraj
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -141,3 +142,15 @@ def client_DSSP(request):
@pytest.fixture(scope='module', params=params_for_cls(HydrogenBondAnalysis))
def client_HydrogenBondAnalysis(request):
return request.param


# MDAnalysis.analysis.align

@pytest.fixture(scope="module", params=params_for_cls(AverageStructure))
def client_AverageStructure(request):
return request.param


@pytest.fixture(scope="module", params=params_for_cls(AlignTraj))
def client_AlignTraj(request):
return request.param
136 changes: 99 additions & 37 deletions testsuite/MDAnalysisTests/analysis/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,14 @@ def test_AlignTraj_outfile_default(self, universe, reference, tmpdir):
finally:
x._writer.close()

def test_AlignTraj_outfile_default_exists(self, universe, reference, tmpdir):
def test_AlignTraj_outfile_default_exists(
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
align.AlignTraj(universe, reference, filename=outfile).run()
align.AlignTraj(universe, reference, filename=outfile).run(
**client_AlignTraj
)
fitted = mda.Universe(PSF, outfile)

# ensure default file exists
Expand All @@ -262,25 +266,37 @@ def test_AlignTraj_outfile_default_exists(self, universe, reference, tmpdir):
with pytest.raises(IOError):
align.AlignTraj(fitted, reference, force=False)

def test_AlignTraj_step_works(self, universe, reference, tmpdir):
def test_AlignTraj_step_works(
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
# this shouldn't throw an exception
align.AlignTraj(universe, reference, filename=outfile).run(step=10)
align.AlignTraj(universe, reference, filename=outfile).run(
step=10, **client_AlignTraj
)

def test_AlignTraj_deprecated_attribute(self, universe, reference, tmpdir):
def test_AlignTraj_deprecated_attribute(
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference, filename=outfile).run(stop=2)
x = align.AlignTraj(universe, reference, filename=outfile).run(
stop=2, **client_AlignTraj
)

wmsg = "The `rmsd` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(x.rmsd, x.results.rmsd)

def test_AlignTraj(self, universe, reference, tmpdir):
def test_AlignTraj(
self, universe, reference, tmpdir, client_AlignTraj
):
reference.trajectory[-1]
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference, filename=outfile).run()
x = align.AlignTraj(universe, reference, filename=outfile).run(
**client_AlignTraj
)
fitted = mda.Universe(PSF, outfile)

assert_allclose(x.results.rmsd[0], 6.9290, rtol=0, atol=1.5e-3)
Expand All @@ -292,10 +308,14 @@ def test_AlignTraj(self, universe, reference, tmpdir):
self._assert_rmsd(reference, fitted, 0, 6.929083044751061)
self._assert_rmsd(reference, fitted, -1, 0.0)

def test_AlignTraj_weighted(self, universe, reference, tmpdir):
def test_AlignTraj_weighted(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference,
filename=outfile, weights='mass').run()
filename=outfile, weights='mass').run(
**client_AlignTraj
)
fitted = mda.Universe(PSF, outfile)
assert_allclose(x.results.rmsd[0], 0, rtol=0, atol=1.5e-3)
assert_allclose(x.results.rmsd[-1], 6.9033, rtol=0, atol=1.5e-3)
Expand All @@ -305,25 +325,35 @@ def test_AlignTraj_weighted(self, universe, reference, tmpdir):
self._assert_rmsd(reference, fitted, -1, 6.929083032629219,
weights=universe.atoms.masses)

def test_AlignTraj_custom_weights(self, universe, reference, tmpdir):
def test_AlignTraj_custom_weights(
self, universe, reference, tmpdir, client_AlignTraj
):
weights = np.zeros(universe.atoms.n_atoms)
ca = universe.select_atoms('name CA')
weights[ca.indices] = 1

outfile = str(tmpdir.join('align_test.dcd'))

x = align.AlignTraj(universe, reference,
filename=outfile, select='name CA').run()
filename=outfile, select='name CA').run(
**client_AlignTraj
)
x_weights = align.AlignTraj(universe, reference,
filename=outfile, weights=weights).run()
filename=outfile, weights=weights).run(
**client_AlignTraj
)

assert_allclose(x.results.rmsd, x_weights.results.rmsd, rtol=0, atol=1.5e-7)

def test_AlignTraj_custom_mass_weights(self, universe, reference, tmpdir):
def test_AlignTraj_custom_mass_weights(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
x = align.AlignTraj(universe, reference,
filename=outfile,
weights=reference.atoms.masses).run()
weights=reference.atoms.masses).run(
**client_AlignTraj
)
fitted = mda.Universe(PSF, outfile)
assert_allclose(x.results.rmsd[0], 0, rtol=0, atol=1.5e-3)
assert_allclose(x.results.rmsd[-1], 6.9033, rtol=0, atol=1.5e-3)
Expand All @@ -333,18 +363,24 @@ def test_AlignTraj_custom_mass_weights(self, universe, reference, tmpdir):
self._assert_rmsd(reference, fitted, -1, 6.929083032629219,
weights=universe.atoms.masses)

def test_AlignTraj_partial_fit(self, universe, reference, tmpdir):
def test_AlignTraj_partial_fit(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
# fitting on a partial selection should still write the whole topology
align.AlignTraj(universe, reference, select='resid 1-20',
filename=outfile, weights='mass').run()
filename=outfile, weights='mass').run(
**client_AlignTraj
)
mda.Universe(PSF, outfile)

def test_AlignTraj_in_memory(self, universe, reference, tmpdir):
def test_AlignTraj_in_memory(
self, universe, reference, tmpdir, client_AlignTraj
):
outfile = str(tmpdir.join('align_test.dcd'))
reference.trajectory[-1]
x = align.AlignTraj(universe, reference, filename=outfile,
in_memory=True).run()
in_memory=True).run(**client_AlignTraj)
assert x.filename is None
assert_allclose(x.results.rmsd[0], 6.9290, rtol=0, atol=1.5e-3)
assert_allclose(x.results.rmsd[-1], 5.2797e-07, rtol=0, atol=1.5e-3)
Expand All @@ -353,15 +389,17 @@ def test_AlignTraj_in_memory(self, universe, reference, tmpdir):
self._assert_rmsd(reference, universe, 0, 6.929083044751061)
self._assert_rmsd(reference, universe, -1, 0.0)

def test_AlignTraj_writer_kwargs(self, universe, reference, tmpdir):
def test_AlignTraj_writer_kwargs(
self, universe, reference, tmpdir, client_AlignTraj
):
# Issue 4564
writer_kwargs = dict(precision=2)
with tmpdir.as_cwd():
aligner = align.AlignTraj(universe, reference,
select='protein and name CA',
filename='aligned_traj.xtc',
writer_kwargs=writer_kwargs,
in_memory=False).run()
in_memory=False).run(**client_AlignTraj)
assert_equal(aligner._writer.precision, 2)

def _assert_rmsd(self, reference, fitted, frame, desired, weights=None):
Expand Down Expand Up @@ -427,9 +465,13 @@ def universe(self):
def reference(self):
return mda.Universe(PSF, CRD)

def test_average_structure_deprecated_attrs(self, universe, reference):
def test_average_structure_deprecated_attrs(
self, universe, reference, client_AverageStructure
):
# Issue #3278 - remove in MDAnalysis 3.0.0
avg = align.AverageStructure(universe, reference).run(stop=2)
avg = align.AverageStructure(universe, reference).run(
stop=2, **client_AverageStructure
)

wmsg = "The `universe` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
Expand All @@ -444,44 +486,58 @@ def test_average_structure_deprecated_attrs(self, universe, reference):
with pytest.warns(DeprecationWarning, match=wmsg):
assert avg.rmsd == avg.results.rmsd

def test_average_structure(self, universe, reference):
def test_average_structure(
self, universe, reference, client_AverageStructure
):
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference)
avg = align.AverageStructure(universe, reference).run()
avg = align.AverageStructure(universe, reference).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref, rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_mass_weighted(self, universe, reference):
def test_average_structure_mass_weighted(
self, universe, reference, client_AverageStructure
):
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference, weights='mass')
avg = align.AverageStructure(universe, reference, weights='mass').run()
avg = align.AverageStructure(universe, reference, weights='mass').run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_select(self, universe, reference):
def test_average_structure_select(
self, universe, reference, client_AverageStructure
):
select = 'protein and name CA and resid 3-5'
ref, rmsd = _get_aligned_average_positions(self.ref_files, reference, select=select)
avg = align.AverageStructure(universe, reference, select=select).run()
avg = align.AverageStructure(universe, reference, select=select).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_no_ref(self, universe):
def test_average_structure_no_ref(self, universe, client_AverageStructure):
ref, rmsd = _get_aligned_average_positions(self.ref_files, universe)
avg = align.AverageStructure(universe).run()
avg = align.AverageStructure(universe).run(**client_AverageStructure)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_no_msf(self, universe):
avg = align.AverageStructure(universe).run()
def test_average_structure_no_msf(self, universe, client_AverageStructure):
avg = align.AverageStructure(universe).run(**client_AverageStructure)
assert not hasattr(avg, 'msf')

def test_mismatch_atoms(self, universe):
u = mda.Merge(universe.atoms[:10])
with pytest.raises(SelectionError):
align.AverageStructure(universe, u)

def test_average_structure_ref_frame(self, universe):
def test_average_structure_ref_frame(
self, universe, client_AverageStructure
):
ref_frame = 3
u = mda.Merge(universe.atoms)

Expand All @@ -492,13 +548,19 @@ def test_average_structure_ref_frame(self, universe):
# back to start
universe.trajectory[0]
ref, rmsd = _get_aligned_average_positions(self.ref_files, u)
avg = align.AverageStructure(universe, ref_frame=ref_frame).run()
avg = align.AverageStructure(universe, ref_frame=ref_frame).run(
**client_AverageStructure
)
assert_allclose(avg.results.universe.atoms.positions, ref,
rtol=0, atol=1.5e-4)
assert_allclose(avg.results.rmsd, rmsd, rtol=0, atol=1.5e-7)

def test_average_structure_in_memory(self, universe):
avg = align.AverageStructure(universe, in_memory=True).run()
def test_average_structure_in_memory(
self, universe, client_AverageStructure
):
avg = align.AverageStructure(universe, in_memory=True).run(
**client_AverageStructure
)
reference_coordinates = universe.trajectory.timeseries().mean(axis=1)
assert_allclose(avg.results.universe.atoms.positions,
reference_coordinates, rtol=0, atol=1.5e-4)
Expand Down
Loading