From 71504fa4f155809dd8a93b973c69707c598e70d7 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Tue, 8 Oct 2024 14:30:26 +0200 Subject: [PATCH] adding pytest mark for db requirement --- publications/2023-neurips/pytest.ini | 4 +- .../2023-neurips/test/test_exctraction.py | 55 +++++++++++-------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/publications/2023-neurips/pytest.ini b/publications/2023-neurips/pytest.ini index cd61f8d..ef2aa17 100644 --- a/publications/2023-neurips/pytest.ini +++ b/publications/2023-neurips/pytest.ini @@ -1,4 +1,6 @@ [pytest] -norecursedirs = .git +norecursedirs = .git +markers = + db: marks to define a test that requires a local database from LCDB data. filterwarnings = ignore:The objective has been evaluated at this point before.:UserWarning \ No newline at end of file diff --git a/publications/2023-neurips/test/test_exctraction.py b/publications/2023-neurips/test/test_exctraction.py index 2510ffe..f9b3823 100644 --- a/publications/2023-neurips/test/test_exctraction.py +++ b/publications/2023-neurips/test/test_exctraction.py @@ -1,23 +1,28 @@ -from lcdb.db import LCDB - -from parameterized import parameterized +import pytest import unittest import numpy as np from lcdb.analysis.util import LearningCurveExtractor, merge_curves +from lcdb.db import LCDB +from parameterized import parameterized +@pytest.mark.db class TestExtractors(unittest.TestCase): - @parameterized.expand([ - (6, "lcdb.workflow.sklearn.KNNWorkflow", 0, 0, 42), - (3, "lcdb.workflow.sklearn.LibLinearWorkflow", 0, 0, 42), - (3, "lcdb.workflow.sklearn.LibSVMWorkflow", 0, 0, 42), - (6, "lcdb.workflow.sklearn.TreesEnsembleWorkflow", 0, 1, 42) - ]) - def test_learning_curve_extraction(self, openmlid, workflow, val_seed, test_seed, workflow_seed): + @parameterized.expand( + [ + (6, "lcdb.workflow.sklearn.KNNWorkflow", 0, 0, 42), + (3, "lcdb.workflow.sklearn.LibLinearWorkflow", 0, 0, 42), + (3, "lcdb.workflow.sklearn.LibSVMWorkflow", 0, 0, 42), + (6, "lcdb.workflow.sklearn.TreesEnsembleWorkflow", 0, 1, 42), + ] + ) + def test_learning_curve_extraction( + self, openmlid, workflow, val_seed, test_seed, workflow_seed + ): - metrics = ["error_rate"]#, "balanced_error_rate"] + metrics = ["error_rate"] # , "balanced_error_rate"] lcdb = LCDB() df = lcdb.query( @@ -29,11 +34,10 @@ def test_learning_curve_extraction(self, openmlid, workflow, val_seed, test_seed workflow_seeds=[workflow_seed], processors={ "learning_curve": LearningCurveExtractor( - metrics=metrics, - folds=["train", "val", "test", "oob"] + metrics=metrics, folds=["train", "val", "test", "oob"] ) }, - show_progress=True + show_progress=True, ) oob_fold_expected = "TreesEnsembleWorkflow" in workflow @@ -53,12 +57,14 @@ def test_learning_curve_extraction(self, openmlid, workflow, val_seed, test_seed self.assertTrue(not oob_fold_expected or num_oob_not_nan > 0) - @parameterized.expand([ - #(6, "lcdb.workflow.sklearn.KNNWorkflow"), - (3, "lcdb.workflow.sklearn.LibLinearWorkflow"), - (3, "lcdb.workflow.sklearn.LibSVMWorkflow"), - #(6, "lcdb.workflow.sklearn.TreesEnsembleWorkflow") - ]) + @parameterized.expand( + [ + # (6, "lcdb.workflow.sklearn.KNNWorkflow"), + (3, "lcdb.workflow.sklearn.LibLinearWorkflow"), + (3, "lcdb.workflow.sklearn.LibSVMWorkflow"), + # (6, "lcdb.workflow.sklearn.TreesEnsembleWorkflow") + ] + ) def test_learning_curve_grouping_after_extraction(self, openmlid, workflow): lcdb = LCDB() @@ -71,16 +77,17 @@ def test_learning_curve_grouping_after_extraction(self, openmlid, workflow): validation_seeds=validation_seeds, processors={ "learning_curve": LearningCurveExtractor( - metrics=["error_rate"], - folds=["train", "val", "test", "oob"] + metrics=["error_rate"], folds=["train", "val", "test", "oob"] ) }, - show_progress=True + show_progress=True, ) if df is not None: config_cols = [c for c in df.columns if c.startswith("p:")] len_before = len(df) - len_after = len(df.groupby(config_cols).agg({"learning_curve": merge_curves})) + len_after = len( + df.groupby(config_cols).agg({"learning_curve": merge_curves}) + ) self.assertEqual(len_before, len_after * len(validation_seeds))