Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Writes compressed output of given format #2221

Merged
merged 22 commits into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package/MDAnalysis/coordinates/GRO.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __init__(self, filename, convert_units=None, n_atoms=None, **kwargs):
w.write(u.atoms)

"""
self.filename = util.filename(filename, ext='gro')
self.filename = util.filename(filename, ext='gro', keep = True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8; no space around = in args.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we also restricting the number of characters per line to 80/79 as per pep8?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the Style Guide: yes, we try to keep within 79 characters. (Not all our code rigorously adheres to it, but that's what we're striving for)

self.n_atoms = n_atoms
self.reindex = kwargs.pop('reindex', True)

Expand Down
16 changes: 4 additions & 12 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,7 +2975,7 @@ def improper(self):
"improper only makes sense for a group with exactly 4 atoms")
return topologyobjects.ImproperDihedral(self.ix, self.universe)

def write(self, filename=None, file_format="PDB",
def write(self, filename=None, file_format=None,
filenamefmt="{trjname}_{frame}", frames=None, **kwargs):
"""Write `AtomGroup` to a file.

Expand Down Expand Up @@ -3054,8 +3054,7 @@ def write(self, filename=None, file_format="PDB",
if filename is None:
trjname, ext = os.path.splitext(os.path.basename(trj.filename))
filename = filenamefmt.format(trjname=trjname, frame=trj.frame)
filename = util.filename(filename, ext=file_format.lower(), keep=True)

filename = util.filename(filename, ext= file_format or 'PDB', keep=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8 !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test that you can omit .lower()??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like file_format or 'PDB'. Better be explicit so that only None really gets defaulted:

file_format if file_format is not None else 'PDB'

# Some writer behave differently when they are given a "multiframe"
# argument. It is the case of the PDB writer tht writes models when
# "multiframe" is True.
Expand All @@ -3080,14 +3079,7 @@ def write(self, filename=None, file_format="PDB",
# Try and select a Class using get_ methods (becomes `writer`)
# Once (and if!) class is selected, use it in with block
try:
# format keyword works differently in get_writer and get_selection_writer
# here it overrides everything, in get_sel it is just a default
# apply sparingly here!
format = os.path.splitext(filename)[1][1:] # strip initial dot!
format = format or file_format
format = format.strip().upper()

writer = get_writer_for(filename, format=format, multiframe=multiframe)
writer = get_writer_for(filename, format=file_format, multiframe=multiframe)
except (ValueError, TypeError):
pass
else:
Expand All @@ -3106,7 +3098,7 @@ def write(self, filename=None, file_format="PDB",
try:
# here `file_format` is only used as default,
# anything pulled off `filename` will be used preferentially
writer = get_selection_writer_for(filename, file_format)
writer = get_selection_writer_for(filename, file_format or 'PDB')
fenilsuchak marked this conversation as resolved.
Show resolved Hide resolved
except (TypeError, NotImplementedError):
pass
else:
Expand Down
1 change: 1 addition & 0 deletions package/MDAnalysis/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def filename(name, ext=None, keep=False):
Also permits :class:`NamedStream` to pass through.
"""
if ext is not None:
ext = ext.lower()
if not ext.startswith(os.path.extsep):
ext = os.path.extsep + ext
root, origext = os.path.splitext(name)
Expand Down
30 changes: 30 additions & 0 deletions testsuite/MDAnalysisTests/core/test_atomgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,27 @@ def test_write_frame_none(self, u, tmpdir, extension):
u.atoms.positions[None, ...], new_positions, decimal=2
)

@pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Se should capitalized extensions, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't understand. Should I capitalize all the extensions?. I've used the parameters from the previous test already written.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test for a capitalized extension 'PDB'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, clearly I should proof-read what I am typing... yes, test for capitalized extensions was what I was looking for.

def test_compressed_write_frame_none(self, u, tmpdir, extension):
for ext in ('.gz', '.bz2'):
destination = str(tmpdir / 'test.' + extension + ext)
u.atoms.write(destination, frames=None)
u_new = mda.Universe(destination)
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
assert_array_almost_equal(
u.atoms.positions[None, ...], new_positions, decimal=2
fenilsuchak marked this conversation as resolved.
Show resolved Hide resolved
)

def test_write_frames_all(self, u, tmpdir):
for ext in ('.gz', '.bz2'):
destination = str(tmpdir / 'test.dcd') + ext
u.atoms.write(destination, frames='all')
u_new = mda.Universe(destination)
ref_positions = np.stack([ts.positions for ts in u.trajectory])
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
assert_array_almost_equal(new_positions, ref_positions)

def test_compressed_rite_frames_all(self, u, tmpdir):
fenilsuchak marked this conversation as resolved.
Show resolved Hide resolved
destination = str(tmpdir / 'test.dcd')
u.atoms.write(destination, frames='all')
u_new = mda.Universe(destination)
Expand Down Expand Up @@ -238,6 +258,16 @@ def test_write_atoms(self, universe, outfile):
err_msg=("atom coordinate mismatch between original and {0!s} file"
"".format(self.ext)))

def test_compressed_write_atoms(self, universe, outfile):
for compressed_ext in ('.gz', '.bz2'):
universe.atoms.write(outfile + compressed_ext)
u2 = self.universe_from_tmp(outfile + compressed_ext)
assert_almost_equal(
universe.atoms.positions, u2.atoms.positions,
self.precision,
err_msg=("atom coordinate mismatch between original and {0!s} file"
"".format(self.ext)))

def test_write_empty_atomgroup(self, universe, outfile):
sel = universe.select_atoms('name doesntexist')
with pytest.raises(IndexError):
Expand Down