From f60e40f4dec7f343f204bdeaa281821829c1e21b Mon Sep 17 00:00:00 2001 From: pciturri Date: Wed, 11 Sep 2024 16:17:00 +0200 Subject: [PATCH] tests: improved coverage of unit tests. Added docstrings to engine functions. --- floatcsep/infrastructure/engine.py | 126 ++++++++++++++++++------- tests/unit/test_commands.py | 109 ++++++++++++++++++++++ tests/unit/test_engine.py | 86 +++++++++++++++++ tests/unit/test_plot_handler.py | 142 +++++++++++++++++++++++++++++ tests/unit/test_reporting.py | 64 +++++++++++++ 5 files changed, 493 insertions(+), 34 deletions(-) create mode 100644 tests/unit/test_commands.py create mode 100644 tests/unit/test_engine.py create mode 100644 tests/unit/test_plot_handler.py create mode 100644 tests/unit/test_reporting.py diff --git a/floatcsep/infrastructure/engine.py b/floatcsep/infrastructure/engine.py index e75ede9..6a99afb 100644 --- a/floatcsep/infrastructure/engine.py +++ b/floatcsep/infrastructure/engine.py @@ -2,38 +2,49 @@ class Task: + """ + Represents a unit of work to be executed later as part of a task graph. + + A Task wraps an object instance, a method, and its arguments to allow for deferred + execution. This is useful in workflows where tasks need to be executed in a specific order, + often dictated by dependencies on other tasks. + + For instance, can wrap a floatcsep.model.Model, its method 'create_forecast' and the + argument 'time_window', which can be executed later with Task.call() when, for example, + task dependencies (parent nodes) have been completed. + """ def __init__(self, instance, method, **kwargs): """ - Base node of the workload distribution. Wraps lazily objects, methods and their - arguments for them to be executed later. For instance, can wrap a floatcsep.Model, its - method 'create_forecast' and the argument 'time_window', which can be executed later - with Task.call() when, for example, task dependencies (parent nodes) have been completed. Args: - instance: can be floatcsep.Experiment, floatcsep.Model, floatcsep.Evaluation - method: the instance's method to be lazily created - **kwargs: keyword arguments passed to method. + instance: The object instance whose method will be executed later. + method (str): The method of the instance that will be called. + **kwargs: Arguments to pass to the method when it is invoked. + """ self.obj = instance self.method = method self.kwargs = kwargs - self.store = None # Bool for nested tasks. DEPRECATED + self.store = None # Bool for nested tasks. def sign_match(self, obj=None, met=None, kw_arg=None): """ - Checks if the Task matches a given signature for simplicity. + Checks whether the task matches a given function signature. + + This method is used to verify if a task belongs to a given object, method, or if it + uses a specific keyword argument. Useful for identifying tasks in a graph based on + partial matches of their attributes. - Purpose is to check from the outside if the Task is from a given object - (Model, Experiment, etc.), matching either name or object or description Args: - obj: Instance or instance's name str. Instance is preferred - met: Name of the method - kw_arg: Only the value (not key) of the kwargs dictionary + obj: The object instance or its name (str) to match against. + met: The method name to match against. + kw_arg: A specific keyword argument value to match against in the task's arguments. Returns: + bool: True if the task matches the provided signature, False otherwise. """ if self.obj == obj or obj == getattr(self.obj, "name", None): @@ -43,6 +54,13 @@ def sign_match(self, obj=None, met=None, kw_arg=None): return False def __str__(self): + """ + Returns a string representation of the task, including the instance name, method, and + arguments. Useful for debugging purposes. + + Returns: + str: A formatted string describing the task. + """ task_str = f"{self.__class__}\n\t" f"Instance: {self.obj.__class__.__name__}\n" a = getattr(self.obj, "name", None) if a: @@ -54,6 +72,16 @@ def __str__(self): return task_str[:-2] def run(self): + """ + Executes the task by calling the method on the object instance with the stored + arguments. If the instance has a `store` attribute, it will use that instead of the + instance itself. Once executed, the result is stored in the `store` attribute if any + output is produced. + + Returns: + The output of the method execution, or None if the method does not return anything. + """ + if hasattr(self.obj, "store"): self.obj = self.obj.store output = getattr(self.obj, self.method)(**self.kwargs) @@ -65,6 +93,12 @@ def run(self): return output def __call__(self, *args, **kwargs): + """ + A callable alias for the `run` method. Allows the task to be invoked directly. + + Returns: + The result of the `run` method. + """ return self.run() def check_exist(self): @@ -73,21 +107,35 @@ def check_exist(self): class TaskGraph: """ - Context manager of floatcsep workload distribution. - - Assign tasks to a node and defines their dependencies (parent nodes). - Contains a 'tasks' dictionary whose dict_keys are the Task to be - executed with dict_values as the Task's dependencies. + Context manager of floatcsep workload distribution. A TaskGraph is responsible for adding + tasks, managing dependencies between tasks, and executing tasks in the correct order. + Tasks in the graph can depend on one another, and the graph ensures that each task is run + after all of its dependencies have been satisfied. Contains a 'tasks' dictionary whose + dict_keys are the Task to be executed with dict_values as the Task's dependencies. + + Attributes: + tasks (OrderedDict): A dictionary where the keys are Task objects and the values are + lists of dependent Task objects. + _ntasks (int): The current number of tasks in the graph. + name (str): A name identifier for the task graph. """ def __init__(self): - + """ + Initializes the TaskGraph with an empty task dictionary and task count. + """ self.tasks = OrderedDict() self._ntasks = 0 - self.name = "floatcsep.utils.TaskGraph" + self.name = "floatcsep.infrastructure.engine.TaskGraph" @property def ntasks(self): + """ + Returns the number of tasks currently in the graph. + + Returns: + int: The total number of tasks in the graph. + """ return self._ntasks @ntasks.setter @@ -96,31 +144,32 @@ def ntasks(self, n): def add(self, task): """ - Simply adds a defined task to the graph. + Adds a new task to the task graph. - Args: - task: floatcsep.utils.Task + The task is added to the dictionary of tasks with no dependencies by default. - Returns: + Args: + task (Task): The task to be added to the graph. """ self.tasks[task] = [] self.ntasks += 1 def add_dependency(self, task, dinst=None, dmeth=None, dkw=None): """ - Adds a dependency to a task already inserted to the TaskGraph. + Adds a dependency to a task already in the graph. - Searchs - within the pre-added tasks a signature match by their name/instance, - method and keyword_args. + Searches for other tasks within the graph whose signature matches the provided + object instance, method name, or keyword argument. Any matches are added as + dependencies to the provided task. Args: - task: Task to which a dependency will be asigned - dinst: object/name of the dependency - dmeth: method of the dependency - dkw: keyword argument of the dependency + task (Task): The task to which dependencies will be added. + dinst: The object instance or name of the dependency. + dmeth: The method name of the dependency. + dkw: A specific keyword argument value of the dependency. Returns: + None """ deps = [] for i, other_tasks in enumerate(self.tasks.keys()): @@ -131,15 +180,24 @@ def add_dependency(self, task, dinst=None, dmeth=None, dkw=None): def run(self): """ - Iterates through all the graph tasks and runs them. + Executes all tasks in the task graph in the correct order based on dependencies. + + Iterates over each task in the graph and runs it after its dependencies have been + resolved. Returns: + None """ for task, deps in self.tasks.items(): task.run() def __call__(self, *args, **kwargs): + """ + A callable alias for the `run` method. Allows the task graph to be invoked directly. + Returns: + None + """ return self.run() def check_exist(self): diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py new file mode 100644 index 0000000..e4401e3 --- /dev/null +++ b/tests/unit/test_commands.py @@ -0,0 +1,109 @@ +import unittest +from unittest.mock import patch, MagicMock +import floatcsep.commands.main as main_module + + +class TestMainModule(unittest.TestCase): + + @patch('floatcsep.commands.main.Experiment') + @patch('floatcsep.commands.main.plot_catalogs') + @patch('floatcsep.commands.main.plot_forecasts') + @patch('floatcsep.commands.main.plot_results') + @patch('floatcsep.commands.main.plot_custom') + @patch('floatcsep.commands.main.generate_report') + def test_run(self, mock_generate_report, mock_plot_custom, mock_plot_results, + mock_plot_forecasts, mock_plot_catalogs, mock_experiment): + # Mock Experiment instance and its methods + mock_exp_instance = MagicMock() + mock_experiment.from_yml.return_value = mock_exp_instance + + # Call the function + main_module.run(config='dummy_config') + + # Verify the calls to the Experiment class methods + mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config') + mock_exp_instance.stage_models.assert_called_once() + mock_exp_instance.set_tasks.assert_called_once() + mock_exp_instance.run.assert_called_once() + + # Verify that plotting and report generation functions were called + mock_plot_catalogs.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_forecasts.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_results.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_custom.assert_called_once_with(experiment=mock_exp_instance) + mock_generate_report.assert_called_once_with(experiment=mock_exp_instance) + + @patch('floatcsep.commands.main.Experiment') + def test_stage(self, mock_experiment): + # Mock Experiment instance and its methods + mock_exp_instance = MagicMock() + mock_experiment.from_yml.return_value = mock_exp_instance + + # Call the function + main_module.stage(config='dummy_config') + + # Verify the calls to the Experiment class methods + mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config') + mock_exp_instance.stage_models.assert_called_once() + + @patch('floatcsep.commands.main.Experiment') + @patch('floatcsep.commands.main.plot_catalogs') + @patch('floatcsep.commands.main.plot_forecasts') + @patch('floatcsep.commands.main.plot_results') + @patch('floatcsep.commands.main.plot_custom') + @patch('floatcsep.commands.main.generate_report') + def test_plot(self, mock_generate_report, mock_plot_custom, mock_plot_results, + mock_plot_forecasts, mock_plot_catalogs, mock_experiment): + # Mock Experiment instance and its methods + mock_exp_instance = MagicMock() + mock_experiment.from_yml.return_value = mock_exp_instance + + # Call the function + main_module.plot(config='dummy_config') + + # Verify the calls to the Experiment class methods + mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config') + mock_exp_instance.stage_models.assert_called_once() + mock_exp_instance.set_tasks.assert_called_once() + + # Verify that plotting and report generation functions were called + mock_plot_catalogs.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_forecasts.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_results.assert_called_once_with(experiment=mock_exp_instance) + mock_plot_custom.assert_called_once_with(experiment=mock_exp_instance) + mock_generate_report.assert_called_once_with(experiment=mock_exp_instance) + + @patch('floatcsep.commands.main.Experiment') + @patch('floatcsep.commands.main.ExperimentComparison') + @patch('floatcsep.commands.main.reproducibility_report') + def test_reproduce(self, mock_reproducibility_report, mock_exp_comparison, mock_experiment): + # Mock Experiment instances and methods + mock_reproduced_exp = MagicMock() + mock_original_exp = MagicMock() + mock_experiment.from_yml.side_effect = [mock_reproduced_exp, mock_original_exp] + + mock_comp_instance = MagicMock() + mock_exp_comparison.return_value = mock_comp_instance + + # Call the function + main_module.reproduce(config='dummy_config') + + # Verify the calls to the Experiment class methods + mock_experiment.from_yml.assert_any_call('dummy_config', repr_dir="reproduced") + mock_reproduced_exp.stage_models.assert_called_once() + mock_reproduced_exp.set_tasks.assert_called_once() + mock_reproduced_exp.run.assert_called_once() + + mock_experiment.from_yml.assert_any_call(mock_reproduced_exp.original_config, + rundir=mock_reproduced_exp.original_run_dir) + mock_original_exp.stage_models.assert_called_once() + mock_original_exp.set_tasks.assert_called_once() + + # Verify comparison and reproducibility report calls + mock_exp_comparison.assert_called_once_with(mock_original_exp, mock_reproduced_exp) + mock_comp_instance.compare_results.assert_called_once() + mock_reproducibility_report.assert_called_once_with(exp_comparison=mock_comp_instance) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py new file mode 100644 index 0000000..d8b03c5 --- /dev/null +++ b/tests/unit/test_engine.py @@ -0,0 +1,86 @@ +import unittest + +from floatcsep.infrastructure.engine import Task, TaskGraph + + +class DummyClass: + def __init__(self, name): + self.name = name + + def dummy_method(self, value): + return value * 2 + + +class TestTask(unittest.TestCase): + + def setUp(self): + self.obj = DummyClass("TestObj") + self.task = Task(instance=self.obj, method="dummy_method", value=10) + + def test_init(self): + self.assertEqual(self.task.obj, self.obj) + self.assertEqual(self.task.method, "dummy_method") + self.assertEqual(self.task.kwargs["value"], 10) + + def test_sign_match(self): + self.assertTrue(self.task.sign_match(obj=self.obj, met="dummy_method", kw_arg=10)) + self.assertFalse(self.task.sign_match(obj="NonMatching", met="dummy_method", kw_arg=10)) + + def test___str__(self): + task_str = str(self.task) + self.assertIn("TestObj", task_str) + self.assertIn("dummy_method", task_str) + self.assertIn("value", task_str) + + def test_run(self): + result = self.task.run() + self.assertEqual(result, 20) + self.assertEqual(self.task.store, 20) + + def test___call__(self): + result = self.task() + self.assertEqual(result, 20) + + def test_check_exist(self): + self.assertIsNone(self.task.check_exist()) + + +class TestTaskGraph(unittest.TestCase): + + def setUp(self): + self.graph = TaskGraph() + self.obj = DummyClass("TestObj") + self.task_a = Task(instance=self.obj, method='dummy_method', value=10) + self.task_b = Task(instance=self.obj, method='dummy_method', value=20) + + def test_init(self): + self.assertEqual(self.graph.ntasks, 0) + self.assertEqual(self.graph.name, "floatcsep.infrastructure.engine.TaskGraph") + + def test_add(self): + self.graph.add(self.task_a) + self.assertIn(self.task_a, self.graph.tasks) + self.assertEqual(self.graph.ntasks, 1) + + def test_add_dependency(self): + self.graph.add(self.task_a) + self.graph.add(self.task_b) + self.graph.add_dependency(self.task_b, dinst=self.obj, dmeth='dummy_method', dkw=10) + self.assertIn(self.task_a, self.graph.tasks[self.task_b]) + + def test_run(self): + self.graph.add(self.task_a) + self.graph.run() + self.assertEqual(self.task_a.store, 20) + + def test___call__(self): + self.graph.add(self.task_a) + self.graph() + self.assertEqual(self.task_a.store, 20) + + def test_check_exist(self): + self.assertIsNone(self.graph.check_exist()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_plot_handler.py b/tests/unit/test_plot_handler.py new file mode 100644 index 0000000..d1fd852 --- /dev/null +++ b/tests/unit/test_plot_handler.py @@ -0,0 +1,142 @@ +import unittest +from unittest.mock import patch, MagicMock +import floatcsep.postprocess.plot_handler as plot_handler + + +class TestPlotHandler(unittest.TestCase): + + @patch("matplotlib.pyplot.savefig") + @patch("floatcsep.postprocess.plot_handler.timewindow2str") + def test_plot_results(self, mock_timewindow2str, mock_savefig): + mock_experiment = MagicMock() + mock_test = MagicMock() + mock_experiment.tests = [mock_test] + mock_timewindow2str.return_value = ["2021-01-01", "2021-12-31"] + + plot_handler.plot_results(mock_experiment) + + mock_timewindow2str.assert_called_once_with(mock_experiment.timewindows) + mock_test.plot_results.assert_called_once_with( + ["2021-01-01", "2021-12-31"], mock_experiment.models, mock_experiment.registry + ) + + @patch("matplotlib.pyplot.savefig") + @patch("floatcsep.postprocess.plot_handler.parse_plot_config") + @patch("floatcsep.postprocess.plot_handler.parse_projection") + def test_plot_forecasts(self, mock_parse_projection, mock_parse_plot_config, mock_savefig): + mock_experiment = MagicMock() + mock_model = MagicMock() + mock_experiment.models = [mock_model] + mock_parse_plot_config.return_value = {"projection": "Mercator"} + mock_parse_projection.return_value = MagicMock() + mock_experiment.postprocess.get.return_value = True + + plot_handler.plot_forecasts(mock_experiment) + + mock_parse_plot_config.assert_called_once_with( + mock_experiment.postprocess.get("plot_forecasts", {}) + ) + mock_model.get_forecast().plot.assert_called() + + # Verify that pyplot.savefig was called to save the plot + mock_savefig.assert_called() + + @patch("matplotlib.pyplot.Figure.savefig") # Mocking savefig on the Figure object + @patch("floatcsep.postprocess.plot_handler.parse_plot_config") + @patch("floatcsep.postprocess.plot_handler.parse_projection") + def test_plot_catalogs( + self, mock_parse_projection, mock_parse_plot_config, mock_savefig + ): + # Mock the experiment and its components + mock_experiment = MagicMock() + mock_catalog = MagicMock() + mock_plot = MagicMock() + mock_ax = MagicMock() + mock_figure = MagicMock() + + mock_experiment.catalog_repo.get_test_cat = MagicMock(return_value=mock_catalog) + mock_catalog.plot = mock_plot + mock_plot.return_value = mock_ax + mock_ax.get_figure.return_value = ( + mock_figure + ) + + mock_parse_plot_config.return_value = {"projection": "Mercator"} + mock_parse_projection.return_value = MagicMock() + mock_experiment.registry.get_figure.return_value = "cat.png" + + plot_handler.plot_catalogs(mock_experiment) + + mock_parse_plot_config.assert_called_once_with( + mock_experiment.postprocess.get("plot_catalog", {}) + ) + + mock_plot.assert_called_once_with( + plot_args=mock_parse_plot_config.return_value, + ) + + mock_figure.savefig.assert_called_once_with( + "cat.png", dpi=300 + ) + + mock_savefig.assert_called() + + @patch("os.path.isfile", return_value=True) + @patch("os.path.realpath", return_value="dir") + @patch("os.path.dirname", return_value="dir") + @patch("importlib.util.spec_from_file_location") + @patch("importlib.util.module_from_spec") + def test_plot_custom(self, mock_module_from_spec, mock_spec_from_file_location, + mock_dirname, mock_realpath, mock_isfile): + mock_experiment = MagicMock() + mock_spec = MagicMock() + mock_module = MagicMock() + mock_func = MagicMock() + + mock_spec_from_file_location.return_value = mock_spec + mock_module_from_spec.return_value = mock_module + mock_module.plot_function = mock_func + mock_experiment.postprocess.get.return_value = "custom_script.py:plot_function" + mock_experiment.registry.abs.return_value = "custom_script.py" + + plot_handler.plot_custom(mock_experiment) + + mock_spec_from_file_location.assert_called_once_with("custom_script", "custom_script.py") + mock_module_from_spec.assert_called_once_with(mock_spec) + mock_func.assert_called_once_with(mock_experiment) + + def test_parse_plot_config(self): + # Test True case + result = plot_handler.parse_plot_config(True) + self.assertEqual(result, {}) + + # Test False case + result = plot_handler.parse_plot_config(False) + self.assertIsNone(result) + + # Test dict case + mock_dict = {"key": "value"} + result = plot_handler.parse_plot_config(mock_dict) + self.assertEqual(result, mock_dict) + + # Test string case with valid script and function + result = plot_handler.parse_plot_config("script.py:plot_func") + self.assertEqual(result, ("script.py", "plot_func")) + + def test_parse_projection(self): + # Test None case + result = plot_handler.parse_projection(None) + self.assertEqual(result.__class__.__name__, "PlateCarree") + + # Test dict case with valid projection + mock_config = {"Mercator": {"central_longitude": 0.0}} + result = plot_handler.parse_projection(mock_config) + self.assertEqual(result.__class__.__name__, "Mercator") + + # Test invalid projection case + result = plot_handler.parse_projection("InvalidProjection") + self.assertEqual(result.__class__.__name__, "PlateCarree") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py new file mode 100644 index 0000000..709f745 --- /dev/null +++ b/tests/unit/test_reporting.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import patch, MagicMock +import floatcsep.postprocess.reporting as reporting + + +class TestReporting(unittest.TestCase): + + @patch("floatcsep.postprocess.reporting.custom_report") + @patch("floatcsep.postprocess.reporting.MarkdownReport") + def test_generate_report_with_custom_function( + self, mock_markdown_report, mock_custom_report + ): + # Mock experiment with a custom report function + mock_experiment = MagicMock() + mock_experiment.postprocess.get.return_value = "custom_report_function" + + # Call the generate_report function + reporting.generate_report(mock_experiment) + + # Assert that custom_report was called with the experiment + mock_custom_report.assert_called_once_with("custom_report_function", mock_experiment) + + @patch("floatcsep.postprocess.reporting.MarkdownReport") + def test_generate_standard_report(self, mock_markdown_report): + # Mock experiment without a custom report function + mock_experiment = MagicMock() + mock_experiment.postprocess.get.return_value = None + mock_experiment.registry.get_figure.return_value = "figure_path" + mock_experiment.magnitudes = [0, 1] + # Call the generate_report function + reporting.generate_report(mock_experiment) + + # Ensure the MarkdownReport methods are called + mock_instance = mock_markdown_report.return_value + mock_instance.add_title.assert_called_once() + mock_instance.add_heading.assert_called() + mock_instance.add_figure.assert_called() + + +class TestMarkdownReport(unittest.TestCase): + + def test_add_title(self): + report = reporting.MarkdownReport() + report.add_title("Test Title", "Subtitle") + self.assertIn("# Test Title", report.markdown[0]) + + def test_add_table_of_contents(self): + report = reporting.MarkdownReport() + report.toc = [("Title", 1, "locator")] + report.table_of_contents() + self.assertIn("# Table of Contents", report.markdown[0]) + + def test_save_report(self): + report = reporting.MarkdownReport() + report.markdown = [["# Test Title\n", "Some content\n"]] + with patch("builtins.open", unittest.mock.mock_open()) as mock_file: + report.save("/path/to/save") + mock_file.assert_called_with("/path/to/save/report.md", "w") + mock_file().writelines.assert_called_with(["# Test Title\n", "Some content\n"]) + + + +if __name__ == "__main__": + unittest.main()