Skip to content

Commit

Permalink
make ChainReader subclass ReaderBase
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed May 25, 2023
1 parent 5ab95ec commit 6147f26
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 106 deletions.
37 changes: 22 additions & 15 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,11 @@
:members:
"""
import abc
import numpy as np
import numbers
import copy
import warnings
import weakref
from typing import Union, Optional, List, Dict
from typing import Any, Union, Optional, List, Dict

from .timestep import Timestep
from . import core
Expand Down Expand Up @@ -614,7 +613,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False # do not suppress exceptions


class _Readermeta(type):
class _Readermeta(abc.ABCMeta):
"""Automatic Reader registration metaclass
.. versionchanged:: 1.0.0
Expand Down Expand Up @@ -664,14 +663,18 @@ class ProtoReader(IOBase, metaclass=_Readermeta):
#: The appropriate Timestep class, e.g.
#: :class:`MDAnalysis.coordinates.xdrfile.XTC.Timestep` for XTC.
_Timestep = Timestep
_transformations: list
_auxs: dict
_filename: Any
n_frames: int

def __init__(self):
# initialise list to store added auxiliary readers in
# subclasses should now call super
self._auxs = {}
self._transformations=[]

def __len__(self):
def __len__(self) -> int:
return self.n_frames

@classmethod
Expand All @@ -691,7 +694,7 @@ def parse_n_atoms(cls, filename, **kwargs):
raise NotImplementedError("{} cannot deduce the number of atoms"
"".format(cls.__name__))

def next(self):
def next(self) -> Timestep:
"""Forward one step to next frame."""
try:
ts = self._read_next_timestep()
Expand All @@ -706,22 +709,22 @@ def next(self):

return ts

def __next__(self):
def __next__(self) -> Timestep:
"""Forward one step to next frame when using the `next` builtin."""
return self.next()

def rewind(self):
def rewind(self) -> Timestep:
"""Position at beginning of trajectory"""
self._reopen()
self.next()

@property
def dt(self):
def dt(self) -> float:
"""Time between two trajectory frames in picoseconds."""
return self.ts.dt

@property
def totaltime(self):
def totaltime(self) -> float:
"""Total length of the trajectory
The time is calculated as ``(n_frames - 1) * dt``, i.e., we assume that
Expand All @@ -733,7 +736,7 @@ def totaltime(self):
return (self.n_frames - 1) * self.dt

@property
def frame(self):
def frame(self) -> int:
"""Frame number of the current time step.
This is a simple short cut to :attr:`Timestep.frame`.
Expand Down Expand Up @@ -781,20 +784,21 @@ def OtherWriter(self, filename, **kwargs):
pass
return core.writer(filename, **kwargs)

def _read_next_timestep(self, ts=None): # pragma: no cover
@abc.abstractmethod
def _read_next_timestep(self, ts=None):
# Example from DCDReader:
# if ts is None:
# ts = self.ts
# ts.frame = self._read_next_frame(etc)
# return ts
raise NotImplementedError(
"BUG: Override _read_next_timestep() in the trajectory reader!")
...

def __iter__(self):
""" Iterate over trajectory frames. """
self._reopen()
return self

@abc.abstractmethod
def _reopen(self):
"""Should position Reader to just before first frame
Expand Down Expand Up @@ -980,7 +984,7 @@ def __repr__(self):
nframes=self.n_frames,
natoms=self.n_atoms
))

def timeseries(self, asel: Optional['AtomGroup']=None,
start: Optional[int]=None, stop: Optional[int]=None,
step: Optional[int]=None,
Expand Down Expand Up @@ -1704,6 +1708,9 @@ def _reopen(self):
def next(self):
raise StopIteration(self._err.format(self.__class__.__name__))

def _read_next_timestep(self, ts=None):
raise NotImplementedError(self._err.format(self.__class__.__name__))

def __iter__(self):
self.rewind()
yield self.ts
Expand Down
108 changes: 19 additions & 89 deletions package/MDAnalysis/coordinates/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

import os.path
import bisect
import copy

from typing import Tuple
import numpy as np

from ..lib import util
Expand Down Expand Up @@ -159,7 +158,7 @@ def check_allowed_filetypes(readers, allowed):
"supported for formats: {}".format(allowed))


class ChainReader(base.ProtoReader):
class ChainReader(base.ReaderBase):
"""Reader that concatenates multiple trajectories on the fly.
The :class:`ChainReader` is used by MDAnalysis internally to
Expand Down Expand Up @@ -219,7 +218,8 @@ class ChainReader(base.ProtoReader):
format = 'CHAIN'

@store_init_arguments
def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
def __init__(self, filenames, skip=1, dt=None, continuous=False,
convert_units=True, **kwargs):
"""Set up the chain reader.
Parameters
Expand Down Expand Up @@ -256,14 +256,15 @@ def __init__(self, filenames, skip=1, dt=None, continuous=False, **kwargs):
unchanged
"""
super(ChainReader, self).__init__()
super(ChainReader, self).__init__(filename='CHAIN',
convert_units=convert_units)

filenames = asiterable(filenames)
# Override here because single frame readers handle this argument as a
# kwarg to a timestep which behaves differently if dt is present or not.
if dt is not None:
kwargs['dt'] = dt
self.readers = [core.reader(filename, **kwargs)
self.readers = [core.reader(filename, convert_units=convert_units, **kwargs)
for filename in filenames]
self.filenames = np.array([fn[0] if isinstance(fn, tuple) else fn
for fn in filenames])
Expand Down Expand Up @@ -377,7 +378,7 @@ def _format_hint(thing):
util.iterable(thing) and
not util.isstream(thing))

def _get_local_frame(self, k):
def _get_local_frame(self, k) -> Tuple[int, int]:
"""Find trajectory index and trajectory frame for chained frame `k`.
Parameters
Expand Down Expand Up @@ -450,21 +451,7 @@ def convert_pos_from_native(self, x):
def convert_pos_to_native(self, x):
return self.active_reader.convert_pos_to_native(x)

def copy(self):
new = self.__class__(**self._kwargs)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
return new

# attributes that can change with the current reader
@property
def filename(self):
"""Filename of the currently read trajectory"""
return self.active_reader.filename

# TODO: check that skip_timestep is still supported in all readers
# or should this be removed?
@property
Expand Down Expand Up @@ -548,6 +535,7 @@ def __activate_reader(self, i):
if not (0 <= i < len(self.readers)):
raise IndexError("Reader index must be 0 <= i < {0:d}".format(len(self.readers)))
self.__active_reader_index = i
self.filename = self.filenames[i]

@property
def active_reader(self):
Expand Down Expand Up @@ -582,32 +570,28 @@ def _read_frame(self, frame):
self.__current_frame = frame
return self.ts


def _read_next_timestep(self, ts=None):
if ts is None:
ts = self.ts
ts = self.__next__()
return ts

def rewind(self):
"""Set current frame to the beginning."""
self._rewind()
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()

def _rewind(self):
def _reopen(self):
"""Internal method: Rewind trajectories themselves and trj pointer."""
self.__current_frame = -1
self._apply('rewind')
self.__next__()

def close(self):
self._apply('close')

def __iter__(self):
"""Generator for all frames, starting at frame 0."""
self.__current_frame = -1
# start from first frame
return self

def __repr__(self):
if len(self.filenames) > 3:
fnames = "{fname} and {nfanmes} more".format(
Expand All @@ -621,57 +605,3 @@ def __repr__(self):
fname=fnames,
nframes=self.n_frames,
natoms=self.n_atoms))

def add_transformations(self, *transformations):
""" Add all transformations to be applied to the trajectory.
This function take as list of transformations as an argument. These
transformations are functions that will be called by the Reader and given
a :class:`Timestep` object as argument, which will be transformed and returned
to the Reader.
The transformations can be part of the :mod:`~MDAnalysis.transformations`
module, or created by the user, and are stored as a list `transformations`.
This list can only be modified once, and further calls of this function will
raise an exception.
.. code-block:: python
u = MDAnalysis.Universe(topology, coordinates)
workflow = [some_transform, another_transform, this_transform]
u.trajectory.add_transformations(*workflow)
Parameters
----------
transform_list : list
list of all the transformations that will be applied to the coordinates
See Also
--------
:mod:`MDAnalysis.transformations`
"""
#Overrides :meth:`~MDAnalysis.coordinates.base.ProtoReader.add_transformations`
#to avoid unintended behaviour where the coordinates of each frame are transformed
#multiple times when iterating over the trajectory.
#In this method, the trajectory is modified all at once and once only.

super(ChainReader, self).add_transformations(*transformations)
for r in self.readers:
r.add_transformations(*transformations)

def _apply_transformations(self, ts):
""" Applies the transformations to the timestep."""
# Overrides :meth:`~MDAnalysis.coordinates.base.ProtoReader.add_transformations`
# to avoid applying the same transformations multiple times on each frame

return ts

def __next__(self):
if self.__current_frame < self.n_frames - 1:
j, f = self._get_local_frame(self.__current_frame + 1)
self.__activate_reader(j)
self.ts = self.active_reader[f]
self.ts.frame = self.__current_frame + 1
self.__current_frame += 1
return self.ts
else:
raise StopIteration()
4 changes: 2 additions & 2 deletions testsuite/MDAnalysisTests/coordinates/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def test_reader_copied_extra_attributes(original_and_copy_extra_args):
# Issue #3664
original, copy = original_and_copy_extra_args

# memory and chain readers subclass protoreader directly and
# memory reader subclass protoreader directly and
# therefore don't have convert_units or _ts_kwargs
if original.__class__.__bases__[0].__name__ != "ProtoReader":
assert original.format not in ('MEMORY', 'CHAIN')
assert original.format is not 'MEMORY'
assert original.convert_units is False
assert copy.convert_units is False
assert original._ts_kwargs['time_offset'] == 10
Expand Down

0 comments on commit 6147f26

Please sign in to comment.