diff --git a/swiftpipeline/config.py b/swiftpipeline/config.py index 0253953..51ae211 100644 --- a/swiftpipeline/config.py +++ b/swiftpipeline/config.py @@ -3,7 +3,7 @@ """ import yaml -from typing import List +from typing import List, Union # Items to read directly from the yaml file with their defaults direct_read = { @@ -24,16 +24,19 @@ class Script(object): # Filename of the python script to be ran filename: str - # Caption for displaying in a webpage, etc. - caption: str - # Output file name; required to match up the caption and file in - # postprocessing - output_file: str + # Output file name(s) + # List of strings if the script produces several plots; otherwise, string + # If list, required to have the same size as caption and title + output_file: Union[str, List[str]] + # Caption(s) for displaying in a webpage, etc. + # List of strings if the script produces several plots (one caption per plot); otherwise, string + caption: Union[str, List[str]] # Section heading; used to classify this with similar figures # in the output. section: str - # Plot title; written above the caption - title: str + # Plot title(s); written above the caption + # List of strings if the script makes several plots (one title per plot); otherwise, string + title: Union[str, List[str]] # Show on webpage; Defaults to True but used to disable webpage plotting # in the config file if required. show_on_webpage: bool diff --git a/swiftpipeline/html.py b/swiftpipeline/html.py index c8c2895..27bf88a 100644 --- a/swiftpipeline/html.py +++ b/swiftpipeline/html.py @@ -12,7 +12,7 @@ from jinja2 import Environment, PackageLoader, FileSystemLoader, select_autoescape from time import strftime -from typing import List +from typing import List, Dict from pathlib import Path import unyt @@ -230,16 +230,58 @@ def add_config_metadata(self, config: Config, is_comparison: bool = False): scripts_to_use = config.comparison_scripts if is_comparison else config.scripts for section in sections: - plots = [ - dict( - filename=script.output_file, - title=script.title, - caption=script.caption, - hash=abs(hash(script.caption + script.title)), - ) - for script in scripts_to_use - if script.section == section and script.show_on_webpage - ] + + plots: List[Dict] = [] + + for script in scripts_to_use: + if script.section == section and script.show_on_webpage: + + # Check whether we expect more than one plot (output file) produced by the script + if isinstance(script.output_file, list): + + # If so, check that each plot has its own title and caption + assert isinstance(script.title, list) and isinstance( + script.caption, list + ), ( + f"Check the config parameters for '{script.filename}'. " + f"If 'output_file' is a list object, then 'title' and 'caption' must be too!" + ) + + # Check that the number of plots is the same as the number of their titles and captions + assert ( + len(script.output_file) + == len(script.title) + == len(script.caption) + ), ( + f"Check the config parameters for '{script.filename}'. " + f"Lists 'output_file', 'title' and 'caption' must have the same size!" + ) + + # Loop over plots made by the script + for output_file, title, caption in zip( + script.output_file, script.title, script.caption + ): + + # Save everything into a dict + plot = dict( + filename=output_file, + title=title, + caption=caption, + hash=abs(hash(caption + output_file)), + ) + + # Add collect in a list + plots.append(plot) + + # The script makes just a single plot + else: + plot = dict( + filename=script.output_file, + title=script.title, + caption=script.caption, + hash=abs(hash(script.caption + script.output_file)), + ) + plots.append(plot) current_section_plots = ( self.variables["sections"].get(section, {"plots": []}).get("plots", [])