diff --git a/pisa/core/detectors.py b/pisa/core/detectors.py index 6819bdb00..9c3bade82 100644 --- a/pisa/core/detectors.py +++ b/pisa/core/detectors.py @@ -13,7 +13,6 @@ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from collections import OrderedDict import inspect -from itertools import product import os from tabulate import tabulate from copy import deepcopy @@ -21,7 +20,6 @@ import numpy as np from pisa import ureg -from pisa.core.map import MapSet from pisa.core.pipeline import Pipeline from pisa.core.distribution_maker import DistributionMaker from pisa.core.param import ParamSet, Param @@ -66,7 +64,7 @@ def __init__(self, pipelines, label=None, set_livetime_from_data=True, profile=F self._distribution_makers , self.det_names = [] , [] for pipeline in pipelines: if not isinstance(pipeline, Pipeline): - pipeline = Pipeline(pipeline) + pipeline = Pipeline(pipeline, profile=profile) name = pipeline.detector_name if name in self.det_names: @@ -111,7 +109,17 @@ def tabulate(self, tablefmt="plain"): def __iter__(self): return iter(self._distribution_makers) - + + def report_profile(self, detailed=False, format_num_kwargs=None): + """Report timing information on contained distribution makers. + See `Pipeline.report_profile` for details. + """ + for distribution_maker in self._distribution_makers: + print(distribution_maker.detector_name + ':') + distribution_maker.report_profile( + detailed=detailed, format_num_kwargs=format_num_kwargs + ) + @property def profile(self): return self._profile diff --git a/pisa/core/distribution_maker.py b/pisa/core/distribution_maker.py index 69aa1529f..24ffce248 100755 --- a/pisa/core/distribution_maker.py +++ b/pisa/core/distribution_maker.py @@ -11,7 +11,6 @@ from collections import OrderedDict from collections.abc import Mapping import inspect -from itertools import product import os from tabulate import tabulate @@ -104,6 +103,11 @@ def __init__(self, pipelines, label=None, set_livetime_from_data=True, profile=F for pipeline in pipelines: if not isinstance(pipeline, Pipeline): pipeline = Pipeline(pipeline, profile=profile) + else: + if profile: + # Only propagate if set to `True` (don't allow negative + # default to negate any original choice for the instance) + pipeline.profile = profile self._pipelines.append(pipeline) data_run_livetime = None @@ -215,6 +219,15 @@ def tabulate(self, tablefmt="plain"): def __iter__(self): return iter(self._pipelines) + def report_profile(self, detailed=False, format_num_kwargs=None): + """Report timing information on contained pipelines. + See `Pipeline.report_profile` for details. + """ + for pipeline in self.pipelines: + pipeline.report_profile( + detailed=detailed, format_num_kwargs=format_num_kwargs + ) + @property def profile(self): return self._profile @@ -225,8 +238,8 @@ def profile(self, value): pipeline.profile = value self._profile = value - def run(self): + """Run all pipelines""" for pipeline in self: pipeline.run() @@ -541,6 +554,22 @@ def test_DistributionMaker(): #current_hier = new_hier #current_mat = new_mat + # test profile flag + p_cfg = 'settings/pipeline/example.cfg' + p = Pipeline(p_cfg, profile=True) + dm = DistributionMaker(pipelines=p) + # default init using Pipeline instance shouldn't negate + assert dm.pipelines[0].profile + # but explicit request should + dm.profile = False + assert not dm.pipelines[0].profile + # now init from cfg path and request profile + dm = DistributionMaker(pipelines=p_cfg, profile=True) + assert dm.pipelines[0].profile + # explicitly request no profile + dm = DistributionMaker(pipelines=p_cfg, profile=False) + assert not dm.pipelines[0].profile + def parse_args(): """Get command line arguments""" @@ -626,4 +655,4 @@ def main(return_outputs=False): if __name__ == '__main__': - distribution_maker, outputs = main(return_outputs=True) # pylint: disable=invalid-name + distribution_maker, outputs = main(return_outputs=True) diff --git a/pisa/core/pipeline.py b/pisa/core/pipeline.py index 311e21ee9..9e0c952c6 100755 --- a/pisa/core/pipeline.py +++ b/pisa/core/pipeline.py @@ -10,6 +10,7 @@ from argparse import ArgumentParser from collections import OrderedDict +from collections.abc import Mapping from configparser import NoSectionError from copy import deepcopy from importlib import import_module @@ -17,6 +18,7 @@ from inspect import getsource import os from tabulate import tabulate +from time import time import traceback import numpy as np @@ -30,6 +32,7 @@ from pisa.core.binning import MultiDimBinning from pisa.utils.config_parser import PISAConfigParser, parse_pipeline_config from pisa.utils.fileio import mkdir +from pisa.utils.format import format_times from pisa.utils.hash import hash_obj from pisa.utils.log import logging, set_verbosity from pisa.utils.profiler import profile @@ -107,6 +110,9 @@ def __init__(self, config, profile=False): self.output_key = config['pipeline']['output_key'] self._profile = profile + self._setup_times = [] + self._run_times = [] + self._get_outputs_times = [] self._stages = [] self._config = config @@ -140,9 +146,46 @@ def tabulate(self, tablefmt="plain"): table[-1] += [len(s.params.fixed), len(s.params.free)] return tabulate(table, headers, tablefmt=tablefmt, colalign=colalign) - def report_profile(self, detailed=False): + def report_profile(self, detailed=False, format_num_kwargs=None): + """Report timing information on pipeline and contained services + + Parameters + ---------- + detailed : bool, default False + Whether to increase level of detail + format_num_kwargs : dict, optional + Dictionary containing arguments passed to `utils.format.format_num`. + Will display each number with three decimal digits by default. + + """ + if not self.profile: + # Report warning only at the pipeline level, which is what the + # typical user should come across. Assume that users calling + # `report_profile` on a `Stage` instance directly know what they're + # doing. + logging.warn( + '`profile` is set to False. Results may not show the expected ' + 'numbers of function calls.' + ) + if format_num_kwargs is None: + format_num_kwargs = { + 'precision': 1e-3, 'fmt': 'full', 'trailing_zeros': True + } + assert isinstance(format_num_kwargs, Mapping) + print(f'Pipeline: {self.name}') + for func_str, times in [ + ('- setup: ', self._setup_times), + ('- run: ', self._run_times), + ('- get_outputs: ', self._get_outputs_times) + ]: + print(func_str, + format_times(times=times, + nindent_detailed=len(func_str) + 1, + detailed=detailed, **format_num_kwargs) + ) + print('Individual services:') for stage in self.stages: - stage.report_profile(detailed=detailed) + stage.report_profile(detailed=detailed, **format_num_kwargs) @property def profile(self): @@ -315,7 +358,19 @@ def _init_stages(self): self.setup() - def get_outputs(self, output_binning=None, output_key=None): + def get_outputs(self, **get_outputs_kwargs): + """Wrapper around `_get_outputs`. The latter might + have quite some overhead compared to `run` alone""" + if self.profile: + start_t = time() + outputs = self._get_outputs(**get_outputs_kwargs) + end_t = time() + self._get_outputs_times.append(end_t - start_t) + else: + outputs = self._get_outputs(**get_outputs_kwargs) + return outputs + + def _get_outputs(self, output_binning=None, output_key=None): """Get MapSet output""" @@ -393,12 +448,32 @@ def _add_rotated(self, paramset:ParamSet, suppress_warning=False): return success def run(self): + """Wrapper around `_run_function`""" + if self.profile: + start_t = time() + self._run_function() + end_t = time() + self._run_times.append(end_t - start_t) + else: + self._run_function() + + def _run_function(self): """Run the pipeline to compute""" for stage in self.stages: logging.debug(f"Working on stage {stage.stage_name}.{stage.service_name}") stage.run() def setup(self): + """Wrapper around `_setup_function`""" + if self.profile: + start_t = time() + self._setup_function() + end_t = time() + self._setup_times.append(end_t - start_t) + else: + self._setup_function() + + def _setup_function(self): """Setup (reset) all stages""" self.data = ContainerSet(self.name) for stage in self.stages: diff --git a/pisa/core/stage.py b/pisa/core/stage.py index 6d865f7f5..3574bfae0 100644 --- a/pisa/core/stage.py +++ b/pisa/core/stage.py @@ -12,6 +12,7 @@ import numpy as np from pisa.core.container import ContainerSet +from pisa.utils.format import format_times from pisa.utils.log import logging from pisa.core.param import ParamSelector from pisa.utils.format import arg_str_seq_none @@ -138,23 +139,20 @@ def __init__( def __repr__(self): return 'Stage "%s"'%(self.__class__.__name__) - def report_profile(self, detailed=False): - def format(times): - tot = np.sum(times) - n = len(times) - ave = 0. if n == 0 else tot/n - return 'Total time %.5f s, n calls: %i, time/call: %.5f s'%(tot, n, ave) - + def report_profile(self, detailed=False, **format_num_kwargs): + """Report timing information on calls to setup, compute, and apply + """ print(self.stage_name, self.service_name) - print('- setup: ', format(self.setup_times)) - if detailed: - print(' Individual runs: ', ', '.join(['%i: %.3f s' % (i, t) for i, t in enumerate(self.setup_times)])) - print('- calc: ', format(self.calc_times)) - if detailed: - print(' Individual runs: ', ', '.join(['%i: %.3f s' % (i, t) for i, t in enumerate(self.calc_times)])) - print('- apply: ', format(self.apply_times)) - if detailed: - print(' Individual runs: ', ', '.join(['%i: %.3f s' % (i, t) for i, t in enumerate(self.apply_times)])) + for func_str, times in [ + ('- setup: ', self.setup_times), + ('- compute: ', self.calc_times), + ('- apply: ', self.apply_times) + ]: + print(func_str, + format_times(times=times, + nindent_detailed=len(func_str) + 1, + detailed=detailed, **format_num_kwargs) + ) def select_params(self, selections, error_on_missing=False): """Apply the `selections` to contained ParamSet. diff --git a/pisa/stages/utils/kde.py b/pisa/stages/utils/kde.py index 8cf9e9351..02bd5c076 100644 --- a/pisa/stages/utils/kde.py +++ b/pisa/stages/utils/kde.py @@ -5,6 +5,7 @@ from copy import deepcopy import numpy as np +from time import time from pisa.core.stage import Stage from pisa.core.binning import MultiDimBinning, OneDimBinning @@ -132,11 +133,21 @@ def setup_function(self): @profile def apply(self): - # this is special, we want the actual event weights in the kde - # therefor we're overwritting the apply function - # normally in a stage you would implement the `apply_function` method - # and not the `apply` method! + """This is special, we want the actual event weights in the kde + therefor we're overwritting the apply function + normally in a stage you would implement the `apply_function` method + and not the `apply` method! We also have to reimplement the profiling + functionality in apply of the Base class""" + + if self.profile: + start_t = time() + self.apply_function() + end_t = time() + self.apply_times.append(end_t - start_t) + else: + self.apply_function() + def apply_function(self): for container in self.data: if self.stash_valid: diff --git a/pisa/utils/format.py b/pisa/utils/format.py index 1c48b4e55..48ac3f517 100644 --- a/pisa/utils/format.py +++ b/pisa/utils/format.py @@ -247,18 +247,22 @@ def split(string, sep=',', force_case=None, parse_func=None): def arg_str_seq_none(inputs, name): """Simple input handler. + Parameters ---------- inputs : None, string, or iterable of strings Input value(s) provided by caller name : string Name of input, used for producing a meaningful error message + Returns ------- inputs : None, or list of strings + Raises ------ TypeError if unrecognized type + """ if isinstance(inputs, str): inputs = [inputs] @@ -1267,6 +1271,52 @@ def format_num( return left_delimiter + num_str + right_delimiter +def format_times(times, nindent_detailed=0, detailed=False, **format_num_kwargs): + """Report statistics derived from a sample of run times, + whose size may represent the number of calls to some function, + using a custom number format. + + Parameters + ---------- + times : Sequence of float + Sequence of run times + nindent_detailed : int, optional + Number of spaces for indentation of detailed info + detailed : bool, default False + Whether to output every individual run time also + **format_num_kwargs : dict, optional + Arguments to `format_num`: refer to its documentation for + the list of all possible arguments. + + Returns + ------- + formatted : str + + """ + assert isinstance(times, Sequence) + tot = np.sum(times) + n = len(times) + if n == 0: + return 'n calls: 0' + ave = format_num(tot/n, **format_num_kwargs) + tot = format_num(tot, **format_num_kwargs) + max_time = format_num(np.max(times), **format_num_kwargs) + min_time = format_num(np.min(times), **format_num_kwargs) + formatted = f'Total time (s): {tot}, n calls: {n}' + if n > 1: + formatted += ( + f', time/call (s): mean {ave}, max. {max_time}, min. {min_time}' + ) + if detailed: + assert isinstance(nindent_detailed, int) + formatted += '\n' + ' ' * nindent_detailed + 'Individual runs: ' + for i, t in enumerate(times): + formatted += '%i: %s s, ' % ( + i, format_num(t, **format_num_kwargs) + ) + return formatted + + def test_format_num(): """Unit tests for the `format_num` function""" # sci_thresh