From 4f7d88cb62449c38da0cb96c9a8106f8614cf552 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 16 May 2023 12:21:03 +0200 Subject: [PATCH 01/17] docs: modify rtd configuration Refs: ST-2, ST-122 --- readthedocs.yml => .readthedocs.yaml | 11 ++++++++++- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) rename readthedocs.yml => .readthedocs.yaml (60%) diff --git a/readthedocs.yml b/.readthedocs.yaml similarity index 60% rename from readthedocs.yml rename to .readthedocs.yaml index 5ec698c..9c3a5ef 100644 --- a/readthedocs.yml +++ b/.readthedocs.yaml @@ -3,13 +3,22 @@ version: 2 build: os: "ubuntu-22.04" tools: - python: "3.11" + python: "3.8" # Build from the docs/ directory with Sphinx sphinx: + builder: html configuration: docs/source/conf.py # Explicitly set the version of Python and its requirements python: install: - requirements: dev-requirements.txt + - method: pip + path: src/scintillometry + extra_requirements: + - dev + +submodules: + include: all + recursive: true diff --git a/docs/source/conf.py b/docs/source/conf.py index d36b443..2b35451 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,7 @@ def setup(app): project = "Scintillometry" copyright = f"2019-{date.today().year}, Scintillometry Contributors" author = "Scintillometry Contributors" -release = "1.0.0" +release = "1.0.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index fd435ec..76a4ea8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "scintillometry" -version = "1.0.0" +version = "1.0.2" authors = [ { name="Scintillometry Contributors", email="" }, ] From c32a030df520a1c82d4154870ad4c18bde9dfd70 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 16 May 2023 12:25:29 +0200 Subject: [PATCH 02/17] docs: fix rtd configuration Corrects package installation path. Refs: ST-2, ST-122 --- .readthedocs.yaml | 3 +-- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9c3a5ef..59ef9f6 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,10 +15,9 @@ python: install: - requirements: dev-requirements.txt - method: pip - path: src/scintillometry + path: . extra_requirements: - dev - submodules: include: all recursive: true diff --git a/docs/source/conf.py b/docs/source/conf.py index 2b35451..ec6105d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,7 @@ def setup(app): project = "Scintillometry" copyright = f"2019-{date.today().year}, Scintillometry Contributors" author = "Scintillometry Contributors" -release = "1.0.2" +release = "1.0.3" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 76a4ea8..16219f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "scintillometry" -version = "1.0.2" +version = "1.0.3" authors = [ { name="Scintillometry Contributors", email="" }, ] From 5bc2beba488b68409a0af52533e1dabe99dd7763 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 16 May 2023 14:14:19 +0200 Subject: [PATCH 03/17] fix: main interface references incorrect arguments Refs: ST-26, ST-123 --- docs/source/conf.py | 7 ++++--- pyproject.toml | 2 +- src/scintillometry/main.py | 4 ++-- src/scintillometry/wrangler/data_parser.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index ec6105d..7fabb0b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -"""Copyright 2023 Scintillometry-Tools Contributors. +"""Copyright 2023 Scintillometry Contributors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ Configuration file for the Sphinx documentation builder. -For the full list of built-in configuration values, see the documentation: +For the full list of built-in configuration values, see the +documentation: https://www.sphinx-doc.org/en/master/usage/configuration.html """ @@ -49,7 +50,7 @@ def setup(app): project = "Scintillometry" copyright = f"2019-{date.today().year}, Scintillometry Contributors" author = "Scintillometry Contributors" -release = "1.0.3" +release = "1.0.4" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 16219f1..b55a2c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "scintillometry" -version = "1.0.3" +version = "1.0.4" authors = [ { name="Scintillometry Contributors", email="" }, ] diff --git a/src/scintillometry/main.py b/src/scintillometry/main.py index b2cf757..2ae1058 100644 --- a/src/scintillometry/main.py +++ b/src/scintillometry/main.py @@ -301,13 +301,13 @@ def perform_data_parsing(**kwargs): data_parser = DataParser.WranglerParsing() # Parse BLS, weather, and topographical data - datasets = data_parser.stitch.wrangle_data( + datasets = data_parser.wrangle_data( bls_path=kwargs["input"], transect_path=kwargs["transect_path"], calibrate=kwargs["calibration"], station_id=kwargs["station_id"], tzone=kwargs["timezone"], - source="zamg", + weather_source="zamg", ) # Parse vertical measurements diff --git a/src/scintillometry/wrangler/data_parser.py b/src/scintillometry/wrangler/data_parser.py index c729c64..6dccdf3 100644 --- a/src/scintillometry/wrangler/data_parser.py +++ b/src/scintillometry/wrangler/data_parser.py @@ -1009,7 +1009,7 @@ def wrangle_data( weather_data = self.weather.parse_weather( timestamp=bls_time, source=weather_source, - klima_id=station_id, + station_id=station_id, data_dir=weather_dir, timezone=tzone, ) From 083a52e0c97c3d71b1643489a12ada8cf99becc6 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 23 May 2023 17:10:50 +0200 Subject: [PATCH 04/17] feat(metrics): perform regression on datasets Adds method to MetricsFlux to perform regression on labelled data. Refs: ST-7, ST-106 --- docs/source/conf.py | 3 +- pyproject.toml | 2 +- src/scintillometry/metrics/calculations.py | 38 +++++++++++++++++++- tests/test_metrics_calculations.py | 40 ++++++++++++++++++++++ 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7fabb0b..d142dce 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -50,7 +50,7 @@ def setup(app): project = "Scintillometry" copyright = f"2019-{date.today().year}, Scintillometry Contributors" author = "Scintillometry Contributors" -release = "1.0.4" +release = "1.0.5" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -124,6 +124,7 @@ def setup(app): .. |Psi_m| replace:: :math:`\\Psi_{{m}}` .. |Q_0| replace:: :math:`Q_{{0}}` .. |r| replace:: :math:`r` +.. |R^2| replace:: math:`R^{{2}}` .. |R_dry| replace:: :math:`R_{{dry}}` .. |R_v| replace:: :math:`R_{{v}}` .. |rho| replace:: :math:`\\rho` diff --git a/pyproject.toml b/pyproject.toml index b55a2c0..17aba01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "scintillometry" -version = "1.0.4" +version = "1.0.5" authors = [ { name="Scintillometry Contributors", email="" }, ] diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index 4ff9ea3..34009fd 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -19,11 +19,12 @@ import kneed import pandas as pd +from sklearn.linear_model import LinearRegression -from scintillometry.backend.iterations import IterationMost from scintillometry.backend.constants import AtmosConstants from scintillometry.backend.constructions import ProfileConstructor from scintillometry.backend.derivations import DeriveScintillometer +from scintillometry.backend.iterations import IterationMost from scintillometry.backend.transects import TransectParameters from scintillometry.visuals.plotting import FigurePlotter @@ -216,6 +217,41 @@ def match_time_at_threshold(self, series, threshold, lessthan=True, min_time=Non return match_time + def get_regression(self, x_data, y_data, intercept=True): + """Performs regression on labelled data. + + Args: + x_data (pd.Series): Labelled explanatory data. + y_data (pd.Series): Labelled response data. + intercept (bool): If True, calculate intercept (e.g. data is + not centred). Default True. + + Returns: + dict: Contains the fitted estimator for regression data, the + coefficient of determination |R^2|, and predicted values for + a fitted regression line. + """ + + scatter_frame = pd.merge( + x_data, y_data, left_index=True, right_index=True, sort=True + ) + scatter_frame = scatter_frame.dropna(axis=0) + x_fit_data = scatter_frame.iloc[:, 0].values.reshape(-1, 1) + y_fit_data = scatter_frame.iloc[:, 1].values.reshape(-1, 1) + + linear_regressor = LinearRegression(fit_intercept=intercept) + estimator = linear_regressor.fit(x_fit_data, y_fit_data) + score = estimator.score(x_fit_data, y_fit_data) + predictions = linear_regressor.predict(x_fit_data) + + regression = { + "fit": estimator, + "score": score, + "regression_line": predictions, + } + + return regression + def get_elbow_point(self, series, min_index=None, max_index=None): """Calculate elbow point using Kneedle algorithm. diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 1a0559d..23e10aa 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -54,6 +54,7 @@ def test_foobar(self, foo_mock, bar_mock): import pandas as pd import pandas.api.types as ptypes import pytest +import sklearn import scintillometry.backend.constants import scintillometry.backend.constructions @@ -300,6 +301,45 @@ def test_match_time_at_threshold( else: assert compare_time is None + @pytest.mark.dependency(name="TestMetricsFlux::test_get_regression") + @pytest.mark.parametrize("arg_intercept", [True, False]) + @pytest.mark.parametrize("arg_mismatch_index", [True, False]) + def test_get_regression( + self, conftest_boilerplate, arg_intercept, arg_mismatch_index + ): + """Perform regression on labelled data.""" + + rng = np.random.default_rng() + test_index = np.arange(0, 1000, 10) + test_data = rng.random(size=len(test_index)) + + test_x = pd.Series(name="obukhov", data=test_data, index=test_index) + test_y = pd.Series(name="other_obukhov", data=test_data + 0.5, index=test_index) + if arg_mismatch_index: + test_y = test_y[:-5] + conftest_boilerplate.index_not_equal(test_x.index, test_y.index) + + assert isinstance(test_x, pd.Series) + assert test_x.shape == (100,) + test_keys = ["fit", "score", "regression_line"] + + compare_regression = self.test_metrics.get_regression( + x_data=test_x, y_data=test_y, intercept=arg_intercept + ) + assert isinstance(compare_regression, dict) + assert all(key in compare_regression for key in test_keys) + assert isinstance( + compare_regression["fit"], sklearn.linear_model.LinearRegression + ) + if not arg_intercept: + assert not compare_regression["fit"].fit_intercept + else: + assert compare_regression["fit"].fit_intercept + assert isinstance(compare_regression["score"], float) + assert isinstance(compare_regression["regression_line"], np.ndarray) + assert not (np.isnan(compare_regression["regression_line"])).any() + assert len(test_y.index) == len(compare_regression["regression_line"]) + @pytest.mark.dependency(name="TestMetricsFlux::test_get_elbow_point") @pytest.mark.parametrize("arg_min_index", [None, 0, 50]) @pytest.mark.parametrize("arg_max_index", [None, 180, 190]) From 743b8a15b9b2304571c811a0b2ab200109685582 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 23 May 2023 17:16:06 +0200 Subject: [PATCH 05/17] tests(metrics): generate datetime index for series Refs: ST-3, ST-7, ST-106 --- tests/test_metrics_calculations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 23e10aa..1d6ef40 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -310,7 +310,7 @@ def test_get_regression( """Perform regression on labelled data.""" rng = np.random.default_rng() - test_index = np.arange(0, 1000, 10) + test_index = pd.date_range(start=self.test_timestamp, periods=100, freq="T") test_data = rng.random(size=len(test_index)) test_x = pd.Series(name="obukhov", data=test_data, index=test_index) @@ -318,7 +318,7 @@ def test_get_regression( if arg_mismatch_index: test_y = test_y[:-5] conftest_boilerplate.index_not_equal(test_x.index, test_y.index) - + assert test_y.shape == (95,) assert isinstance(test_x, pd.Series) assert test_x.shape == (100,) test_keys = ["fit", "score", "regression_line"] From 6bd371bcb3ac345b63a5bae1b3c8e18b1d9e1619 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Tue, 23 May 2023 18:21:27 +0200 Subject: [PATCH 06/17] feat(visuals): plot scatter with optional regression line Creates scatter plot with optional regression line and optional annotation for the coefficient of determination. Refs: ST-3, ST-8, ST-106 --- src/scintillometry/visuals/plotting.py | 72 ++++++++++++++++++++++++-- tests/test_visuals_plotting.py | 62 ++++++++++++++++++++++ 2 files changed, 130 insertions(+), 4 deletions(-) diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index 9b26d9e..bc7e417 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -24,6 +24,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import pandas as pd class FigureFormat: @@ -91,8 +92,8 @@ def get_site_name(self, site_name, dataframe=None): Args: site_name (str): Location of data collection. - dataframe (pd.DataFrame): Any collected dataset. Default - None. + dataframe (pd.DataFrame or pd.Series): Any collected + dataset. Default None. Returns: str: Location of data collection. Returns empty string if no @@ -193,7 +194,8 @@ def get_date_and_timezone(self, data): """Return first time index and timezone. Args: - data (pd.DataFrame): TZ-aware dataframe with DatetimeIndex. + data (pd.DataFrame or pd.Series): TZ-aware dataframe with + DatetimeIndex. Returns: dict[str, datetime.tzinfo]: Date formatted as @@ -227,7 +229,7 @@ def title_plot(self, title, timestamp, location=""): if not isinstance(timestamp, str): timestamp = timestamp.strftime("%d %B %Y") - title_string = "".join((f"{title}{location},\n{timestamp}")) + title_string = f"{title}{location},\n{timestamp}" plt.title(title_string, fontweight="bold") plt.legend(loc="upper left") @@ -654,6 +656,68 @@ def plot_innflux(self, iter_data, innflux_data, name="obukhov", site=""): return figure, axes + def plot_scatter( + self, x_data, y_data, name, sources, score=None, regression_line=None, site="" + ): + """Plots scatter between two datasets with a regression line. + + Args: + x_data (pd.Series): Labelled explanatory data. + y_data (pd.Series): Labelled response data. + name (str): Name of variable. + sources (list[str, str]): Names of data sources formatted + as: [, ]. + score (float): Coefficient of determination |R^2|. + Default None. + regression_line (np.ndarray): Values for regression line. + Default None. + site (str): Location of data collection. Default empty + string. + + Returns: + tuple[plt.Figure, plt.Axes]: Regression plot of explanatory + and response data, with fitted regression line and + regression score. + """ + + figure = plt.figure(figsize=(8, 8)) + date = self.get_date_and_timezone(data=x_data)["date"] + scatter_frame = pd.merge( + x_data, y_data, left_index=True, right_index=True, sort=True + ) + + scatter_frame = scatter_frame.dropna(axis=0) # drop mismatched index + x_fit_data = scatter_frame.iloc[:, 0].values.reshape(-1, 1) + y_fit_data = scatter_frame.iloc[:, 1].values.reshape(-1, 1) + plt.scatter(x_fit_data, y_fit_data, marker=".", color="gray") + + if regression_line is not None: + plt.plot( + x_fit_data, regression_line, color="black", label="Line of Best Fit" + ) + + axes = plt.gca() + if score is not None: + plt.text( + 0.05, + 0.9, + f"R$^{2}$= {score:.5f}", + horizontalalignment="left", + verticalalignment="bottom", + transform=axes.transAxes, + ) + variable_name = self.label_selector(name) + sources_string = f"{sources[0].title()} and {sources[1].title()}" + title_string = f"{variable_name[0]} Regression Between\n{sources_string}" + site_label = self.get_site_name(site_name=site, dataframe=x_data) + self.title_plot(title=title_string, timestamp=date, location=site_label) + x_label = self.merge_label_with_unit(label=variable_name) + y_label = self.merge_label_with_unit(label=variable_name) + plt.xlabel(f"{sources[0].title()} {x_label}") + plt.ylabel(f"{sources[1].title()} {y_label}") + + return figure, axes + def plot_vertical_profile( self, vertical_data, time_idx, name, site="", y_lim=None, **kwargs ): diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index 6616ad4..d1039c5 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -683,6 +683,68 @@ def test_plot_innflux( plt.close("all") + @pytest.mark.dependency(name="TestVisualsPlotting::test_plot_scatter") + @pytest.mark.parametrize("arg_site", ["", "Test Location", None]) + @pytest.mark.parametrize("arg_score", [None, 0.561734521]) + @pytest.mark.parametrize("arg_regression", [True, False]) + def test_plot_scatter( + self, conftest_boilerplate, arg_score, arg_site, arg_regression + ): + """Plot scatter between two datasets with regression line.""" + + test_name = "obukhov" + rng = np.random.default_rng() + test_index = pd.date_range(start=self.test_timestamp, periods=100, freq="T") + test_data = rng.random(size=len(test_index)) + test_x = pd.Series(name=test_name, data=test_data, index=test_index) + test_y = pd.Series(name=test_name, data=test_data + 0.5, index=test_index) + for series in [test_x, test_y]: + assert isinstance(series, pd.Series) + assert series.shape == (100,) + assert not (series.isnull()).any() + if arg_site: + test_site = f" at {arg_site}," + else: + test_site = "," + test_title = ( + "Obukhov Length Regression Between", + f"Baseline and Comparison{test_site}", + f"{self.test_date}", + ) + if not arg_regression: + test_line = None + else: + test_line = np.arange(0, 1000, 10) + + compare_fig, compare_ax = self.test_plotting.plot_scatter( + x_data=test_x, + y_data=test_y, + name=test_name, + sources=["Baseline", "Comparison"], + site=arg_site, + score=arg_score, + regression_line=test_line, + ) + compare_params = { + "plot": (compare_fig, compare_ax), + "x_label": "Baseline Obukhov Length, [m]", + "y_label": "Comparison Obukhov Length, [m]", + "title": "\n".join(test_title), + } + conftest_boilerplate.check_plot(plot_params=compare_params) + + if arg_regression: + _, labels = compare_params["plot"][1].get_legend_handles_labels() + assert "Line of Best Fit" in labels + + if arg_score is not None: + assert ( + compare_params["plot"][1].texts[0].get_text() + == f"R$^{2}$= {arg_score:.5f}" + ) + + plt.close("all") + @pytest.mark.dependency( name="TestVisualsPlotting::test_plot_vertical_profile", depends=["TestVisualsFormatting::test_plot_constant_lines"], From c5b47467a5464b695407b8defffec38d66f527f8 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Wed, 24 May 2023 15:55:56 +0200 Subject: [PATCH 07/17] refactor!: plots from metrics calculations are paired into lists Matplotlib plots generated by methods in metrics_calculations are now structured as list[tuple[plt.Figure, plt.Axes]] instead of stacked into a single tuple. Refs: ST-3, ST-7, ST-8, ST-124, ST-126 --- src/scintillometry/metrics/calculations.py | 79 ++++++++-------- src/scintillometry/visuals/plotting.py | 44 +++++---- tests/conftest.py | 17 +++- tests/test_metrics_calculations.py | 105 ++++++++++++--------- tests/test_visuals_plotting.py | 20 ++-- 5 files changed, 150 insertions(+), 115 deletions(-) diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index 34009fd..3225a92 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -397,11 +397,11 @@ def plot_lapse_rates( string. Returns: - tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: Vertical - profiles of lapse rates on a single axis, and vertical - profiles of parcel temperatures on a single axis. If a - boundary layer height is provided, vertical lines denoting - its height are added to the figures. + list[tuple[plt.Figure, plt.Axes]]: Vertical profiles of + lapse rates on a single axis, and vertical profiles of + parcel temperatures on a single axis. If a boundary layer + height is provided, vertical lines denoting its height are + added to the figures. """ lapse_rates = { @@ -448,7 +448,7 @@ def plot_lapse_rates( suffix=f"{round_time.strftime('%H%M')}_parcel_temperatures", ) - return fig_lapse, axes_lapse, fig_parcel, axes_parcel + return [(fig_lapse, axes_lapse), (fig_parcel, axes_parcel)] def get_switch_time_vertical(self, data, method="static", ri_crit=0.25): """Gets local time of switch between stability conditions. @@ -617,10 +617,10 @@ def plot_switch_time_stability(self, data, local_time, location="", bl_height=No bl_height (int): Boundary layer height. Default None. Returns: - tuple[plt.Figure, plt.Axes]: Vertical profile of potential - temperature. If the gradient potential temperature is also - provided, the two vertical profiles are placed side-by-side - in separate subplots. + list[tuple[plt.Figure, plt.Axes]]: Vertical profile of + potential temperature. If the gradient potential temperature + is also provided, the two vertical profiles are placed + side-by-side in separate subplots. """ round_time = self.get_nearest_time_index( @@ -670,7 +670,7 @@ def plot_switch_time_stability(self, data, local_time, location="", bl_height=No suffix=f"{mil_time}_gradient_potential_temperature_2km", ) - return fig, ax + return [(fig, ax)] def calculate_switch_time( self, datasets, method="sun", switch_time=None, location="" @@ -739,9 +739,9 @@ def iterate_fluxes( Trades speed from vectorisation for more accurate convergence. Args: - z_parameters (dict[float, float]): Tuples of effective and - mean path height |z_eff| and |z_mean| [m], with - stability conditions as keys. + z_parameters (dict[str, tuple[float, float]): Tuples of + effective and mean path height |z_eff| and |z_mean| [m], + with stability conditions as keys. datasets (dict): Contains parsed, tz-aware dataframes, with at least |CT2|, wind speed, air density, and temperature. @@ -809,8 +809,9 @@ def plot_derived_metrics(self, derived_data, time_id, regime=None, location=""): string. Returns: - tuple[plt.Figure, plt.Axes]: Time series comparing sensible - heat fluxes under free convection to on-board software. + list[tuple[plt.Figure, plt.Axes]]: Time series comparing + sensible heat fluxes under free convection to on-board + software. """ fig_convection, ax_convection = self.plotting.plot_convection( @@ -819,25 +820,24 @@ def plot_derived_metrics(self, derived_data, time_id, regime=None, location=""): self.plotting.save_figure( figure=fig_convection, timestamp=time_id, suffix="free_convection" ) + derived_plots = [(fig_convection, ax_convection)] - return fig_convection, ax_convection + return derived_plots def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): """Plot and save time series and comparison of iterated fluxes. Args: - user_args (argparse.Namespace): Namespace of user arguments. - derived_data (pd.DataFrame): Interpolated tz-aware dataframe - with columns for sensible heat fluxes calculated with - MOST and for free convection. - time_id (pd.Timestamp): Start time of scintillometer data + iterated_data (pd.DataFrame): TZ-aware dataframe with + columns for heat fluxes and MOST parameters. + time_stamp (pd.Timestamp): Start time of scintillometer data collection. site_location (str): Name of scintillometer location. Returns: - tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: Time - series of sensible heat flux calculated through MOST, and a - comparison to sensible heat flux under free convection. + list[tuple[plt.Figure, plt.Axes]]: Time series of sensible + heat flux calculated through MOST, and a comparison to + sensible heat flux under free convection. """ plots = self.plotting.plot_iterated_fluxes( @@ -872,7 +872,7 @@ def calculate_standard_metrics( - Calculates effective path heights for all stability conditions. - Derives |CT2| and sensible heat flux for free convection. - - Estimates the time where stability conditions change. + - Estimates the time when stability conditions change. - Calculates sensible heat flux using MOST. - Plots time series comparing sensible heat flux for free convection |H_free| to on-board software, time series of @@ -965,7 +965,7 @@ def calculate_standard_metrics( return data def compare_innflux(self, own_data, innflux_data, location=""): - """Compares SHF and Obukhov lengths to InnFLUX measurements. + """Compares SHF and Obukhov lengths to innFLUX measurements. This wrapper function: @@ -977,7 +977,6 @@ def compare_innflux(self, own_data, innflux_data, location=""): with an argparse.Namespace object. Args: - arguments (argparse.Namespace): User arguments. own_data (pd.DataFrame): Labelled data for SHF and Obukhov length. innflux_data (pd.DataFrame): Eddy covariance data from @@ -986,9 +985,9 @@ def compare_innflux(self, own_data, innflux_data, location=""): string. Returns: - tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: Time - series comparing Obukhov length and sensible heat flux to - innFlux measurements. + list[tuple[plt.Figure, plt.Axes]]: Time series comparing + Obukhov length and sensible heat flux to innFLUX + measurements. """ data_timestamp = own_data.index[0] @@ -1011,7 +1010,9 @@ def compare_innflux(self, own_data, innflux_data, location=""): figure=fig_shf, timestamp=data_timestamp, suffix="innflux_shf" ) - return fig_obukhov, ax_obukhov, fig_shf, ax_shf + plots = [(fig_obukhov, ax_obukhov), (fig_shf, ax_shf)] + + return plots def compare_eddy(self, own_data, ext_data, source="innflux", location=""): """Compares data to an external source of eddy covariance data. @@ -1024,17 +1025,19 @@ def compare_eddy(self, own_data, ext_data, source="innflux", location=""): with an argparse.Namespace object. Args: - arguments (argparse.Namespace): User arguments. own_data (pd.DataFrame): Labelled data. ext_data (pd.DataFrame): Eddy covariance data from an external source. + source (str): Data source of vertical measurements. + Currently supports processed innFLUX data. + Default "innflux". location (str): Location of data collection. Default empty string. Returns: - tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: Time - series comparing Obukhov length and sensible heat flux to - innFlux measurements. + list[tuple[plt.Figure, plt.Axes]]: Time series comparing + Obukhov length and sensible heat flux to innFLUX + measurements. Raises: NotImplementedError: measurements are not @@ -1043,7 +1046,7 @@ def compare_eddy(self, own_data, ext_data, source="innflux", location=""): """ if source.lower() == "innflux": - fig_obukhov, ax_obukhov, fig_shf, ax_shf = self.compare_innflux( + eddy_plots = self.compare_innflux( own_data=own_data, innflux_data=ext_data, location=location ) else: @@ -1052,4 +1055,4 @@ def compare_eddy(self, own_data, ext_data, source="innflux", location=""): ) raise NotImplementedError(error_msg) - return fig_obukhov, ax_obukhov, fig_shf, ax_shf + return eddy_plots diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index bc7e417..e051adc 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -594,7 +594,8 @@ def plot_iterated_fluxes(self, iteration_data, time_id, location=""): string. Returns: - tuple[plt.Figure, plt.Figure]: Time series and comparison. + list[tuple[plt.Figure, plt.Axes]]: Time series and + comparison. """ fig_shf, ax_shf = self.plot_generic(iteration_data, "shf", site=location) @@ -611,7 +612,7 @@ def plot_iterated_fluxes(self, iteration_data, time_id, location=""): fig_comp = plt.gcf() self.save_figure(figure=fig_comp, timestamp=time_id, suffix="shf_comp") - return fig_shf, ax_shf, fig_comp, ax_comp + return [(fig_shf, ax_shf), (fig_comp, ax_comp)] def plot_innflux(self, iter_data, innflux_data, name="obukhov", site=""): """Plots comparison between scintillometer and InnFLUX data. @@ -724,13 +725,14 @@ def plot_vertical_profile( """Plots vertical profile of variable. Args: - vertical_data (dict[pd.DataFrame]): Contains time series of - vertical profiles. + vertical_data (dict[str, pd.DataFrame]): Contains time + series of vertical measurements. time_idx (pd.Timestamp): The local time for which to plot a vertical profile. name (str): Name of dependent variable, must be key in . - site (str): Location of data collection. Default empty string. + site (str): Location of data collection. Default empty + string. y_lim (float): Y-axis limit. Default None. Keyword Args: @@ -782,8 +784,8 @@ def plot_vertical_profile( location = ",\n" else: location = f"\nat {site_label}, " - time_idx = time_idx.strftime("%H:%M") - time_label = f"{time_data['date']} {time_idx} {time_data['tzone']}" + time_string = time_idx.strftime("%H:%M") + time_label = f"{time_data['date']} {time_string} {time_data['tzone']}" title_string = f"{title}{location}{time_label}" plt.title(title_string, fontweight="bold") axes = plt.gca() @@ -813,16 +815,13 @@ def plot_vertical_comparison(self, dataset, time_index, keys, site="", **kwargs) profiles. """ - key_number = len(keys) + key_length = len(keys) figure, axes = plt.subplots( - nrows=1, ncols=key_number, sharey=False, figsize=(4 * key_number, 8) + nrows=1, ncols=key_length, sharey=False, figsize=(4 * key_length, 8) ) subplot_labels = [] - for i in range(key_number): + for i in range(key_length): vertical_profile = dataset[keys[i]].loc[[time_index]] - time_data = self.get_date_and_timezone( - data=dataset[keys[i]].loc[[time_index]] - ) axes[i].plot( vertical_profile.values[0], vertical_profile.columns, @@ -837,7 +836,6 @@ def plot_vertical_comparison(self, dataset, time_index, keys, site="", **kwargs) if kwargs: self.parse_formatting_kwargs(axis=axes[i], **kwargs) - axes[0].set_ylabel("Height [m]") title_name = self.merge_multiple_labels(labels=subplot_labels) @@ -847,8 +845,9 @@ def plot_vertical_comparison(self, dataset, time_index, keys, site="", **kwargs) location = ",\n" else: location = f"\nat {site_label}, " - time_index = time_index.strftime("%H:%M") - time_label = f"{time_data['date']} {time_index} {time_data['tzone']}" + time_data = self.get_date_and_timezone(data=dataset[keys[-1]].loc[[time_index]]) + time_string = time_index.strftime("%H:%M") + time_label = f"{time_data['date']} {time_string} {time_data['tzone']}" title_string = f"{title}{location}{time_label}" figure.suptitle(title_string, fontweight="bold") @@ -879,15 +878,12 @@ def plot_merged_profiles(self, dataset, time_index, site="", y_lim=None, **kwarg """ keys = list(dataset) - key_number = len(keys) + key_length = len(keys) figure = plt.figure(figsize=(8, 8)) subplot_labels = [] xlims = [] - for i in range(key_number): + for i in range(key_length): vertical_profile = dataset[keys[i]].loc[[time_index]] - time_data = self.get_date_and_timezone( - data=dataset[keys[i]].loc[[time_index]] - ) line_label = self.label_selector(dependent=keys[i]) plt.plot( vertical_profile.values[0], @@ -903,6 +899,7 @@ def plot_merged_profiles(self, dataset, time_index, site="", y_lim=None, **kwarg if xlim_max > 1: xlim_min = math.floor(min(vertical_profile[heights].values[0])) xlims.append(xlim_min) + line_label = self.label_selector(dependent=keys[-1]) x_label = self.merge_label_with_unit(label=line_label) plt.xlabel(x_label) plt.ylabel("Height [m]") @@ -920,8 +917,9 @@ def plot_merged_profiles(self, dataset, time_index, site="", y_lim=None, **kwarg location = ",\n" else: location = f"\nat {site_label}, " - time_index = time_index.strftime("%H:%M") - time_label = f"{time_data['date']} {time_index} {time_data['tzone']}" + time_data = self.get_date_and_timezone(data=dataset[keys[-1]].loc[[time_index]]) + time_string = time_index.strftime("%H:%M") + time_label = f"{time_data['date']} {time_string} {time_data['tzone']}" title_name = self.merge_multiple_labels(labels=subplot_labels) title = f"Vertical Profiles of {title_name}" title_string = f"{title}{location}{time_label}" diff --git a/tests/conftest.py b/tests/conftest.py index 1c4708c..b7467ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -227,7 +227,7 @@ def test_check_timezone(self): self.check_timezone(dataframe=test_frame, tzone=arg_timezone) def index_not_equal(self, index_01: pd.Index, index_02: pd.Index): - """Check that two indices are not equal.""" + """Check if two indices are not equal.""" with pytest.raises(AssertionError, match="Index are different"): assert pd.testing.assert_index_equal( @@ -338,6 +338,21 @@ def conftest_mock_save_figure(): mock_exists.return_value = None +# Mock random data +@pytest.fixture(name="conftest_generate_series", scope="function", autouse=False) +def fixture_conftest_generate_series() -> tuple[np.ndarray, pd.Timestamp]: + """Generates Series with random data and DatetimeIndex""" + + rng = np.random.default_rng() + test_timestamp = pd.Timestamp("03 June 2020 05:20", tz="CET") + test_index = pd.date_range(start=test_timestamp, periods=100, freq="T") + assert ptypes.is_datetime64_any_dtype(test_index) + test_data = rng.random(size=len(test_index)) + assert isinstance(test_data, np.ndarray) + + yield test_data, test_index + + # Mock scintillometer data @pytest.fixture(name="conftest_mnd_lines", scope="function", autouse=False) def fixture_conftest_mock_mnd_lines(): diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 1d6ef40..8b37236 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -332,9 +332,9 @@ def test_get_regression( compare_regression["fit"], sklearn.linear_model.LinearRegression ) if not arg_intercept: - assert not compare_regression["fit"].fit_intercept - else: assert compare_regression["fit"].fit_intercept + else: + assert not compare_regression["fit"].fit_intercept assert isinstance(compare_regression["score"], float) assert isinstance(compare_regression["regression_line"], np.ndarray) assert not (np.isnan(compare_regression["regression_line"])).any() @@ -653,19 +653,12 @@ def test_plot_switch_time_stability( ) assert isinstance(test_time, pd.Timestamp) test_vertical = test_dataset["vertical"] - if arg_location: + if arg_location is not None: test_location = f"\nat {arg_location}, " else: test_location = ",\n" test_labels = ["Potential Temperature, [K]"] - if not arg_gradient: - test_vertical.pop("grad_potential_temperature", None) - assert "grad_potential_temperature" not in test_vertical - test_title = ( - f"Vertical Profile of Potential Temperature{test_location}", - f"{self.test_date} 05:20 CET", - ) - else: + if arg_gradient: assert "grad_potential_temperature" in test_vertical test_title = ( "Vertical Profiles of Potential Temperature ", @@ -673,23 +666,33 @@ def test_plot_switch_time_stability( f"{self.test_date} 05:20 CET", ) test_labels.append(r"Gradient of Potential Temperature, [K$\cdot$m$^{-1}$]") + else: + test_vertical.pop("grad_potential_temperature", None) + assert "grad_potential_temperature" not in test_vertical + test_title = ( + f"Vertical Profile of Potential Temperature{test_location}", + f"{self.test_date} 05:20 CET", + ) - compare_fig, compare_ax = self.test_metrics.plot_switch_time_stability( + compare_plots = self.test_metrics.plot_switch_time_stability( data=test_vertical, local_time=test_time, location=arg_location ) - assert isinstance(compare_fig, plt.Figure) - - if not arg_gradient: - assert isinstance(compare_ax, plt.Axes) - assert compare_ax.get_title() == "".join(test_title) - assert compare_ax.yaxis.get_label_text() == "Height [m]" - else: - assert isinstance(compare_ax, np.ndarray) - assert all(isinstance(ax, plt.Axes) for ax in compare_ax) - assert compare_fig.texts[0].get_text() == "".join(test_title) - assert compare_ax[0].yaxis.get_label_text() == "Height [m]" - for i in range(len(compare_ax)): - assert compare_ax[i].xaxis.get_label_text() == test_labels[i] + assert isinstance(compare_plots, list) + for compare_tuple in compare_plots: + assert isinstance(compare_tuple, tuple) + assert isinstance(compare_tuple[0], plt.Figure) + compare_ax = compare_tuple[1] + if arg_gradient: + assert isinstance(compare_ax, np.ndarray) + assert all(isinstance(ax, plt.Axes) for ax in compare_ax) + assert compare_tuple[0].texts[0].get_text() == "".join(test_title) + assert compare_ax[0].yaxis.get_label_text() == "Height [m]" + for i in range(len(compare_ax)): + assert compare_ax[i].xaxis.get_label_text() == test_labels[i] + else: + assert isinstance(compare_ax, plt.Axes) + assert compare_ax.get_title() == "".join(test_title) + assert compare_ax.yaxis.get_label_text() == "Height [m]" plt.close("all") @@ -717,32 +720,34 @@ def test_plot_lapse_rates( test_dataset = self.test_metrics.append_vertical_variables(data=test_dataset) for key in ["grad_potential_temperature", "environmental_lapse_rate"]: assert key in test_dataset["vertical"] - if arg_location: + if arg_location is not None: test_location = f"\nat {arg_location}, " else: test_location = ",\n" test_title = f"{test_location}{self.test_date} 05:10 CET" - fig_lapse, ax_lapse, fig_parcel, ax_parcel = self.test_metrics.plot_lapse_rates( + compare_plots = self.test_metrics.plot_lapse_rates( vertical_data=test_dataset["vertical"], dry_adiabat=self.test_metrics.constants.dalr, local_time=self.test_timestamp, location=arg_location, bl_height=arg_height, ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) compare_params = { "lapse": { "title": "Temperature Lapse Rates", "x_label": r"Lapse Rate, [Km$^{-1}$]", "y_label": "Height [m]", - "plot": (fig_lapse, ax_lapse), + "plot": (compare_plots[0]), }, "parcel": { "title": "Vertical Profiles of Parcel Temperature", "x_label": "Temperature, [K]", "y_label": "Height [m]", - "plot": (fig_parcel, ax_parcel), + "plot": (compare_plots[1]), }, } @@ -848,7 +853,7 @@ def test_plot_derived_metrics( _ = conftest_mock_save_figure test_frame = conftest_mock_derived_dataframe - if arg_regime: + if arg_regime is not None: test_conditions = f"{arg_regime.capitalize()} Conditions" else: test_conditions = "No Height Dependency" @@ -857,14 +862,17 @@ def test_plot_derived_metrics( f"for Free Convection ({test_conditions}),\n{self.test_date}", ) - compare_fig, compare_ax = self.test_metrics.plot_derived_metrics( + compare_plots = self.test_metrics.plot_derived_metrics( derived_data=test_frame, time_id=test_frame.index[0], regime=arg_regime, location="", ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) + compare_params = { - "plot": (compare_fig, compare_ax), + "plot": (compare_plots[0]), "x_label": "Time, CET", "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "title": " ".join(test_title), @@ -891,33 +899,35 @@ def test_plot_iterated_metrics( if arg_location: test_frame.attrs["name"] = arg_location assert "name" in test_frame.attrs - if arg_location: + if arg_location is not None: test_location = f" at {arg_location}" else: test_location = "" test_title = f"{test_location},\n{self.test_date}" - plot_pairs = self.test_metrics.plot_iterated_metrics( + compare_plots = self.test_metrics.plot_iterated_metrics( iterated_data=test_frame, time_stamp=test_stamp, site_location=arg_location, ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) - compare_plots = { + compare_params = { "iteration": { - "plot": (plot_pairs[0], plot_pairs[1]), + "plot": (compare_plots[0]), "title": "Sensible Heat Flux", "x_label": "Time, CET", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", }, "comparison": { - "plot": (plot_pairs[2], plot_pairs[3]), + "plot": (compare_plots[1]), "title": "Sensible Heat Flux from Free Convection and Iteration", "x_label": "Time, CET", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", }, } - for params in compare_plots.values(): + for params in compare_params.values(): conftest_boilerplate.check_plot(plot_params=params, title=test_title) plt.close("all") @@ -1121,30 +1131,31 @@ def test_compare_innflux( _ = conftest_mock_save_figure - if arg_location: + if arg_location is not None: test_location = f" at {arg_location}" else: test_location = "" test_title = f"{test_location},\n{self.test_date}" - fig_obukhov, ax_obukhov, fig_shf, ax_shf = self.test_workflow.compare_innflux( + compare_plots = self.test_workflow.compare_innflux( innflux_data=conftest_mock_innflux_dataframe_tz, own_data=conftest_mock_iterated_dataframe, location=arg_location, ) - + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) compare_params = { "obukhov": { "title": "Obukhov Length from Scintillometer and innFLUX", "y_label": "Obukhov Length, [m]", "x_label": "Time, CET", - "plot": (fig_obukhov, ax_obukhov), + "plot": (compare_plots[0]), }, "shf": { "title": "Sensible Heat Flux from Scintillometer and innFLUX", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "x_label": "Time, CET", - "plot": (fig_shf, ax_shf), + "plot": (compare_plots[1]), }, } @@ -1196,25 +1207,27 @@ def test_compare_eddy( test_location = "Test Location" test_title = f" at {test_location},\n{self.test_date}" - fig_obukhov, ax_obukhov, fig_shf, ax_shf = self.test_workflow.compare_eddy( + compare_plots = self.test_workflow.compare_eddy( own_data=conftest_mock_iterated_dataframe, ext_data=conftest_mock_innflux_dataframe_tz, source="innflux", location="Test Location", ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) compare_params = { "obukhov": { "title": "Obukhov Length from Scintillometer and innFLUX", "y_label": "Obukhov Length, [m]", "x_label": "Time, CET", - "plot": (fig_obukhov, ax_obukhov), + "plot": (compare_plots[0]), }, "shf": { "title": "Sensible Heat Flux from Scintillometer and innFLUX", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "x_label": "Time, CET", - "plot": (fig_shf, ax_shf), + "plot": (compare_plots[1]), }, } for params in compare_params.values(): diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index d1039c5..0cd193a 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -596,24 +596,26 @@ def test_plot_iterated_fluxes( test_title = f"{test_location},\n{self.test_date}" timestamp = test_data.index[0] - plots = self.test_plotting.plot_iterated_fluxes( + compare_plots = self.test_plotting.plot_iterated_fluxes( iteration_data=test_data, time_id=timestamp, location=arg_location, ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) compare_plots = { "shf": { "title": "Sensible Heat Flux", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "xlabel": "Time, CET", - "plot": (plots[0], plots[1]), + "plot": (compare_plots[0]), }, "comparison": { "title": "Sensible Heat Flux from Free Convection and Iteration", "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "xlabel": "Time, CET", - "plot": (plots[2], plots[3]), + "plot": (compare_plots[1]), }, } @@ -688,14 +690,18 @@ def test_plot_innflux( @pytest.mark.parametrize("arg_score", [None, 0.561734521]) @pytest.mark.parametrize("arg_regression", [True, False]) def test_plot_scatter( - self, conftest_boilerplate, arg_score, arg_site, arg_regression + self, + conftest_boilerplate, + conftest_generate_series, + arg_score, + arg_site, + arg_regression, ): """Plot scatter between two datasets with regression line.""" test_name = "obukhov" - rng = np.random.default_rng() - test_index = pd.date_range(start=self.test_timestamp, periods=100, freq="T") - test_data = rng.random(size=len(test_index)) + test_data, test_index = conftest_generate_series + test_x = pd.Series(name=test_name, data=test_data, index=test_index) test_y = pd.Series(name=test_name, data=test_data + 0.5, index=test_index) for series in [test_x, test_y]: From 1989f644a71e65700176d1f1f0570408961e8b53 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Wed, 24 May 2023 16:54:53 +0200 Subject: [PATCH 08/17] feat(metrics): calculate, plot, save regression between data Adds ability to plot and save regression between calculated fluxes and innFLUX data. Fixes some incorrect test conditions which led to tests being skipped. Refs: ST-3, ST-7, ST-8, ST-106, ST-107 --- src/scintillometry/metrics/calculations.py | 45 ++++++++++++-- src/scintillometry/visuals/plotting.py | 6 +- tests/test_metrics_calculations.py | 71 ++++++++++++++-------- tests/test_visuals_plotting.py | 4 +- 4 files changed, 92 insertions(+), 34 deletions(-) diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index 3225a92..5f34c69 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -991,26 +991,61 @@ def compare_innflux(self, own_data, innflux_data, location=""): """ data_timestamp = own_data.index[0] - fig_obukhov, ax_obukhov = self.plotting.plot_innflux( + obukhov_plot = self.plotting.plot_innflux( iter_data=own_data, innflux_data=innflux_data, name="obukhov", site=location, ) self.plotting.save_figure( - figure=fig_obukhov, timestamp=data_timestamp, suffix="innflux_obukhov" + figure=obukhov_plot[0], timestamp=data_timestamp, suffix="innflux_obukhov" ) - fig_shf, ax_shf = self.plotting.plot_innflux( + obukhov_regression = self.get_regression( + x_data=own_data["obukhov"], y_data=innflux_data["obukhov"], intercept=True + ) + obukhov_regression_plot = self.plotting.plot_scatter( + x_data=own_data, + y_data=innflux_data, + sources=["MOST Iteration", "innFLUX"], + name="obukhov", + score=obukhov_regression["score"], + regression_line=obukhov_regression["regression_line"], + site=location, + ) + self.plotting.save_figure( + figure=obukhov_regression_plot[0], + timestamp=data_timestamp, + suffix="innflux_obukhov_regression", + ) + + shf_plot = self.plotting.plot_innflux( iter_data=own_data, innflux_data=innflux_data, name="shf", site=location, ) self.plotting.save_figure( - figure=fig_shf, timestamp=data_timestamp, suffix="innflux_shf" + figure=shf_plot[0], timestamp=data_timestamp, suffix="innflux_shf" + ) + shf_regression = self.get_regression( + x_data=own_data["obukhov"], y_data=innflux_data["obukhov"], intercept=True + ) + shf_regression_plot = self.plotting.plot_scatter( + x_data=own_data, + y_data=innflux_data, + sources=["MOST Iteration", "innFLUX"], + name="shf", + score=shf_regression["score"], + regression_line=shf_regression["regression_line"], + site=location, + ) + self.plotting.save_figure( + figure=shf_regression_plot[0], + timestamp=data_timestamp, + suffix="innflux_shf_regression", ) - plots = [(fig_obukhov, ax_obukhov), (fig_shf, ax_shf)] + plots = [obukhov_plot, shf_plot, obukhov_regression_plot, shf_regression_plot] return plots diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index e051adc..9c99210 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -708,14 +708,14 @@ def plot_scatter( transform=axes.transAxes, ) variable_name = self.label_selector(name) - sources_string = f"{sources[0].title()} and {sources[1].title()}" + sources_string = f"{sources[0]} and {sources[1]}" title_string = f"{variable_name[0]} Regression Between\n{sources_string}" site_label = self.get_site_name(site_name=site, dataframe=x_data) self.title_plot(title=title_string, timestamp=date, location=site_label) x_label = self.merge_label_with_unit(label=variable_name) y_label = self.merge_label_with_unit(label=variable_name) - plt.xlabel(f"{sources[0].title()} {x_label}") - plt.ylabel(f"{sources[1].title()} {y_label}") + plt.xlabel(f"{x_label} ({sources[0]})") + plt.ylabel(f"{y_label} ({sources[1]})") return figure, axes diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 8b37236..07e471c 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -280,10 +280,10 @@ def test_match_time_at_threshold( test_weather = conftest_mock_weather_dataframe_tz.copy(deep=True) if arg_empty: test_weather["global_irradiance"] = 20 - if not arg_timestamp: - test_timestamp = None - else: + if arg_timestamp: test_timestamp = self.test_timestamp + else: + test_timestamp = None compare_time = self.test_metrics.match_time_at_threshold( series=test_weather["global_irradiance"], @@ -331,7 +331,7 @@ def test_get_regression( assert isinstance( compare_regression["fit"], sklearn.linear_model.LinearRegression ) - if not arg_intercept: + if arg_intercept: assert compare_regression["fit"].fit_intercept else: assert not compare_regression["fit"].fit_intercept @@ -653,7 +653,7 @@ def test_plot_switch_time_stability( ) assert isinstance(test_time, pd.Timestamp) test_vertical = test_dataset["vertical"] - if arg_location is not None: + if arg_location: test_location = f"\nat {arg_location}, " else: test_location = ",\n" @@ -720,7 +720,7 @@ def test_plot_lapse_rates( test_dataset = self.test_metrics.append_vertical_variables(data=test_dataset) for key in ["grad_potential_temperature", "environmental_lapse_rate"]: assert key in test_dataset["vertical"] - if arg_location is not None: + if arg_location: test_location = f"\nat {arg_location}, " else: test_location = ",\n" @@ -899,7 +899,7 @@ def test_plot_iterated_metrics( if arg_location: test_frame.attrs["name"] = arg_location assert "name" in test_frame.attrs - if arg_location is not None: + if arg_location: test_location = f" at {arg_location}" else: test_location = "" @@ -918,13 +918,13 @@ def test_plot_iterated_metrics( "plot": (compare_plots[0]), "title": "Sensible Heat Flux", "x_label": "Time, CET", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", }, "comparison": { "plot": (compare_plots[1]), "title": "Sensible Heat Flux from Free Convection and Iteration", "x_label": "Time, CET", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", }, } for params in compare_params.values(): @@ -1122,24 +1122,32 @@ def test_calculate_standard_metrics_no_vertical( def test_compare_innflux( self, conftest_mock_save_figure, - conftest_mock_innflux_dataframe_tz, - conftest_mock_iterated_dataframe, conftest_boilerplate, + conftest_generate_series, arg_location, ): - """Compares input data to InnFLUX data.""" + """Compares input data to innFLUX data.""" _ = conftest_mock_save_figure - if arg_location is not None: + test_data, test_index = conftest_generate_series + test_obukhov = pd.Series(data=test_data, index=test_index) + test_shf = pd.Series(data=test_data, index=test_index) + test_base_dataframe = pd.DataFrame( + data={"obukhov": test_obukhov, "shf": test_shf}, + ) + test_ext_dataframe = test_base_dataframe.add(0.5) + + if arg_location: test_location = f" at {arg_location}" else: test_location = "" test_title = f"{test_location},\n{self.test_date}" + test_regression_string = "Regression Between\nMOST Iteration and innFLUX" compare_plots = self.test_workflow.compare_innflux( - innflux_data=conftest_mock_innflux_dataframe_tz, - own_data=conftest_mock_iterated_dataframe, + own_data=test_base_dataframe, + innflux_data=test_ext_dataframe, location=arg_location, ) assert isinstance(compare_plots, list) @@ -1153,10 +1161,22 @@ def test_compare_innflux( }, "shf": { "title": "Sensible Heat Flux from Scintillometer and innFLUX", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "x_label": "Time, CET", "plot": (compare_plots[1]), }, + "obukhov_regression": { + "title": f"Obukhov Length {test_regression_string}", + "y_label": "Obukhov Length, [m] (innFLUX)", + "x_label": "Obukhov Length, [m] (MOST Iteration)", + "plot": (compare_plots[2]), + }, + "shf_regression": { + "title": f"Sensible Heat Flux {test_regression_string}", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$] (innFLUX)", + "x_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$] (MOST Iteration)", + "plot": (compare_plots[3]), + }, } for params in compare_params.values(): @@ -1194,22 +1214,25 @@ def test_compare_eddy_error( ], ) def test_compare_eddy( - self, - conftest_mock_save_figure, - conftest_mock_innflux_dataframe_tz, - conftest_mock_iterated_dataframe, - conftest_boilerplate, + self, conftest_mock_save_figure, conftest_generate_series, conftest_boilerplate ): """Compares input data to external eddy covariance data.""" _ = conftest_mock_save_figure + test_data, test_index = conftest_generate_series + test_obukhov = pd.Series(data=test_data, index=test_index) + test_shf = pd.Series(data=test_data, index=test_index) + test_base_dataframe = pd.DataFrame( + data={"obukhov": test_obukhov, "shf": test_shf}, + ) + test_ext_dataframe = test_base_dataframe.add(0.5) test_location = "Test Location" test_title = f" at {test_location},\n{self.test_date}" compare_plots = self.test_workflow.compare_eddy( - own_data=conftest_mock_iterated_dataframe, - ext_data=conftest_mock_innflux_dataframe_tz, + own_data=test_base_dataframe, + ext_data=test_ext_dataframe, source="innflux", location="Test Location", ) @@ -1225,7 +1248,7 @@ def test_compare_eddy( }, "shf": { "title": "Sensible Heat Flux from Scintillometer and innFLUX", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", "x_label": "Time, CET", "plot": (compare_plots[1]), }, diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index 0cd193a..15533a7 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -733,8 +733,8 @@ def test_plot_scatter( ) compare_params = { "plot": (compare_fig, compare_ax), - "x_label": "Baseline Obukhov Length, [m]", - "y_label": "Comparison Obukhov Length, [m]", + "x_label": "Obukhov Length, [m] (Baseline)", + "y_label": "Obukhov Length, [m] (Comparison)", "title": "\n".join(test_title), } conftest_boilerplate.check_plot(plot_params=compare_params) From d638315d93eff1b19442dcb335bab8681226426b Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Wed, 24 May 2023 17:06:08 +0200 Subject: [PATCH 09/17] refactor: minor docstring and logic fixes Begins: - ST-126: Deprecate FigurePlotter.plot_iterated_fluxes in favour of plot_iterated_metrics Refs: ST-3, ST-7, ST-8 --- src/scintillometry/metrics/calculations.py | 20 ++++++++--------- src/scintillometry/visuals/plotting.py | 8 +++---- tests/test_metrics_calculations.py | 12 +++++------ tests/test_visuals_plotting.py | 25 +++++++++++----------- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index 5f34c69..e150917 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -56,8 +56,9 @@ def get_path_height_parameters(self, transect, regime=None): regime (str): Target stability condition. Default None. Returns: - dict[float, float]: Tuples of effective and mean path height - |z_eff| and |z_mean| [m], with stability conditions as keys. + dict[str, tuple[np.floating, np.floating]]: Tuples of + effective and mean path height |z_eff| and |z_mean| [m], + with stability conditions as keys. """ z_params = self.transect.get_all_path_heights(path_transect=transect) @@ -106,7 +107,6 @@ def construct_flux_dataframe( . Args: - user_args (argparse.Namespace): Namespace of user arguments. interpolated_data (pd.DataFrame): Dataframe containing parsed and localised weather and scintillometer data with matching temporal resolution. @@ -176,7 +176,7 @@ def append_vertical_variables(self, data): "vertical" is updated with vertical data for water vapour pressure, air pressure, mixing ratio, virtual temperature, mean sea-level pressure, and potential temperature. - Otherwise the dictionary is returned unmodified. + Otherwise, the dictionary is returned unmodified. """ if "vertical" in data: @@ -283,10 +283,10 @@ def get_elbow_point(self, series, min_index=None, max_index=None): ] if series[indices[-1]] < series[indices[0]]: curve_direction = "decreasing" - online_param = "true" + online_param = True else: curve_direction = "increasing" - online_param = "true" + online_param = True knee = kneed.KneeLocator( series[indices], indices, @@ -1004,8 +1004,8 @@ def compare_innflux(self, own_data, innflux_data, location=""): x_data=own_data["obukhov"], y_data=innflux_data["obukhov"], intercept=True ) obukhov_regression_plot = self.plotting.plot_scatter( - x_data=own_data, - y_data=innflux_data, + x_data=own_data["obukhov"], + y_data=innflux_data["obukhov"], sources=["MOST Iteration", "innFLUX"], name="obukhov", score=obukhov_regression["score"], @@ -1031,8 +1031,8 @@ def compare_innflux(self, own_data, innflux_data, location=""): x_data=own_data["obukhov"], y_data=innflux_data["obukhov"], intercept=True ) shf_regression_plot = self.plotting.plot_scatter( - x_data=own_data, - y_data=innflux_data, + x_data=own_data["shf"], + y_data=innflux_data["shf"], sources=["MOST Iteration", "innFLUX"], name="shf", score=shf_regression["score"], diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index 9c99210..899bbce 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -264,7 +264,7 @@ def merge_multiple_labels(self, labels): labels (list[str]): Labels, which may contain duplicates. Returns: - str: A formatted, puncutated string with no duplicates. + str: A formatted, punctuated string with no duplicates. """ unique_text = list(dict.fromkeys(labels)) @@ -286,9 +286,9 @@ def set_xy_labels(self, ax, timezone, name): name (str): Name or abbreviation of dependent variable. Returns: - plt.Axes: Plot axes with labels for local time on the x axis - and for the dependent variable with units on the y axis. - Ticks on the x axis are formatted at hourly intervals. + plt.Axes: Plot axes with labels for local time on the x-axis + and for the dependent variable with units on the y-axis. + Ticks on the x-axis are formatted at hourly intervals. """ x_label = f"Time, {timezone.zone}" diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 07e471c..468ed2c 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -107,7 +107,7 @@ def test_get_path_height_parameters( assert compare_metrics["None"][0] > compare_metrics["None"][1] compare_print = capsys.readouterr() - if not arg_regime: + if arg_regime is None: assert "Selected no height dependency:" in compare_print.out else: assert str(arg_regime) in compare_print.out @@ -172,7 +172,7 @@ def test_construct_flux_dataframe(self, conftest_mock_merged_dataframe, arg_kwar """Compute sensible heat flux for free convection.""" test_frame = conftest_mock_merged_dataframe[["CT2", "H_convection"]] - if arg_kwargs: + if isinstance(arg_kwargs, tuple): test_kwargs = { "beam_wavelength": arg_kwargs[0], "beam_error": arg_kwargs[1], @@ -271,7 +271,7 @@ def test_append_vertical_variables( @pytest.mark.dependency(name="TestMetricsFlux::test_match_time_at_threshold") @pytest.mark.parametrize("arg_lessthan", [True, False]) @pytest.mark.parametrize("arg_empty", [True, False]) - @pytest.mark.parametrize("arg_timestamp", [True, None]) + @pytest.mark.parametrize("arg_timestamp", [True, False]) def test_match_time_at_threshold( self, conftest_mock_weather_dataframe_tz, arg_lessthan, arg_empty, arg_timestamp ): @@ -373,7 +373,7 @@ def test_get_elbow_point(self, arg_min_index, arg_max_index, arg_curve): test_indices, S=1.5, curve="convex", - online="true", + online=True, direction=test_direction, interp_method="interp1d", ) @@ -806,8 +806,8 @@ def test_calculate_switch_time( test_dataset = { "weather": conftest_mock_weather_dataframe_tz.copy(deep=True), "timestamp": self.test_timestamp.replace(hour=5, minute=10), + "vertical": conftest_mock_hatpro_dataset.copy(), } - test_dataset["vertical"] = conftest_mock_hatpro_dataset.copy() test_dataset = self.test_metrics.append_vertical_variables(data=test_dataset) if not arg_potential: @@ -974,7 +974,7 @@ def test_iterate_fluxes( for key in compare_keys: assert not (compare_metrics[key].isnull()).any() assert key in compare_metrics.keys() - assert all(isinstance(x, (mpmath.mpf)) for x in compare_metrics[key]) + assert all(isinstance(x, mpmath.mpf) for x in compare_metrics[key]) plt.close("all") diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index 15533a7..e4268ac 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -64,10 +64,11 @@ def assert_constant_lines( @pytest.mark.dependency(name="TestVisualsBoilerplate::test_assert_constant_lines") @pytest.mark.parametrize("arg_vlines", [{"v_a": 1}, {"v_a": 1, "v_b": 2.1}, None]) @pytest.mark.parametrize("arg_hlines", [{"h_a": 1}, {"h_a": 1, "h_b": 2.1}, None]) - def test_assert_constant_lines(self, arg_hlines, arg_vlines): - rng = np.random.default_rng() - test_index = np.arange(0, 200, 10) - test_data = rng.random(size=len(test_index)) + def test_assert_constant_lines( + self, conftest_generate_series, arg_hlines, arg_vlines + ): + """Validate test for constant lines existing on axis.""" + test_data, test_index = conftest_generate_series test_series = pd.Series(data=test_data, index=test_index) plt.figure(figsize=(10, 10)) @@ -128,13 +129,11 @@ def test_initialise_formatting(self, arg_offset): @pytest.mark.dependency(name="TestVisualsFormatting::test_parse_formatting_kwargs") @pytest.mark.parametrize("arg_fig", [True, False]) - def test_parse_formatting_kwargs(self, arg_fig): + def test_parse_formatting_kwargs(self, arg_fig, conftest_generate_series): """Parse kwargs when formatting.""" if arg_fig: - rng = np.random.default_rng() - test_index = np.arange(0, 200, 10) - test_data = rng.random(size=len(test_index)) + test_data, test_index = conftest_generate_series test_series = pd.Series(data=test_data, index=test_index) plt.figure(figsize=(10, 10)) plt.plot(test_index, test_series) @@ -333,12 +332,12 @@ def test_set_xy_labels(self, arg_name): ) @pytest.mark.parametrize("arg_vlines", [{"va": None}, {"va": 1, "vb": 2.1}, None]) @pytest.mark.parametrize("arg_hlines", [{"ha": None}, {"ha": 1, "hb": 2.1}, None]) - def test_plot_constant_lines(self, arg_hlines, arg_vlines): + def test_plot_constant_lines( + self, conftest_generate_series, arg_hlines, arg_vlines + ): """Plot horizontal and vertical lines.""" - rng = np.random.default_rng() - test_index = np.arange(0, 200, 10) - test_data = rng.random(size=len(test_index)) + test_data, test_index = conftest_generate_series test_series = pd.Series(data=test_data, index=test_index) plt.figure(figsize=(10, 10)) @@ -368,7 +367,7 @@ class TestVisualsPlotting(TestVisualsBoilerplate): test_date = "03 June 2020" test_timestamp = pd.Timestamp(f"{test_date} 05:20", tz="CET") - def test_visualsplotting_attributes(self): + def test_visuals_plotting_attributes(self): assert isinstance(self.test_timestamp, pd.Timestamp) assert self.test_timestamp.strftime("%Y-%m-%d %H:%M") == "2020-06-03 05:20" assert self.test_timestamp.tz.zone == "CET" From ae5b2ff677bae286784713cf2a419c2ee9e7b26f Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Thu, 25 May 2023 13:55:07 +0200 Subject: [PATCH 10/17] refactor: begin deprecating FigurePlotter.plot_iterated_fluxes Begins deprecation of FigurePlotter.plot_iterated_fluxes in favour of MetricsFlux.plot_iterated_metrics. The former method is redundant. Refs: ST-7, ST-8, ST-126 --- src/scintillometry/metrics/calculations.py | 35 ++++++++++++++++------ src/scintillometry/visuals/plotting.py | 9 ++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index e150917..e5d735f 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -825,14 +825,19 @@ def plot_derived_metrics(self, derived_data, time_id, regime=None, location=""): return derived_plots def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): - """Plot and save time series and comparison of iterated fluxes. + """Plots and saves iterated SHF, comparison to free convection. + + .. todo:: + ST-126: Deprecate FigurePlotter.plot_iterated_fluxes in + favour of plot_iterated_metrics. Args: - iterated_data (pd.DataFrame): TZ-aware dataframe with - columns for heat fluxes and MOST parameters. - time_stamp (pd.Timestamp): Start time of scintillometer data - collection. - site_location (str): Name of scintillometer location. + iteration_data (pd.DataFrame): TZ-aware with columns for + sensible heat fluxes calculated for free convection + |H_free|, and by MOST |H|. + time_id (pd.Timestamp): Local time of data collection. + site_location (str): Location of data collection. Default empty + string. Returns: list[tuple[plt.Figure, plt.Axes]]: Time series of sensible @@ -840,11 +845,23 @@ def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): sensible heat flux under free convection. """ - plots = self.plotting.plot_iterated_fluxes( - iteration_data=iterated_data, time_id=time_stamp, location=site_location + shf_plot = self.plotting.plot_generic(iterated_data, "shf", site=site_location) + self.plotting.save_figure( + figure=shf_plot[0], timestamp=time_stamp, suffix="shf" ) - return plots + comparison_plot = self.plotting.plot_comparison( + df_01=iterated_data, + df_02=iterated_data, + keys=["H_free", "shf"], + labels=["Free Convection", "Iteration"], + site=site_location, + ) + self.plotting.save_figure( + figure=comparison_plot[0], timestamp=time_stamp, suffix="shf_comp" + ) + + return [shf_plot, comparison_plot] class MetricsWorkflow(MetricsFlux, MetricsTopography): diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index 899bbce..68f5ccf 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -585,6 +585,15 @@ def plot_comparison(self, df_01, df_02, keys, labels, site=""): def plot_iterated_fluxes(self, iteration_data, time_id, location=""): """Plots and saves iterated SHF, comparison to free convection. + .. note:: Pending deprecation in a future patch release. Use + :func:`MetricsFlux.plot_iterated_metrics() + ` + instead. + + .. todo:: + ST-126: Deprecate FigurePlotter.plot_iterated_fluxes in + favour of plot_iterated_metrics. + Args: iteration_data (pd.DataFrame): TZ-aware with sensible heat fluxes calculated for free convection |H_free|, and From 7ac8ee32aab87f97ba490549a22ac169d32c1b49 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Thu, 25 May 2023 14:04:18 +0200 Subject: [PATCH 11/17] fix(tests): fix incorrect type hinting in conftest Refs: ST-3 --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index b7467ce..9fedd24 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -340,7 +340,7 @@ def conftest_mock_save_figure(): # Mock random data @pytest.fixture(name="conftest_generate_series", scope="function", autouse=False) -def fixture_conftest_generate_series() -> tuple[np.ndarray, pd.Timestamp]: +def fixture_conftest_generate_series(): """Generates Series with random data and DatetimeIndex""" rng = np.random.default_rng() From a2a4f222675223b89328950ceefdd1d80c9a6be0 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Thu, 25 May 2023 14:08:13 +0200 Subject: [PATCH 12/17] fix: add scikit-learn to project requirements Refs: ST-1 --- dev-requirements.txt | 1 + pyproject.toml | 1 + requirements.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/dev-requirements.txt b/dev-requirements.txt index 1872c7b..cf5e94a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,6 +3,7 @@ tqdm>4.8 scipy>=1.10 mpmath>=1.2.1 numpy +scikit-learn matplotlib kneed pytest>=7.0 diff --git a/pyproject.toml b/pyproject.toml index 17aba01..164da57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "scipy >= 1.10", "mpmath >= 1.2.1", "numpy", + "scikit-learn "matplotlib", "kneed", ] diff --git a/requirements.txt b/requirements.txt index 88f6eb2..c450156 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ tqdm>4.8 scipy>=1.10 mpmath>=1.2.1 numpy +scikit-learn matplotlib kneed From 856bcd9a885f3135338a47bf15ccd4e3f1eac2bc Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Thu, 25 May 2023 14:10:28 +0200 Subject: [PATCH 13/17] fix: fix broken syntax in pyproject.toml Refs: ST-1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 164da57..7ff6d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "scipy >= 1.10", "mpmath >= 1.2.1", "numpy", - "scikit-learn + "scikit-learn", "matplotlib", "kneed", ] From f07b6c65ca36776477bc9b7db04804d8784d9e9c Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Thu, 25 May 2023 14:48:55 +0200 Subject: [PATCH 14/17] merge: merge branch feat-ST127-refactor-project with develop (#8) * refactor: minor docstring fixes Refs: ST-2, ST-127 * refactor: update README Refs: ST-2, ST-127 * refactor: minor efficiency fixes Refs: ST-2, ST-5, ST-7, ST-127 * refactor: minor updates to test logic Refs: ST-3, ST-127 * refactor: improve code legibility in main Refs: ST-26, ST-127 * fix: remove unsupported type hinting Refs: ST-3, ST-127 Refs: ST-1, ST-127 --- README.md | 6 ++--- src/scintillometry/backend/constants.py | 12 ++++------ src/scintillometry/backend/constructions.py | 6 ++--- src/scintillometry/backend/iterations.py | 2 ++ src/scintillometry/backend/transects.py | 10 ++++---- src/scintillometry/main.py | 18 +++++++------- src/scintillometry/metrics/calculations.py | 7 +++--- src/scintillometry/wrangler/data_parser.py | 26 ++++++++------------- tests/test_backend_iterations.py | 10 ++++---- tests/test_backend_transects.py | 2 +- tests/test_visuals_plotting.py | 8 +++---- tests/test_wrangler_data_parser.py | 4 ++-- 12 files changed, 52 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index b0dca28..f1d1d8d 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,11 @@ limitations under the License. --> [![Pytest and Flake8](https://github.com/gampnico/scintillometry/actions/workflows/python-app.yml/badge.svg?branch=main)](https://github.com/gampnico/scintillometry/actions/workflows/python-app.yml) -Analyse data & 2D flux footprints from Scintec's BLS scintillometers. +Development branch: [![Pytest and Linting](https://github.com/gampnico/scintillometry/actions/workflows/python-app.yml/badge.svg?branch=develop)](https://github.com/gampnico/scintillometry/actions/workflows/python-app.yml) -This repository is a complete rewrite of gampnico/ss19-feldkurs. If you have any existing forks or local clones, **please delete them**. The legacy code no longer works. No user features will be lost, but rewriting may take some time. Contributions are always welcome. +Analyse data & 2D flux footprints from Scintec's BLS scintillometers. -This package started life as part of a field course. If you spot any missing citations or licenses please [open an issue](https://github.com/gampnico/scintillometry/issues). +This project started life as part of a field course. If you spot any missing citations or licences please [open an issue](https://github.com/gampnico/scintillometry/issues). Comprehensive documentation is available [via ReadTheDocs](https://scintillometry.readthedocs.io/en/latest/). diff --git a/src/scintillometry/backend/constants.py b/src/scintillometry/backend/constants.py index 0a58757..91c5eaa 100644 --- a/src/scintillometry/backend/constants.py +++ b/src/scintillometry/backend/constants.py @@ -34,10 +34,6 @@ class AtmosConstants(object): BLS type. lamda (float): BLS wavelength, |lamda| [nm]. lamda_error (float): BLS wavelength error, [nm]. - m1_opt (float): Needed for |A_T| and |A_q|, from Owens (1967). - [#owens1967]_ - m2_opt (float): Needed for |A_T| and |A_q|, from Owens (1967). - [#owens1967]_ at_opt (float): |A_T| coefficient for 880 nm & typical atmospheric conditions, from Ward et al. (2013). @@ -60,7 +56,7 @@ class AtmosConstants(object): 20°C [|Jkg^-1|]. r_dry (float): Specific gas constant for dry air, |R_dry| [|JK^-1| |kg^-1|]. - r_vapour (float): Specific gas contstant for water vapour, + r_vapour (float): Specific gas constant for water vapour, |R_v| [|JK^-1| |kg^-1|]. ratio_rmm (float): Ratio of molecular masses of water vapour and dry air i.e. ratio of gas constants |epsilon|. @@ -149,7 +145,7 @@ def convert_pressure(self, pressure, base=True): pressure (Union[pd.DataFrame, pd.Series]): Pressure measurements |P| in pascals [Pa], hectopascals [hPa], or bars [bar]. - base (bool): If True, converts to pascals [Pa]. Otherwise + base (bool): If True, converts to pascals [Pa]. Otherwise, converts to hectopascals [hPa]. Default True. Returns: @@ -186,12 +182,12 @@ def convert_temperature(self, temperature, base=True): - T [°C] < 130 °C This method should therefore only be used on pre-processed data - as a *convenience*. By default converts to kelvins. + as a *convenience*. By default, converts to kelvins. Args: temperature (Union[pd.DataFrame, pd.Series]): Temperature measurements |T| in kelvins [K] or Celsius [°C]. - base (bool): If True, converts to kelvins [K]. Otherwise + base (bool): If True, converts to kelvins [K]. Otherwise, converts to Celsius [°C]. Default True. Returns: diff --git a/src/scintillometry/backend/constructions.py b/src/scintillometry/backend/constructions.py index d7f87ec..644037f 100644 --- a/src/scintillometry/backend/constructions.py +++ b/src/scintillometry/backend/constructions.py @@ -190,7 +190,7 @@ def get_mixing_ratio(self, wv_pressure, d_pressure): # (wv_pressure * self.r_dry) / (d_pressure * self.r_vapour) m_ratio = (wv_pressure.multiply(self.constants.r_dry)).divide( - (d_pressure).multiply(self.constants.r_vapour) + d_pressure.multiply(self.constants.r_vapour) ) return m_ratio @@ -236,7 +236,7 @@ def get_reduced_pressure(self, station_pressure, virtual_temperature, elevation) elevation (float): Station elevation, |z_stn| [m]. Returns: - pd.DataDrame: Derived vertical measurements for mean + pd.DataFrame: Derived vertical measurements for mean sea-level pressure, |P_MSL| [Pa]. """ @@ -495,7 +495,7 @@ def get_gradient(self, data, method="backward"): :math:`\\partial T/\\partial z` for heights |z| with time index t. - By default the gradient is calculated using a 1-D + By default, the gradient is calculated using a 1-D centred-differencing scheme for non-uniform meshes, since vertical measurements are rarely made at uniform intervals. diff --git a/src/scintillometry/backend/iterations.py b/src/scintillometry/backend/iterations.py index c7be7fa..89079b3 100644 --- a/src/scintillometry/backend/iterations.py +++ b/src/scintillometry/backend/iterations.py @@ -329,6 +329,8 @@ def most_iteration(self, dataframe, zm_bls, stable_flag, most_coeffs): stable_flag (bool): Stability conditions. If true, assumes stable conditions, otherwise assumes unstable conditions. + most_coeffs (list): MOST coefficients for unstable and + stable conditions. Returns: pd.DataFrame: Dataframe with additional columns for Obukhov diff --git a/src/scintillometry/backend/transects.py b/src/scintillometry/backend/transects.py index e9a1d71..5bd4059 100644 --- a/src/scintillometry/backend/transects.py +++ b/src/scintillometry/backend/transects.py @@ -85,7 +85,7 @@ def get_b_value(self, stability_name): Returns: float: Constant "b" accounting for height dependence of - |Cn2|. Values of "b" are from Hartogenesis et al. + |Cn2|. Values of "b" are from Hartogensis et al. (2003) [#hartogensis2003]_, and Kleissl et al. (2008) [#kleissl2008]_. @@ -94,7 +94,7 @@ def get_b_value(self, stability_name): stability condition. """ - # Hartogenesis et al. (2003), Kleissl et al. (2008). + # Hartogensis et al. (2003), Kleissl et al. (2008). stability_dict = {"stable": -2 / 3, "unstable": -4 / 3} if not stability_name: @@ -168,9 +168,9 @@ def get_all_path_heights(self, path_transect): path_transect (pd.DataFrame): Parsed path transect data. Returns: - dict[tuple[np.floating, np.floating]]: Effective and mean - path heights of transect |z_eff| and |z_mean| [m], with each - stability condition as key. + dict[str, tuple[np.floating, np.floating]]: Effective and + mean path heights of transect |z_eff| and |z_mean| [m], with + each stability condition as key. """ path_heights_dict = {} diff --git a/src/scintillometry/main.py b/src/scintillometry/main.py index 2ae1058..391a2e6 100644 --- a/src/scintillometry/main.py +++ b/src/scintillometry/main.py @@ -55,8 +55,8 @@ import argparse -import scintillometry.metrics.calculations as MetricsCalculations -import scintillometry.wrangler.data_parser as DataParser +import scintillometry.metrics.calculations as calculations +import scintillometry.wrangler.data_parser as data_parser def user_argumentation(): @@ -298,10 +298,10 @@ def perform_data_parsing(**kwargs): measurements, weather observations, and topography. """ - data_parser = DataParser.WranglerParsing() + parser = data_parser.WranglerParsing() # Parse BLS, weather, and topographical data - datasets = data_parser.wrangle_data( + datasets = parser.wrangle_data( bls_path=kwargs["input"], transect_path=kwargs["transect_path"], calibrate=kwargs["calibration"], @@ -312,7 +312,7 @@ def perform_data_parsing(**kwargs): # Parse vertical measurements if kwargs["profile_prefix"]: - datasets["vertical"] = data_parser.vertical.parse_vertical( + datasets["vertical"] = parser.vertical.parse_vertical( file_path=kwargs["profile_prefix"], source="hatpro", levels=None, @@ -360,11 +360,11 @@ def perform_analysis(datasets, **kwargs): covariance data. """ - metrics_class = MetricsCalculations.MetricsWorkflow() - data_parser = DataParser.WranglerParsing() + metrics_class = calculations.MetricsWorkflow() + parser = data_parser.WranglerParsing() metrics_data = metrics_class.calculate_standard_metrics(data=datasets, **kwargs) if kwargs["eddy_path"]: - eddy_frame = data_parser.eddy.parse_eddy_covariance( + eddy_frame = parser.eddy.parse_eddy_covariance( file_path=kwargs["eddy_path"], tzone=kwargs["timezone"], source="innflux" ) metrics_data["eddy"] = eddy_frame @@ -382,7 +382,7 @@ def main(): """Parses command line arguments and executes analysis. Converts command line arguments into kwargs. Imports and parses - scintillomter, weather, and transect data. If the appropriate + scintillometer, weather, and transect data. If the appropriate arguments are specified: - Parses vertical measurements diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index e5d735f..4b044d1 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -829,13 +829,14 @@ def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): .. todo:: ST-126: Deprecate FigurePlotter.plot_iterated_fluxes in - favour of plot_iterated_metrics. + favour of plot_iterated_metrics. `site_location` should + be deprecated in favour of `location`. Args: - iteration_data (pd.DataFrame): TZ-aware with columns for + iterated_data (pd.DataFrame): TZ-aware with columns for sensible heat fluxes calculated for free convection |H_free|, and by MOST |H|. - time_id (pd.Timestamp): Local time of data collection. + time_stamp (pd.Timestamp): Local time of data collection. site_location (str): Location of data collection. Default empty string. diff --git a/src/scintillometry/wrangler/data_parser.py b/src/scintillometry/wrangler/data_parser.py index 6dccdf3..7e1807b 100644 --- a/src/scintillometry/wrangler/data_parser.py +++ b/src/scintillometry/wrangler/data_parser.py @@ -93,7 +93,7 @@ def parse_iso_date(self, x, date=True): Args: x (str): Timestamp containing ISO-8601 duration and date, i.e. "/". - date (bool): If True, returns date. Otherwise returns + date (bool): If True, returns date. Otherwise, returns duration. Default True. Returns: @@ -163,10 +163,9 @@ def parse_mnd_lines(self, line_list): line_list (list): Lines read from .mnd file in FORMAT-1. Returns: - dict[list, list, str, list]: Contains a list of lines of - parsed BLS data, an ordered list of variable names, the file - timestamp, and any additional header parameters in the file - header. + dict: Contains a list of lines of parsed BLS data, an + ordered list of variable names, the file timestamp, and any + additional header parameters in the file header. Raises: Warning: The input file does not follow FORMAT-1. @@ -251,7 +250,8 @@ def parse_scintillometer(self, file_path, timezone="CET", calibration=None): """Parses .mnd files into dataframes. Args: - filename (str): Path to a raw .mnd data file using FORMAT-1. + file_path (str): Path to a raw .mnd data file using + FORMAT-1. timezone (str): Local timezone during the scintillometer's operation. Default "CET". calibration (list): Contains the incorrect and correct path @@ -680,11 +680,7 @@ def parse_eddy_covariance(self, file_path, source="innflux", tzone=None): """ if source.lower() == "innflux": - eddy_data = self.parse_innflux( - file_name=file_path, - timezone=tzone, - headers=None, - ) + eddy_data = self.parse_innflux(file_name=file_path, timezone=tzone) else: error_msg = ( f"{source.title()} measurements are not supported. Use 'innflux'." @@ -806,8 +802,8 @@ def parse_hatpro( Default 612.0. Returns: - dict[pd.DataFrame, pd.DataFrame]: Vertical measurements from - HATPRO for temperature |T| [K], and absolute humidity + dict[str, pd.DataFrame]: Vertical measurements from HATPRO + for temperature |T| [K], and absolute humidity |rho_v| [|gm^-3|]. """ @@ -828,9 +824,7 @@ def parse_hatpro( station_elevation=elevation, ) - data = {} - data["humidity"] = humidity_data - data["temperature"] = temperature_data + data = {"humidity": humidity_data, "temperature": temperature_data} return data diff --git a/tests/test_backend_iterations.py b/tests/test_backend_iterations.py index e2e4526..2889f38 100644 --- a/tests/test_backend_iterations.py +++ b/tests/test_backend_iterations.py @@ -136,7 +136,7 @@ def test_get_most_coefficients(self): @pytest.mark.dependency(name="TestBackendIterationMost::test_similarity_function") @pytest.mark.parametrize("arg_obukhov", [(-100, False), (0, True), (100, True)]) - def test_similarity_function(self, arg_obukhov): + def test_similarity_function(self, arg_obukhov: tuple): """Compute similarity function.""" test_f_ct2 = self.test_class.similarity_function( @@ -148,7 +148,7 @@ def test_similarity_function(self, arg_obukhov): @pytest.mark.dependency(name="TestBackendIterationMost::test_calc_theta_star") @pytest.mark.parametrize("arg_params", [(1.9e-04, 5.6, True), (2e-03, 3.6, False)]) - def test_calc_theta_star(self, arg_params): + def test_calc_theta_star(self, arg_params: tuple): """Calculate temperature scale.""" test_theta = self.test_class.calc_theta_star( @@ -182,7 +182,7 @@ def test_calc_obukhov_length(self, arg_theta): """Calculate Obukhov length.""" compare_lob = self.test_class.calc_obukhov_length( - temp=295, u_star=0.2, theta_star=mpmath.mpmathify(arg_theta) + temp=np.float64(295.0), u_star=0.2, theta_star=mpmath.mpmathify(arg_theta) ) assert isinstance(compare_lob, mpmath.mpf) assert (compare_lob < 0) == (arg_theta < 0) # obukhov and theta have same sign @@ -224,7 +224,7 @@ def test_check_signs(self, arg_shf, arg_obukhov): scope="class", ) @pytest.mark.parametrize("arg_stable", [(200, True), (-100, False)]) - def test_most_iteration(self, conftest_mock_merged_dataframe, arg_stable): + def test_most_iteration(self, conftest_mock_merged_dataframe, arg_stable: tuple): """Iterate single row of dataframe using MOST.""" test_data = conftest_mock_merged_dataframe.iloc[0].copy(deep=True) @@ -310,7 +310,7 @@ def test_most_method(self, capsys, conftest_mock_merged_dataframe, arg_stable): for key in compare_keys: assert not (compare_most[key].isnull()).any() assert key in compare_most.keys() - assert all(isinstance(x, (mpmath.mpf)) for x in compare_most[key]) + assert all(isinstance(x, mpmath.mpf) for x in compare_most[key]) # signs match stability assert (compare_most["obukhov"] > 0).all() == arg_stable[1] diff --git a/tests/test_backend_transects.py b/tests/test_backend_transects.py index 6c28610..93cddf3 100644 --- a/tests/test_backend_transects.py +++ b/tests/test_backend_transects.py @@ -204,7 +204,7 @@ def test_print_path_heights(self, capsys, arg_stability): test_capture = capsys.readouterr() self.test_transect_parameters.print_path_heights( - z_eff=34, z_mean=31.245, stability=arg_stability + z_eff=np.float64(34), z_mean=np.float64(31.245), stability=arg_stability ) compare_capture = capsys.readouterr() diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index e4268ac..7fd3f4c 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -606,14 +606,14 @@ def test_plot_iterated_fluxes( compare_plots = { "shf": { "title": "Sensible Heat Flux", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", - "xlabel": "Time, CET", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "x_label": "Time, CET", "plot": (compare_plots[0]), }, "comparison": { "title": "Sensible Heat Flux from Free Convection and Iteration", - "ylabel": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", - "xlabel": "Time, CET", + "y_label": r"Sensible Heat Flux, [W$\cdot$m$^{-2}$]", + "x_label": "Time, CET", "plot": (compare_plots[1]), }, } diff --git a/tests/test_wrangler_data_parser.py b/tests/test_wrangler_data_parser.py index 83d0ff0..77f3d05 100644 --- a/tests/test_wrangler_data_parser.py +++ b/tests/test_wrangler_data_parser.py @@ -891,7 +891,7 @@ def test_vertical_init(self): @pytest.mark.dependency( name="TestWranglerVertical::test_construct_hatpro_levels_error" ) - @pytest.mark.parametrize("arg_levels", [[(0, 1), (0)], [1.0, 30]]) + @pytest.mark.parametrize("arg_levels", [[(0, 1), 0], [1.0, 30]]) def test_construct_hatpro_levels_error(self, arg_levels): """Raise error for incorrectly formatted scanning levels.""" @@ -910,7 +910,7 @@ def test_construct_hatpro_levels(self, arg_levels): compare_scan = self.test_wrangler_vertical.construct_hatpro_levels( levels=arg_levels ) - assert isinstance(compare_scan, (list)) + assert isinstance(compare_scan, list) assert all(isinstance(x, int) for x in compare_scan) @pytest.mark.dependency( From cc9fedcaaa4bf7b34f40439eb3a486bf3944dd84 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Mon, 29 May 2023 16:20:04 +0200 Subject: [PATCH 15/17] merge: merge feat-ST128-deprecations module (#9) * feat(backed): add decorators for deprecating objects Call the decorator with: from scintillometry.backend.deprecations import Decorators @Decorators.deprecated(stage="...", reason="...", version="...") def some_function(foo=...): ... Refs: ST-3, ST-6, ST-128 * feat(backend): deprecate individual arguments Adds method to deprecate individual function arguments. ``` from scintillometry.backend.deprecations import Decorators @Decorators.deprecated_argument(stage, old_argument="new_argument") def some_function(new_argument): ... ``` Refs: ST-3, ST-6, ST-128 * refactor(visuals): start deprecating plot_iterated_fluxes Marks `plot_iterated_fluxes` as pending deprecation. This function is superseded by `MetricsFlux.plot_iterated_metrics`. Refs: ST-3, ST-8, ST-126 * refactor(metrics): deprecate argument in plot_iterated_metrics Marks `site_location` as pending deprecation. Argument is replaced by `location`. Refs: ST-3, ST-7, ST-126 * docs(metrics): updates docstring noting deprecation Refs: ST-2, ST-126 --- docs/source/scintillometry.backend.rst | 10 + src/scintillometry/backend/deprecations.py | 339 +++++++++++++ src/scintillometry/metrics/calculations.py | 19 +- src/scintillometry/visuals/plotting.py | 6 + tests/test_backend_deprecations.py | 543 +++++++++++++++++++++ tests/test_metrics_calculations.py | 16 +- tests/test_visuals_plotting.py | 11 +- 7 files changed, 930 insertions(+), 14 deletions(-) create mode 100644 src/scintillometry/backend/deprecations.py create mode 100644 tests/test_backend_deprecations.py diff --git a/docs/source/scintillometry.backend.rst b/docs/source/scintillometry.backend.rst index fb56779..97ef5a5 100644 --- a/docs/source/scintillometry.backend.rst +++ b/docs/source/scintillometry.backend.rst @@ -70,6 +70,16 @@ In: scintillometry.backend.iterations.py :undoc-members: :show-inheritance: +Deprecations module +------------------- + +In: scintillometry.backend.deprecations.py + +.. automodule:: scintillometry.backend.deprecations + :members: + :undoc-members: + :show-inheritance: + References ---------- diff --git a/src/scintillometry/backend/deprecations.py b/src/scintillometry/backend/deprecations.py new file mode 100644 index 0000000..2f6cb22 --- /dev/null +++ b/src/scintillometry/backend/deprecations.py @@ -0,0 +1,339 @@ +"""Copyright 2023 Scintillometry Contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +===== + +Handles deprecating, deprecated, and defunct code. The deprecation +process takes one patch release cycle, the removal process takes two +major release cycles. + +EOL Cycle +--------- + +Import the decorator +:func:`~scintillometry.backend.deprecations.deprecated` from this +module. Mark deprecating functions like so: + +.. code-block:: python + + @deprecated(stage="deprecated", reason="Some reason", version="1.1.1") + def foobar(...) + ... + +Where ``stage`` is the stage in the function's deprecation cycle, +``reason`` is an optional message, and ``version`` is the release number +when ``stage`` was last updated. + +Update ``stage`` by following this schedule: + +1. A function pending deprecation is marked as **pending** and issues a +PendingDeprecationWarning during patch development. + +2. The function is marked as **deprecated** and issues a +DeprecationWarning from the next patch release. It must still work as +intended. + +3. The function is marked as **eol** and issues a FutureWarning during +the next minor release cycle. It must still work as intended. + +4. The function is marked as **defunct** and throws an error during the +next major release cycle, but is not removed. + +5. The defunct function is entirely removed from release candidates in +preparation for the subsequent major release. +""" + + +import functools +import inspect +import warnings +from typing import Callable + + +class DeprecationHandler: + """Methods for marking and handling deprecation cycles. + + Attributes: + string_types (tuple): Supported string types (str and + byte literals). + """ + + def __init__(self): + super().__init__() + self.string_types = (type(b""), type("")) # support byte literals + + def get_stage(self, name): + """Gets stage in deprecation cycle. + + Args: + name (str): Stage of deprecation cycle. Valid values + are: + + - pending + - deprecated + - eol + - defunct + + Returns: + tuple[str, Exception]: The description of the deprecation stage, + and the Exception subclass matching this stage. + + Raises: + TypeError: is an invalid type. Use str instead. + ValueError: is not a valid deprecation stage. + """ + + if isinstance(name, self.string_types): + stage_lower = name.lower() + if stage_lower == "deprecated": + stage_string = "deprecated." + category = DeprecationWarning + elif stage_lower == "pending": + stage_string = "pending deprecation." + category = PendingDeprecationWarning + elif stage_lower == "eol": + stage_string = "deprecated." + category = FutureWarning + elif stage_lower == "defunct": + stage_string = "defunct." + category = RuntimeError + else: + raise ValueError(f"{stage_lower} is not a valid deprecation stage.") + else: + raise TypeError(f"{type(name)} is an invalid type. Use {str} instead.") + + return stage_string, category + + def get_reason(self, **kwargs): + """Gets reason for deprecation. + + Keyword Args: + reason (str): Reason for deprecation. + + Returns: + str: The reason for deprecation if available, otherwise returns + an empty string. + """ + + reason = kwargs.get("reason", None) + if not isinstance(reason, self.string_types): + reason = "" + + return reason + + def get_version(self, **kwargs): + """Gets release number of latest stage in deprecation. + + Keyword Args: + version (str): Release number of latest stage in deprecation. + + Returns: + str: Formatted version number of latest stage in deprecation if + provided, otherwise returns an empty string. + """ + + version = kwargs.get("version", None) + if not isinstance(version, self.string_types): + version = "" + else: + version = f"Ver. {version}: " + + return version + + def raise_warning(self, obj, stage, details): + """Raises warning or error with informative message. + + Raises a warning or error stating the function or class' stage in + its deprecation cycle. Optionally, lists a release number and + reason. + + Args: + obj (Callable): The function or class being deprecated. + stage (str): The current stage in the deprecation cycle. + details (dict): A dictionary map optionally containing the + keys: + + - **reason**: the reason for deprecation. + - **version**: the release number of the latest + change in the deprecation cycle. + + Raises: + RuntimeError: The function is . + """ + + stage_string, warn_class = self.get_stage(name=stage) + reason = self.get_reason(**details) + version = self.get_version(**details) + suffix = " ".join((stage_string, reason)) + + if inspect.isclass(obj): + warn_string = f"{version}The class {obj.__name__} is {suffix}" + else: + warn_string = f"{version}The function {obj.__name__} is {suffix}" + + if stage.lower() != "defunct": + warnings.warn( + message=warn_string.strip(), + category=warn_class, + stacklevel=2, + ) + else: + raise RuntimeError(warn_string.strip()) + + def rename_arguments(self, obj, stage, kwargs, alias, reason=None, version=None): + """Marks argument as deprecated and redirects to alias. + + The wrapped function's arguments are wrapped into `kwargs` and + are safe from being overwritten by arguments in alias. + + Args: + obj (Callable): The function or class being deprecated. + stage (str): The current stage in the deprecation cycle. + kwargs (dict): Keyword arguments. + alias (dict): A dictionary map optionally containing the + keys: + + reason (str): The reason for deprecation. Default None. + version (str): The release number of the latest change in + the deprecation cycle. Default None. + """ + + stage_string, warn_class = self.get_stage(name=stage) + reason = self.get_reason(**{"reason": reason}) + version = self.get_version(**{"version": version}) + + for old, new in alias.items(): + if old in kwargs: + if new in kwargs: + warn_string = ( + f"{version}{obj.__name__}", + f"received both {old} and {new} as arguments. {old}", + f"is {stage_string}", + f"Use {new} instead.", + f"{reason}", + ) + raise TypeError(" ".join(warn_string).strip()) + else: + warn_string = ( + f"{version}The argument {old} in {obj.__name__}", + f"is {stage_string}", + f"Use {new} instead.", + f"{reason}", + ) + warnings.warn( + message=" ".join(warn_string).strip(), + category=warn_class, + stacklevel=2, + ) + kwargs[new] = kwargs.pop(old) + + +class Decorators: + def __init__(self): + super().__init__() + + @staticmethod + def deprecated(stage="deprecated", **details): + """Decorator for deprecated function and method arguments. + + Example: + + .. code-block:: python + + @deprecated(stage="pending", reason="Foobar", version="1.3.2") + def some_function(foo): + ... + + + Args: + stage(str): Stage of deprecation cycle. Valid values + are: + + - pending + - deprecated + - eol + - defunct + + Default "deprecated". + Keyword Args: + reason (str): Reason for deprecation. + version (str): Release number of latest stage in deprecation. + + Returns: + Callable: Decorator for deprecated argument. + """ + + internals = DeprecationHandler() + + def decorator(f: Callable): + @functools.wraps(f) + def wrapper(*args, **kwargs): + internals.raise_warning(f, stage, details) + return f(*args, **kwargs) + + return wrapper + + return decorator + + @staticmethod + def deprecated_argument(stage="deprecated", reason=None, version=None, **aliases): + """Decorator for deprecated function and method arguments. + + Use as follows: + + .. code-block:: python + + @deprecated_argument(old_argument="new_argument") + def myfunc(new_arg): + ... + + Args: + stage(str): Stage of deprecation cycle. Valid values + are: + + - pending + - deprecated + - eol + - defunct + + Default "deprecated". + reason (str): Reason for deprecation. Default None. + version (str): Release number of latest stage in + deprecation. Default None. + aliases (dict[str, str]): Deprecated argument and its + alternative. + + Returns: + Callable: Decorator for deprecated argument. + + """ + + internals = DeprecationHandler() + + def decorator(f: Callable): + @functools.wraps(f) + def wrapper(*args, **kwargs): + internals.rename_arguments( + obj=f, + stage=stage, + reason=reason, + version=version, + kwargs=kwargs, + alias=aliases, + ) + return f(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/scintillometry/metrics/calculations.py b/src/scintillometry/metrics/calculations.py index 4b044d1..6eea166 100755 --- a/src/scintillometry/metrics/calculations.py +++ b/src/scintillometry/metrics/calculations.py @@ -27,6 +27,7 @@ from scintillometry.backend.iterations import IterationMost from scintillometry.backend.transects import TransectParameters from scintillometry.visuals.plotting import FigurePlotter +from scintillometry.backend.deprecations import Decorators class MetricsTopography: @@ -824,20 +825,22 @@ def plot_derived_metrics(self, derived_data, time_id, regime=None, location=""): return derived_plots - def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): + @Decorators.deprecated_argument( + stage="pending", version="1.0.5", site_location="location" + ) + def plot_iterated_metrics(self, iterated_data, time_stamp, location=""): """Plots and saves iterated SHF, comparison to free convection. .. todo:: - ST-126: Deprecate FigurePlotter.plot_iterated_fluxes in - favour of plot_iterated_metrics. `site_location` should - be deprecated in favour of `location`. + ST-126: Deprecate the argument `site_location` for + `location`. Args: iterated_data (pd.DataFrame): TZ-aware with columns for sensible heat fluxes calculated for free convection |H_free|, and by MOST |H|. time_stamp (pd.Timestamp): Local time of data collection. - site_location (str): Location of data collection. Default empty + location (str): Location of data collection. Default empty string. Returns: @@ -846,7 +849,7 @@ def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): sensible heat flux under free convection. """ - shf_plot = self.plotting.plot_generic(iterated_data, "shf", site=site_location) + shf_plot = self.plotting.plot_generic(iterated_data, "shf", site=location) self.plotting.save_figure( figure=shf_plot[0], timestamp=time_stamp, suffix="shf" ) @@ -856,7 +859,7 @@ def plot_iterated_metrics(self, iterated_data, time_stamp, site_location=""): df_02=iterated_data, keys=["H_free", "shf"], labels=["Free Convection", "Iteration"], - site=site_location, + site=location, ) self.plotting.save_figure( figure=comparison_plot[0], timestamp=time_stamp, suffix="shf_comp" @@ -974,7 +977,7 @@ def calculate_standard_metrics( self.plot_iterated_metrics( iterated_data=iterated_dataframe, time_stamp=data_timestamp, - site_location=location, + location=location, ) data["derivation"] = derived_dataframe diff --git a/src/scintillometry/visuals/plotting.py b/src/scintillometry/visuals/plotting.py index 68f5ccf..2819747 100644 --- a/src/scintillometry/visuals/plotting.py +++ b/src/scintillometry/visuals/plotting.py @@ -25,6 +25,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from scintillometry.backend.deprecations import Decorators class FigureFormat: @@ -582,6 +583,11 @@ def plot_comparison(self, df_01, df_02, keys, labels, site=""): return figure, axes + @Decorators.deprecated( + stage="pending", + reason="Superseded by MetricsFlux.plot_iterated_metrics.", + version="1.0.5", + ) def plot_iterated_fluxes(self, iteration_data, time_id, location=""): """Plots and saves iterated SHF, comparison to free convection. diff --git a/tests/test_backend_deprecations.py b/tests/test_backend_deprecations.py new file mode 100644 index 0000000..8ed5e6a --- /dev/null +++ b/tests/test_backend_deprecations.py @@ -0,0 +1,543 @@ +"""Copyright 2023 Scintillometry Contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +===== + +This module is used to test deprecation decorators, not deprecating +functions. + +Use the `conftest_boilerplate` fixture to avoid duplicating tests. +""" + +import inspect +from typing import Callable + +import pytest + +from scintillometry.backend.deprecations import Decorators, DeprecationHandler + + +class TestBackendDeprecationsMock: + """Test mock class targeted by decorators.""" + + class MockDeprecationsClass: + """Mock class to call when testing decorators.""" + + def add_one(self, a: int): + """Adds 1 to input integer.""" + + a += 1 + + return a + + @pytest.mark.dependency( + name="TestBackendDeprecationsMock::test_mock_deprecations_class" + ) + def test_mock_deprecations_class(self): + """Create mock class for testing deprecation.""" + assert inspect.isclass(self.MockDeprecationsClass) + mock_class = self.MockDeprecationsClass + assert mock_class.__name__ == "MockDeprecationsClass" + mock_class_instance = self.MockDeprecationsClass() + assert mock_class_instance + + @pytest.mark.dependency( + name="TestBackendDeprecationsMock::test_add_one", + depends=["TestBackendDeprecationsMock::test_mock_deprecations_class"], + ) + def test_add_one(self): + """Add 1 to integer.""" + + mock_class_instance = self.MockDeprecationsClass() + test_integer = 3 + compare_integer = mock_class_instance.add_one(a=test_integer) + assert isinstance(compare_integer, int) + assert compare_integer == test_integer + 1 + + +class TestBackendDeprecationsHandler: + """Handles deprecation methods. + + Attributes: + test_stages (dict[str, tuple[str, Exception]]): Names of + deprecation labels and respective text and warning category. + test_details (dict[str, str]): Mock details for decorator. + test_mock (type): Mock class called when testing decorators. + test_mock_instance (MockDeprecationsClass): An instantiated mock + class called when testing methods. + test_handler (DeprecationHandler): Class called by decorator + marking deprecation. + """ + + test_stages = { + "pending": ("pending deprecation.", PendingDeprecationWarning), + "deprecated": ("deprecated.", DeprecationWarning), + "eol": ("deprecated.", FutureWarning), + "defunct": ("defunct.", RuntimeError), + } + test_details = {"version": "1.1.2", "reason": "Some reason."} + test_mock = TestBackendDeprecationsMock.MockDeprecationsClass + test_mock_instance = TestBackendDeprecationsMock.MockDeprecationsClass() + test_handler = DeprecationHandler() + + def setup_warning( + self, + obj: Callable, + stage: str = "deprecated", + reason: bool = False, + version: bool = False, + ): + """Creates regex string for warning message.""" + + test_details = {} + suffix = "" + prefix = "" + if reason: + test_details["reason"] = self.test_details["reason"] + suffix = test_details["reason"] + if version: + test_details["version"] = self.test_details["version"] + prefix = f"Ver. {test_details['version']}: " + + suffix_string = " ".join((self.test_stages[stage][0], suffix)) + if inspect.isclass(obj): + assert obj.__name__ == "MockDeprecationsClass" + object_string = f"class {obj.__name__}" + else: + object_string = f"function {obj.__name__}" + + regex = f"{prefix}The {object_string} is {suffix_string}".strip() + + return regex, prefix, suffix, test_details + + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated"]) + @pytest.mark.parametrize("arg_class", [True, False]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_setup_warning", + depends=["TestBackendDeprecationsMock::test_add_one"], + ) + def test_setup_warning(self, arg_class, arg_stage, arg_reason, arg_version): + """Create regex string to test warning message.""" + + if arg_class: + test_object = self.test_mock + assert isinstance(test_object, type) + assert inspect.isclass(test_object) + else: + test_object = self.test_mock_instance.add_one + assert not isinstance(test_object, type) + assert not inspect.isclass(test_object) + assert isinstance(test_object, Callable) + + warn_params = self.setup_warning( + obj=test_object, + stage=arg_stage, + reason=arg_reason, + version=arg_version, + ) + assert isinstance(warn_params[0], str) + if arg_version: + assert warn_params[1] == f"Ver. {self.test_details['version']}: " + assert "version" in warn_params[3] + else: + assert warn_params[1] == "" + assert "version" not in warn_params[3] + if arg_reason: + assert warn_params[2] == self.test_details["reason"] + assert "reason" in warn_params[3] + else: + assert warn_params[2] == "" + assert "reason" not in warn_params[3] + + def get_mock_object(self, is_class: bool = False): + """Gets mock class or function.""" + + if is_class: + obj = self.test_mock + else: + obj = self.test_mock_instance.add_one + + return obj + + @pytest.mark.dependency(name="TestBackendDeprecationsHandler::test_get_mock_object") + @pytest.mark.parametrize("arg_class", [True, False]) + def test_get_mock_object(self, arg_class): + """Get mock class or function.""" + + compare_object = self.get_mock_object(is_class=arg_class) + if arg_class: + assert inspect.isclass(compare_object) + else: + assert not inspect.isclass(compare_object) + + @pytest.mark.dependency(name="TestBackendDeprecationsHandler::test_get_stage_error") + def test_get_stage_error(self): + """Raise error for incorrect or missing stage argument.""" + + test_stage = "incorrect stage" + with pytest.raises( + ValueError, match=f"{test_stage} is not a valid deprecation stage." + ): + self.test_handler.get_stage(name=test_stage) + with pytest.raises( + TypeError, match=f"{int} is an invalid type. Use {str} instead." + ): + self.test_handler.get_stage(name=1) + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_get_stage", + depends=["TestBackendDeprecationsHandler::test_get_stage_error"], + ) + def test_get_stage(self): + """Get deprecation stage and warning category.""" + + for key, value in self.test_stages.items(): + compare_stage, compare_warning = self.test_handler.get_stage(name=key) + assert compare_stage == value[0] + assert compare_warning == value[1] + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_get_reason_invalid" + ) + def test_get_reason_invalid(self): + """Return empty string for missing value or invalid type.""" + + test_kwargs = {"stage": "pending"} + compare_reason = self.test_handler.get_reason(**test_kwargs) + assert compare_reason == "" + compare_reason = self.test_handler.get_reason(reason=1) + assert compare_reason == "" + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_get_reason", + depends=["TestBackendDeprecationsHandler::test_get_reason_invalid"], + ) + def test_get_reason(self): + """Get reason for deprecation.""" + + test_kwargs = {"stage": "pending", "reason": "Some reason."} + + compare_reason = self.test_handler.get_reason(**test_kwargs) + assert compare_reason == test_kwargs["reason"] + compare_reason = self.test_handler.get_reason(reason=test_kwargs["reason"]) + assert compare_reason == test_kwargs["reason"] + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_get_version_invalid" + ) + def test_get_version_invalid(self): + """Return empty string for missing version or invalid type.""" + + test_kwargs = {"stage": "pending"} + compare_reason = self.test_handler.get_version(**test_kwargs) + assert compare_reason == "" + compare_reason = self.test_handler.get_version(reason=1) + assert compare_reason == "" + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_get_version", + depends=["TestBackendDeprecationsHandler::test_get_version_invalid"], + ) + def test_get_version(self): + """Get release number of last deprecation update.""" + + test_kwargs = {"stage": "pending", "version": "1.1.0"} + + compare_version = self.test_handler.get_version(**test_kwargs) + assert compare_version == f"Ver. {test_kwargs['version']}: " + compare_version = self.test_handler.get_version(version=test_kwargs["version"]) + assert compare_version == f"Ver. {test_kwargs['version']}: " + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_raise_warning", + depends=["TestBackendDeprecationsHandler::test_setup_warning"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_class", [True, False]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_raise_warning(self, arg_class, arg_stage, arg_reason, arg_version): + """Raise warning for deprecated object.""" + + assert arg_stage in self.test_stages + test_object = self.get_mock_object(is_class=arg_class) + test_warning = self.setup_warning( + obj=test_object, stage=arg_stage, reason=arg_reason, version=arg_version + ) + + with pytest.warns(self.test_stages[arg_stage][1], match=test_warning[0]): + self.test_handler.raise_warning( + obj=test_object, stage=arg_stage, details=test_warning[3] + ) + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_raise_warning_error", + depends=["TestBackendDeprecationsHandler::test_setup_warning"], + ) + @pytest.mark.parametrize("arg_class", [True, False]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_raise_warning_error(self, arg_reason, arg_version, arg_class): + """Raise RuntimeError when stage="defunct.""" + + assert "defunct" in self.test_stages + test_object = self.get_mock_object(is_class=arg_class) + test_warning = self.setup_warning( + obj=test_object, stage="defunct", reason=arg_reason, version=arg_version + ) + + with pytest.raises(self.test_stages["defunct"][1], match=test_warning[0]): + self.test_handler.raise_warning( + obj=test_object, stage="defunct", details=test_warning[3] + ) + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_rename_arguments", + depends=["TestBackendDeprecationsHandler::test_setup_warning"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_rename_arguments(self, arg_stage, arg_reason, arg_version): + """Raise warning and redirect deprecated argument.""" + + assert arg_stage in self.test_stages + test_object = self.get_mock_object(is_class=False) + test_warning = self.setup_warning( + obj=test_object, stage=arg_stage, reason=arg_reason, version=arg_version + ) + test_alias_old = "a" + test_alias_new = "b" + test_regex = ( + f"{test_warning[1]}The argument {test_alias_old} in {test_object.__name__}", + f"is {self.test_stages[arg_stage][0]}", + f"Use {test_alias_new} instead.", + f"{test_warning[3].get('reason','')}", + ) + test_kwargs = {test_alias_old: 2, "extra_arg": 3} + test_alias = {test_alias_old: test_alias_new} + + with pytest.warns( + self.test_stages[arg_stage][1], match=" ".join(test_regex).strip() + ): + self.test_handler.rename_arguments( + obj=test_object, + stage=arg_stage, + kwargs=test_kwargs, + alias=test_alias, + version=test_warning[3].get("version", None), + reason=test_warning[3].get("reason", None), + ) + assert test_alias_old not in test_kwargs # argument is replaced + assert test_alias_new in test_kwargs + assert test_kwargs[test_alias_new] == 2 + assert "extra_arg" in test_kwargs + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_rename_arguments_error", + depends=["TestBackendDeprecationsHandler::test_setup_warning"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_rename_arguments_error(self, arg_stage, arg_reason, arg_version): + """Raise warning and redirect deprecated argument.""" + + assert arg_stage in self.test_stages + test_object = self.get_mock_object(is_class=False) + test_warning = self.setup_warning( + obj=test_object, stage=arg_stage, reason=arg_reason, version=arg_version + ) + test_alias_old = "a" + test_alias_new = "b" + test_regex = ( + f"{test_warning[1]}{test_object.__name__}", + f"received both {test_alias_old} and {test_alias_new} as arguments.", + f"{test_alias_old} is {self.test_stages[arg_stage][0]} " + f"Use {test_alias_new} instead.", + f"{test_warning[3].get('reason','')}", + ) + test_kwargs = {test_alias_old: 1, test_alias_new: 3, "extra_arg": 2} + test_alias = {test_alias_old: test_alias_new} + + with pytest.raises(TypeError, match=" ".join(test_regex).strip()): + self.test_handler.rename_arguments( + obj=test_object, + stage=arg_stage, + kwargs=test_kwargs, + alias=test_alias, + version=test_warning[3].get("version", None), + reason=test_warning[3].get("reason", None), + ) + assert test_alias_old in test_kwargs + assert test_alias_new in test_kwargs + assert test_kwargs[test_alias_new] == 3 # argument is not replaced + assert "extra_arg" in test_kwargs + + @pytest.mark.dependency( + name="TestBackendDeprecationsHandler::test_rename_arguments_missing", + depends=["TestBackendDeprecationsHandler::test_setup_warning"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_rename_arguments_missing(self, arg_stage, arg_reason, arg_version): + """Raise warning and redirect deprecated argument.""" + + assert arg_stage in self.test_stages + test_object = self.get_mock_object(is_class=False) + test_warning = self.setup_warning( + obj=test_object, stage=arg_stage, reason=arg_reason, version=arg_version + ) + test_alias_old = "a" + test_alias_new = "b" + test_kwargs = {test_alias_new: 3, "extra_arg": 2} + test_alias = {test_alias_old: test_alias_new} + + self.test_handler.rename_arguments( + obj=test_object, + stage=arg_stage, + kwargs=test_kwargs, + alias=test_alias, + version=test_warning[3].get("version", None), + reason=test_warning[3].get("reason", None), + ) + assert test_alias_old not in test_kwargs + assert test_alias_new in test_kwargs + assert test_kwargs[test_alias_new] == 3 # argument is not replaced + assert "extra_arg" in test_kwargs + + +class TestBackendDeprecationsDecorator(TestBackendDeprecationsHandler): + """Tests decorators marking deprecated objects.""" + + @pytest.mark.dependency( + name="TestBackendDeprecationsDecorator::test_decorators_init" + ) + def test_decorators_init(self): + test_instance = Decorators() + assert test_instance + assert inspect.isclass(Decorators) + assert inspect.isfunction(Decorators.deprecated) + + @pytest.mark.dependency( + name="TestBackendDeprecationsDecorator::test_deprecated_decorator", + depends=["TestBackendDeprecationsMock::test_add_one"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_deprecated_decorator(self, arg_stage, arg_reason, arg_version): + """Use decorator to raise warning.""" + + assert arg_stage in self.test_stages + test_details = {} + if arg_version: + test_details["version"] = self.test_details["version"] + if arg_reason: + test_details["reason"] = self.test_details["reason"] + + @Decorators.deprecated(stage=arg_stage, **test_details) + def deprecated_one(x): # pragma: no cover + b = self.test_mock_instance.add_one(a=x) + return b + + test_warning = self.setup_warning( + obj=deprecated_one, stage=arg_stage, reason=arg_reason, version=arg_version + ) + + with pytest.warns(self.test_stages[arg_stage][1], match=test_warning[0]): + y = deprecated_one(x=1) + assert y == 2 + + @pytest.mark.dependency( + name="TestBackendDeprecationsDecorator::test_deprecated_error", + depends=[ + "TestBackendDeprecationsMock::test_add_one", + "TestBackendDeprecationsHandler::test_raise_warning", + "TestBackendDeprecationsHandler::test_raise_warning_error", + ], + ) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_deprecated_error(self, arg_reason, arg_version): + """Raise RuntimeError when stage="defunct.""" + + assert "defunct" in self.test_stages + test_details = {} + if arg_version: + test_details["version"] = self.test_details["version"] + if arg_reason: + test_details["reason"] = self.test_details["reason"] + + @Decorators.deprecated(stage="defunct", **test_details) + def deprecated_one(x): # pragma: no cover + b = self.test_mock_instance.add_one(a=x) + return b + + test_warning = self.setup_warning( + obj=deprecated_one, stage="defunct", reason=arg_reason, version=arg_version + ) + + with pytest.raises( + self.test_stages["defunct"][1], + match=test_warning[0], + ): + deprecated_one(x=1) + + @pytest.mark.dependency( + name="TestBackendDeprecationsDecorator::test_deprecated_argument_decorator", + depends=["TestBackendDeprecationsMock::test_add_one"], + ) + @pytest.mark.parametrize("arg_stage", ["pending", "deprecated", "eol"]) + @pytest.mark.parametrize("arg_reason", [True, False]) + @pytest.mark.parametrize("arg_version", [True, False]) + def test_deprecated_argument_decorator(self, arg_stage, arg_reason, arg_version): + """Use decorator to raise warning.""" + + assert arg_stage in self.test_stages + test_details = {} + if arg_version: + test_details["version"] = self.test_details["version"] + if arg_reason: + test_details["reason"] = self.test_details["reason"] + test_alias_new = "a" + + @Decorators.deprecated_argument( + stage=arg_stage, + reason=test_details.get("reason", ""), + version=test_details.get("version", ""), + x=test_alias_new, + ) + def deprecated_one(a): # pragma: no cover + b = self.test_mock_instance.add_one(a=a) + return b + + test_warning = self.setup_warning( + obj=deprecated_one, stage=arg_stage, reason=arg_reason, version=arg_version + ) + test_regex = ( + f"{test_warning[1]}The argument x in {deprecated_one.__name__}", + f"is {self.test_stages[arg_stage][0]}", + f"Use {test_alias_new} instead.", + f"{test_warning[3].get('reason','')}", + ) + + with pytest.warns( + self.test_stages[arg_stage][1], match=" ".join(test_regex).strip() + ): + y = deprecated_one(x=1) # pylint:disable=no-value-for-parameter + assert y == 2 diff --git a/tests/test_metrics_calculations.py b/tests/test_metrics_calculations.py index 468ed2c..98e75c0 100644 --- a/tests/test_metrics_calculations.py +++ b/tests/test_metrics_calculations.py @@ -908,7 +908,7 @@ def test_plot_iterated_metrics( compare_plots = self.test_metrics.plot_iterated_metrics( iterated_data=test_frame, time_stamp=test_stamp, - site_location=arg_location, + location=arg_location, ) assert isinstance(compare_plots, list) assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) @@ -932,6 +932,20 @@ def test_plot_iterated_metrics( plt.close("all") + # site_location is pending deprecation + with pytest.warns(PendingDeprecationWarning): + compare_plots = self.test_metrics.plot_iterated_metrics( + iterated_data=test_frame, + time_stamp=test_stamp, + site_location=arg_location, # pylint:disable=unexpected-keyword-arg + ) + assert isinstance(compare_plots, list) + assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) + for params in compare_params.values(): + conftest_boilerplate.check_plot(plot_params=params, title=test_title) + + plt.close("all") + @pytest.mark.dependency( name="TestMetricsFlux::test_iterate_fluxes", depends=["TestMetricsFlux::test_append_vertical_variables"], diff --git a/tests/test_visuals_plotting.py b/tests/test_visuals_plotting.py index 7fd3f4c..043ba21 100644 --- a/tests/test_visuals_plotting.py +++ b/tests/test_visuals_plotting.py @@ -595,11 +595,12 @@ def test_plot_iterated_fluxes( test_title = f"{test_location},\n{self.test_date}" timestamp = test_data.index[0] - compare_plots = self.test_plotting.plot_iterated_fluxes( - iteration_data=test_data, - time_id=timestamp, - location=arg_location, - ) + with pytest.deprecated_call(): + compare_plots = self.test_plotting.plot_iterated_fluxes( + iteration_data=test_data, + time_id=timestamp, + location=arg_location, + ) assert isinstance(compare_plots, list) assert all(isinstance(compare_tuple, tuple) for compare_tuple in compare_plots) From 70f9d9832ee4383dc9690128ec6951b326c5cdb6 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Mon, 29 May 2023 16:24:09 +0200 Subject: [PATCH 16/17] build: toml license refers to name instead of file Refs: ST-1, ST-2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7ff6d7a..ee0c88c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "Analyse data & 2D flux footprints from Scintec's BLS scintillometers." readme = "README.md" requires-python = ">=3.8" -license = {file = "LICENSE"} +license = Apache-2.0 classifiers = [ "Private :: Do Not Upload", "License :: OSI Approved :: Apache Software License", From c948a21661057343ff49513cb09ff226f8996595 Mon Sep 17 00:00:00 2001 From: gampnico <45390064+gampnico@users.noreply.github.com> Date: Mon, 29 May 2023 16:26:35 +0200 Subject: [PATCH 17/17] fix: incorrect toml syntax Refs: ST-1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ee0c88c..4ddb4c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "Analyse data & 2D flux footprints from Scintec's BLS scintillometers." readme = "README.md" requires-python = ">=3.8" -license = Apache-2.0 +license = "Apache-2.0" classifiers = [ "Private :: Do Not Upload", "License :: OSI Approved :: Apache Software License",