Skip to content

Commit

Permalink
Fix issue #2462: bands export error
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Apr 12, 2020
1 parent c2599f5 commit 68621fc
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 172 deletions.
247 changes: 122 additions & 125 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,7 @@ def _prepare_mpl_singlefile(self, *args, **kwargs):

s_header = matplotlib_header_template.substitute()
s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2))
if len(all_data['paths']) == 1:
s_body = matplotlib_body_template.substitute(plot_code=single_kp)
else:
s_body = matplotlib_body_template.substitute(plot_code=multi_kp)
s_body = self._get_mpl_body_template(all_data['paths'])
s_footer = matplotlib_footer_template_show.substitute()

s = s_header + s_import + s_body + s_footer
Expand Down Expand Up @@ -857,125 +854,13 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs):

s_header = matplotlib_header_template.substitute()
s_import = matplotlib_import_data_fromfile_template.substitute(json_fname=json_fname)
if len(all_data['paths']) == 1:
s_body = matplotlib_body_template.substitute(plot_code=single_kp)
else:
s_body = matplotlib_body_template.substitute(plot_code=multi_kp)
s_body = self._get_mpl_body_template(all_data['paths'])
s_footer = matplotlib_footer_template_show.substitute()

s = s_header + s_import + s_body + s_footer

return s.encode('utf-8'), ext_files

def _prepare_gnuplot(self,
main_file_name='',
title='',
comments=True,
prettify_format=None,
y_max_lim=None,
y_min_lim=None,
y_origin=0.):
"""
Prepare an gnuplot script to plot the bands, with the .dat file
returned as an independent file.
:param main_file_name: if the user asks to write the main content on a
file, this contains the filename. This should be used to infer a
good filename for the additional files.
In this case, we remove the extension, and add '_data.dat'
:param title: if specified, add a title to the plot
:param comments: if True, print comments (if it makes sense for the given
format)
:param prettify_format: if None, use the default prettify format. Otherwise
specify a string with the prettifier to use.
"""
import os

if main_file_name is not None:
dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat'
else:
dat_filename = 'band_data.dat'

if prettify_format is None:
# Default. Specified like this to allow caller functions to pass 'None'
prettify_format = 'gnuplot_seekpath'

plot_info = self._get_bandplot_data(
cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin)

bands = plot_info['y']
x = plot_info['x']
labels = plot_info['labels']

num_labels = len(labels)
num_bands = bands.shape[1]

# axis limits
if y_max_lim is None:
y_max_lim = bands.max()
if y_min_lim is None:
y_min_lim = bands.min()
x_min_lim = min(x) # this isn't a numpy array, but a list
x_max_lim = max(x)

# first prepare the xy coordinates of the sets
raw_data, _ = self._prepare_dat_blocks(plot_info, comments=comments)

xtics_string = ', '.join('"{}" {}'.format(label, pos) for pos, label in plot_info['labels'])

script = []
# Start with some useful comments

if comments:
script.append(prepare_header_comment(self.uuid, plot_info=plot_info, comment_char='# '))
script.append('')

script.append(u"""## Uncomment the next two lines to write directly to PDF
## Note: You need to have gnuplot installed with pdfcairo support!
#set term pdfcairo
#set output 'out.pdf'
### Uncomment one of the options below to change font
### For the LaTeX fonts, you can download them from here:
### https://sourceforge.net/projects/cm-unicode/
### And then install them in your system
## LaTeX Serif font, if installed
#set termopt font "CMU Serif, 12"
## LaTeX Sans Serif font, if installed
#set termopt font "CMU Sans Serif, 12"
## Classical Times New Roman
#set termopt font "Times New Roman, 12"
""")

# Actual logic
script.append('set termopt enhanced') # Properly deals with e.g. subscripts
script.append('set encoding utf8') # To deal with Greek letters
script.append('set xtics ({})'.format(xtics_string))

script.append('unset key')


script.append('set yrange [{}:{}]'.format(y_min_lim, y_max_lim))

script.append('set ylabel "{}"'.format('Dispersion ({})'.format(self.units)))

if title:
script.append('set title "{}"'.format(title.replace('"', '\"')))

# Plot, escaping filename
if len(x) > 1:
script.append('set xrange [{}:{}]'.format(x_min_lim, x_max_lim))
script.append('set grid xtics lt 1 lc rgb "#888888"')
script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"')))
else:
script.append('set xrange [-1.0:1.0]')
script.append('plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"')))

script_data = '\n'.join(script) + '\n'
extra_files = {dat_filename: raw_data}

return script_data.encode('utf-8'), extra_files

def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs):
"""
Prepare a python script using matplotlib to plot the bands, with the JSON
Expand All @@ -996,10 +881,7 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs):
# Use the Agg backend
s_header = matplotlib_header_agg_template.substitute()
s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2))
if len(all_data['paths']) == 1:
s_body = matplotlib_body_template.substitute(plot_code=single_kp)
else:
s_body = matplotlib_body_template.substitute(plot_code=multi_kp)
s_body = self._get_mpl_body_template(all_data['paths'])

# I get a temporary file name
handle, filename = tempfile.mkstemp()
Expand Down Expand Up @@ -1050,10 +932,7 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs):
# Use the Agg backend
s_header = matplotlib_header_agg_template.substitute()
s_import = matplotlib_import_data_inline_template.substitute(all_data_json=json.dumps(all_data, indent=2))
if len(all_data['paths']) == 1:
s_body = matplotlib_body_template.substitute(plot_code=single_kp)
else:
s_body = matplotlib_body_template.substitute(plot_code=multi_kp)
s_body = self._get_mpl_body_template(all_data['paths'])

# I get a temporary file name
handle, filename = tempfile.mkstemp()
Expand Down Expand Up @@ -1084,6 +963,17 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs):

return imgdata, {}

@staticmethod
def _get_mpl_body_template(paths):
"""
:param paths: paths of k-points
"""
if len(paths) == 1:
s_body = matplotlib_body_template.substitute(plot_code=single_kp)
else:
s_body = matplotlib_body_template.substitute(plot_code=multi_kp)
return s_body

def show_mpl(self, **kwargs):
"""
Call a show() command for the band structure using matplotlib.
Expand All @@ -1094,6 +984,113 @@ def show_mpl(self, **kwargs):
"""
exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) # pylint: disable=exec-used

def _prepare_gnuplot(self,
main_file_name=None,
title='',
comments=True,
prettify_format=None,
y_max_lim=None,
y_min_lim=None,
y_origin=0.):
"""
Prepare an gnuplot script to plot the bands, with the .dat file
returned as an independent file.
:param main_file_name: if the user asks to write the main content on a
file, this contains the filename. This should be used to infer a
good filename for the additional files.
In this case, we remove the extension, and add '_data.dat'
:param title: if specified, add a title to the plot
:param comments: if True, print comments (if it makes sense for the given
format)
:param prettify_format: if None, use the default prettify format. Otherwise
specify a string with the prettifier to use.
"""
import os

main_file_name = main_file_name or 'band.dat'
dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat'

if prettify_format is None:
# Default. Specified like this to allow caller functions to pass 'None'
prettify_format = 'gnuplot_seekpath'

plot_info = self._get_bandplot_data(
cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin)

bands = plot_info['y']
x = plot_info['x']
labels = plot_info['labels']

num_labels = len(labels)
num_bands = bands.shape[1]

# axis limits
if y_max_lim is None:
y_max_lim = bands.max()
if y_min_lim is None:
y_min_lim = bands.min()
x_min_lim = min(x) # this isn't a numpy array, but a list
x_max_lim = max(x)

# first prepare the xy coordinates of the sets
raw_data, _ = self._prepare_dat_blocks(plot_info, comments=comments)

xtics_string = ', '.join('"{}" {}'.format(label, pos) for pos, label in plot_info['labels'])

script = []
# Start with some useful comments

if comments:
script.append(prepare_header_comment(self.uuid, plot_info=plot_info, comment_char='# '))
script.append('')

script.append(u"""## Uncomment the next two lines to write directly to PDF
## Note: You need to have gnuplot installed with pdfcairo support!
#set term pdfcairo
#set output 'out.pdf'
### Uncomment one of the options below to change font
### For the LaTeX fonts, you can download them from here:
### https://sourceforge.net/projects/cm-unicode/
### And then install them in your system
## LaTeX Serif font, if installed
#set termopt font "CMU Serif, 12"
## LaTeX Sans Serif font, if installed
#set termopt font "CMU Sans Serif, 12"
## Classical Times New Roman
#set termopt font "Times New Roman, 12"
""")

# Actual logic
script.append('set termopt enhanced') # Properly deals with e.g. subscripts
script.append('set encoding utf8') # To deal with Greek letters
script.append('set xtics ({})'.format(xtics_string))

script.append('unset key')


script.append('set yrange [{}:{}]'.format(y_min_lim, y_max_lim))

script.append('set ylabel "{}"'.format('Dispersion ({})'.format(self.units)))

if title:
script.append('set title "{}"'.format(title.replace('"', '\"')))

# Plot, escaping filename
if len(x) > 1:
script.append('set xrange [{}:{}]'.format(x_min_lim, x_max_lim))
script.append('set grid xtics lt 1 lc rgb "#888888"')
script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"')))
else:
script.append('set xrange [-1.0:1.0]')
script.append('plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"')))

script_data = '\n'.join(script) + '\n'
extra_files = {dat_filename: raw_data}

return script_data.encode('utf-8'), extra_files

def _prepare_agr(self,
main_file_name='',
comments=True,
Expand Down
59 changes: 12 additions & 47 deletions tests/cmdline/commands/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,51 +344,16 @@ def test_bandsexport(self):
self.assertIn(b'[1.0, 3.0]', res.stdout_bytes, 'The string [1.0, 3.0] was not found in the bands' 'export')

def test_bandsexport_single_kp(self):
""" test issue #2462 """

# Create bands structure object.
alat = 4. # angstrom
cell = [
[
alat,
0.,
0.,
],
[
0.,
alat,
0.,
],
[
0.,
0.,
alat,
],
]
strct = StructureData(cell=cell)
strct.append_atom(position=(0., 0., 0.), symbols='Fe')
strct.append_atom(position=(alat / 2., alat / 2., alat / 2.), symbols='O')
strct.store()

@calcfunction
def connect_structure_bands(strct): # pylint: disable=unused-argument
alat = 4.
cell = np.array([
[alat, 0., 0.],
[0., alat, 0.],
[0., 0., alat],
])

kpnts = KpointsData()
kpnts.set_cell(cell)
kpnts.set_kpoints([[0., 0., 0.]])

bands = BandsData()
bands.set_kpointsdata(kpnts)
bands.set_bands([[1.0, 2.0]])
return bands
"""
Plot band for single k-point (issue #2462).
"""
kpnts = KpointsData()
kpnts.set_kpoints([[0., 0., 0.]])

bands = connect_structure_bands(strct)
bands = BandsData()
bands.set_kpointsdata(kpnts)
bands.set_bands([[1.0, 2.0]])
bands.store()

# matplotlib
options = [str(bands.id), '--format', 'mpl_singlefile']
Expand All @@ -399,9 +364,9 @@ def connect_structure_bands(strct): # pylint: disable=unused-argument
with self.cli_runner.isolated_filesystem():
options = [str(bands.id), '--format', 'gnuplot', '-o', 'bands.gnu']
self.cli_runner.invoke(cmd_bands.bands_export, options, catch_exceptions=False)
with open('bands.gnu', 'r') as f:
res = f.read()
self.assertIn('vectors nohead', res, 'The string vectors nohead was not found in the gnuplot script')
with open('bands.gnu', 'r') as gnu_file:
res = gnu_file.read()
self.assertIn('vectors nohead', res, 'The string "vectors nohead" was not found in the gnuplot script')


class TestVerdiDataDict(AiidaTestCase):
Expand Down

0 comments on commit 68621fc

Please sign in to comment.