From bb182aaf8667c0002f7d96585b86a3a15fc050d8 Mon Sep 17 00:00:00 2001 From: Vyacheslav Morov Date: Tue, 20 Aug 2024 15:42:13 +0200 Subject: [PATCH] Fix missing feature_type in ColumnDriftMetric. (#1258) --- setup.py | 4 ++-- .../metrics/data_drift/column_drift_metric.py | 4 ++-- tests/test_setup.py | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index dcee2fc1c0..5b9c514d76 100644 --- a/setup.py +++ b/setup.py @@ -55,8 +55,8 @@ "statsmodels>=0.12.2", "scikit-learn>=1.0.1", "pandas[parquet]>=1.3.5", - "numpy>=1.22.0", - "nltk>=3.6.7", + "numpy>=1.22.0,<2.1", + "nltk>=3.6.7,<=3.8.1", "scipy>=1.10.0", "requests>=2.32.0", "PyYAML>=5.4", diff --git a/src/evidently/metrics/data_drift/column_drift_metric.py b/src/evidently/metrics/data_drift/column_drift_metric.py index a8242e5ae6..bfd6bfeca7 100644 --- a/src/evidently/metrics/data_drift/column_drift_metric.py +++ b/src/evidently/metrics/data_drift/column_drift_metric.py @@ -283,8 +283,8 @@ def calculate(self, data: InputData) -> ColumnDataDriftMetrics: if self.column_name.is_main_dataset(): column_type = data.data_definition.get_column(self.column_name.name).column_type else: - if self.column_name._feature_class is not None: - column_type = self.column_name._feature_class.feature_type + if self.column_name.feature_class is not None: + column_type = self.column_name.feature_class.get_type(self.column_name.name) datetime_column = data.data_definition.get_datetime_column() options = DataDriftOptions(all_features_stattest=self.stattest, threshold=self.stattest_threshold) diff --git a/tests/test_setup.py b/tests/test_setup.py index c87918ed2d..8a306c39c8 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -7,9 +7,16 @@ def test_minimal_requirements(): path = Path(__file__).parent.parent with open(path / "requirements.min.txt") as f: lines = {line.strip().split("#")[0] for line in f.readlines()} - min_reqs = {k.split("[")[0]: v for line in lines if line.strip() for k, v in (line.strip().split("=="),)} + min_reqs = { + k.split("[")[0]: _get_min_version(v) + for line in lines + if line.strip() + for k, v in (line.strip().split("=="),) + } - install_reqs = {k.split("[")[0]: v for r in setup_args["install_requires"] for k, v in (r.split(">="),)} + install_reqs = { + k.split("[")[0]: _get_min_version(v) for r in setup_args["install_requires"] for k, v in (r.split(">="),) + } extra = [] wrong_version = [] for m, v in install_reqs.items(): @@ -23,3 +30,7 @@ def test_minimal_requirements(): assert ( len(extra) == 0 and len(wrong_version) == 0 ), f"install_requires has extra reqs {extra} and wrong versions of {wrong_version}" + + +def _get_min_version(value): + return value.split(",")[0]