Skip to content

Commit

Permalink
ChainReader now subclasses BaseReader (#3906)
Browse files Browse the repository at this point in the history
* make ChainReader subclass ReaderBase

* deal with kwargs in ChainReader properly

* add test for issue #4008

* stupid fucking linters

* changelog
  • Loading branch information
richardjgowers authored May 26, 2023
1 parent 5ddbc85 commit 7e9c6f8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 107 deletions.
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The rules for this file:
* 2.5.0

Fixes
* Fixed ChainReader performing transformations correctly across boundaries in the individual trajectories
(Issue #3657 #4008 PR #3906)
* Fix the `heavy` distance_type for water bridge analysis where distance
is not correctly assigned when more than one hydrogen is bonded to a
heavy atom (Issue #4040, PR #4066).
Expand Down
39 changes: 23 additions & 16 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,11 @@
:members:
"""
import abc
import numpy as np
import numbers
import copy
import warnings
import weakref
from typing import Union, Optional, List, Dict
from typing import Any, Union, Optional, List, Dict

from .timestep import Timestep
from . import core
Expand Down Expand Up @@ -614,15 +613,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False # do not suppress exceptions


class _Readermeta(type):
class _Readermeta(abc.ABCMeta):
"""Automatic Reader registration metaclass
.. versionchanged:: 1.0.0
Added _format_hint functionality
"""
# Auto register upon class creation
def __init__(cls, name, bases, classdict):
type.__init__(type, name, bases, classdict)
type.__init__(type, name, bases, classdict) # pylint: disable=non-parent-init-called
try:
fmt = asiterable(classdict['format'])
except KeyError:
Expand Down Expand Up @@ -664,14 +663,18 @@ class ProtoReader(IOBase, metaclass=_Readermeta):
#: The appropriate Timestep class, e.g.
#: :class:`MDAnalysis.coordinates.xdrfile.XTC.Timestep` for XTC.
_Timestep = Timestep
_transformations: list
_auxs: dict
_filename: Any
n_frames: int

def __init__(self):
# initialise list to store added auxiliary readers in
# subclasses should now call super
self._auxs = {}
self._transformations=[]

def __len__(self):
def __len__(self) -> int:
return self.n_frames

@classmethod
Expand All @@ -691,7 +694,7 @@ def parse_n_atoms(cls, filename, **kwargs):
raise NotImplementedError("{} cannot deduce the number of atoms"
"".format(cls.__name__))

def next(self):
def next(self) -> Timestep:
"""Forward one step to next frame."""
try:
ts = self._read_next_timestep()
Expand All @@ -706,22 +709,22 @@ def next(self):

return ts

def __next__(self):
def __next__(self) -> Timestep:
"""Forward one step to next frame when using the `next` builtin."""
return self.next()

def rewind(self):
def rewind(self) -> Timestep:
"""Position at beginning of trajectory"""
self._reopen()
self.next()

@property
def dt(self):
def dt(self) -> float:
"""Time between two trajectory frames in picoseconds."""
return self.ts.dt

@property
def totaltime(self):
def totaltime(self) -> float:
"""Total length of the trajectory
The time is calculated as ``(n_frames - 1) * dt``, i.e., we assume that
Expand All @@ -733,7 +736,7 @@ def totaltime(self):
return (self.n_frames - 1) * self.dt

@property
def frame(self):
def frame(self) -> int:
"""Frame number of the current time step.
This is a simple short cut to :attr:`Timestep.frame`.
Expand Down Expand Up @@ -781,20 +784,21 @@ def OtherWriter(self, filename, **kwargs):
pass
return core.writer(filename, **kwargs)

def _read_next_timestep(self, ts=None): # pragma: no cover
@abc.abstractmethod
def _read_next_timestep(self, ts=None):
# Example from DCDReader:
# if ts is None:
# ts = self.ts
# ts.frame = self._read_next_frame(etc)
# return ts
raise NotImplementedError(
"BUG: Override _read_next_timestep() in the trajectory reader!")
...

def __iter__(self):
""" Iterate over trajectory frames. """
self._reopen()
return self

@abc.abstractmethod
def _reopen(self):
"""Should position Reader to just before first frame
Expand Down Expand Up @@ -980,7 +984,7 @@ def __repr__(self):
nframes=self.n_frames,
natoms=self.n_atoms
))

def timeseries(self, asel: Optional['AtomGroup']=None,
start: Optional[int]=None, stop: Optional[int]=None,
step: Optional[int]=None,
Expand Down Expand Up @@ -1704,6 +1708,9 @@ def _reopen(self):
def next(self):
raise StopIteration(self._err.format(self.__class__.__name__))

def _read_next_timestep(self, ts=None):
raise NotImplementedError(self._err.format(self.__class__.__name__))

def __iter__(self):
self.rewind()
yield self.ts
Expand Down
111 changes: 22 additions & 89 deletions package/MDAnalysis/coordinates/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

import os.path
import bisect
import copy

from typing import Tuple
import numpy as np

from ..lib import util
Expand Down Expand Up @@ -159,7 +158,7 @@ def check_allowed_filetypes(readers, allowed):
"supported for formats: {}".format(allowed))


class ChainReader(base.ProtoReader):
class ChainReader(base.ReaderBase):
"""Reader that concatenates multiple trajectories on the fly.
The :class:`ChainReader` is used by MDAnalysis internally to
Expand Down Expand Up @@ -219,7 +218,8 @@ class ChainReader(base.ProtoReader):
format = 'CHAIN'

@store_init_arguments
def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
def __init__(self, filenames, skip=1, dt=None, continuous=False,
convert_units=True, **kwargs):
"""Set up the chain reader.
Parameters
Expand Down Expand Up @@ -256,14 +256,18 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
unchanged
"""
super(ChainReader, self).__init__()
super(ChainReader, self).__init__(filename='CHAIN',
skip=skip,
convert_units=convert_units,
dt=dt,
**kwargs)

filenames = asiterable(filenames)
# Override here because single frame readers handle this argument as a
# kwarg to a timestep which behaves differently if dt is present or not.
if dt is not None:
kwargs['dt'] = dt
self.readers = [core.reader(filename, **kwargs)
self.readers = [core.reader(filename, convert_units=convert_units, **kwargs)
for filename in filenames]
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn
for fn in filenames])
Expand Down Expand Up @@ -377,7 +381,7 @@ def _format_hint(thing):
util.iterable(thing) and
not util.isstream(thing))

def _get_local_frame(self, k):
def _get_local_frame(self, k) -> Tuple[int, int]:
"""Find trajectory index and trajectory frame for chained frame `k`.
Parameters
Expand Down Expand Up @@ -450,21 +454,7 @@ def convert_pos_from_native(self, x):
def convert_pos_to_native(self, x):
return self.active_reader.convert_pos_to_native(x)

def copy(self):
new = self.__class__(**self._kwargs)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
return new

# attributes that can change with the current reader
@property
def filename(self):
"""Filename of the currently read trajectory"""
return self.active_reader.filename

# TODO: check that skip_timestep is still supported in all readers
# or should this be removed?
@property
Expand Down Expand Up @@ -548,6 +538,7 @@ def __activate_reader(self, i):
if not (0 <= i < len(self.readers)):
raise IndexError("Reader index must be 0 <= i < {0:d}".format(len(self.readers)))
self.__active_reader_index = i
self.filename = self.filenames[i]

@property
def active_reader(self):
Expand Down Expand Up @@ -582,32 +573,28 @@ def _read_frame(self, frame):
self.__current_frame = frame
return self.ts


def _read_next_timestep(self, ts=None):
if ts is None:
ts = self.ts
ts = self.__next__()
return ts

def rewind(self):
"""Set current frame to the beginning."""
self._rewind()
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()

def _rewind(self):
def _reopen(self):
"""Internal method: Rewind trajectories themselves and trj pointer."""
self.__current_frame = -1
self._apply('rewind')
self.__next__()

def close(self):
self._apply('close')

def __iter__(self):
"""Generator for all frames, starting at frame 0."""
self.__current_frame = -1
# start from first frame
return self

def __repr__(self):
if len(self.filenames) > 3:
fnames = "{fname} and {nfanmes} more".format(
Expand All @@ -621,57 +608,3 @@ def __repr__(self):
fname=fnames,
nframes=self.n_frames,
natoms=self.n_atoms))

def add_transformations(self, *transformations):
""" Add all transformations to be applied to the trajectory.
This function take as list of transformations as an argument. These
transformations are functions that will be called by the Reader and given
a :class:`Timestep` object as argument, which will be transformed and returned
to the Reader.
The transformations can be part of the :mod:`~MDAnalysis.transformations`
module, or created by the user, and are stored as a list `transformations`.
This list can only be modified once, and further calls of this function will
raise an exception.
.. code-block:: python
u = MDAnalysis.Universe(topology, coordinates)
workflow = [some_transform, another_transform, this_transform]
u.trajectory.add_transformations(*workflow)
Parameters
----------
transform_list : list
list of all the transformations that will be applied to the coordinates
See Also
--------
:mod:`MDAnalysis.transformations`
"""
#Overrides :meth:`~MDAnalysis.coordinates.base.ProtoReader.add_transformations`
#to avoid unintended behaviour where the coordinates of each frame are transformed
#multiple times when iterating over the trajectory.
#In this method, the trajectory is modified all at once and once only.

super(ChainReader, self).add_transformations(*transformations)
for r in self.readers:
r.add_transformations(*transformations)

def _apply_transformations(self, ts):
""" Applies the transformations to the timestep."""
# Overrides :meth:`~MDAnalysis.coordinates.base.ProtoReader.add_transformations`
# to avoid applying the same transformations multiple times on each frame

return ts

def __next__(self):
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()
29 changes: 29 additions & 0 deletions testsuite/MDAnalysisTests/coordinates/test_chainreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,32 @@ def test_multilevel_arg_sort(l, ref):
def test_filter_times(l, ref):
indices = mda.coordinates.chain.filter_times(l, dt=1)
assert_equal(indices, ref)


def test_issue_4008():
# test for transformation on a chainreader trajectory
# in issue #4008 it was observed that the transformation did not work on
# the frames which were "borders" of the chained trajectories
def shift_by_group(ag):
# this just shifts all atoms in the box by their center of mass,
# i.e. afterwards the center of mass of all atoms should be [0,0,0]
def wrapped(ts):
shift = ag.center_of_mass()
ts.positions[:] -= shift
return ts

return wrapped

u = mda.Universe(GRO, [XTC, XTC])

trafos = (shift_by_group(u.atoms),)
u.trajectory.add_transformations(*trafos)
com = np.zeros((len(u.trajectory), 3))

for i, ts in enumerate(u.trajectory):
com[i] = u.atoms.center_of_mass()

# see issue for details, but in general the com shouldn't be far from zero
# in the issue report values of ~25 were seen before the fix
# with the fix all values should be <1e-6
assert np.abs(com).max() < 1e-6
4 changes: 2 additions & 2 deletions testsuite/MDAnalysisTests/coordinates/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def test_reader_copied_extra_attributes(original_and_copy_extra_args):
# Issue #3664
original, copy = original_and_copy_extra_args

# memory and chain readers subclass protoreader directly and
# memory reader subclass protoreader directly and
# therefore don't have convert_units or _ts_kwargs
if original.__class__.__bases__[0].__name__ != "ProtoReader":
assert original.format not in ('MEMORY', 'CHAIN')
assert original.format is not 'MEMORY'
assert original.convert_units is False
assert copy.convert_units is False
assert original._ts_kwargs['time_offset'] == 10
Expand Down

0 comments on commit 7e9c6f8

Please sign in to comment.