Skip to content

Commit

Permalink
Refactoring test_unsupported_object
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddhantSadangi committed Jan 18, 2024
1 parent 04eb8b5 commit 549a798
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
init_run,
)
except ImportError:
from neptune.new import init_run, Run
from neptune.new import Run, init_run

import pytest
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.dummy import (
DummyClassifier,
DummyRegressor,
)
from sklearn.model_selection import (
GridSearchCV,
train_test_split,
)
from sklearn.model_selection import GridSearchCV

import neptune_sklearn as npt_utils

Expand Down Expand Up @@ -60,28 +56,26 @@ def test_kmeans_summary(iris):


@pytest.mark.filterwarnings("error::neptune.common.warnings.NeptuneUnsupportedType")
def test_unsupported_object():
def test_unsupported_object(diabetes):
"""This method checks if Neptune throws a `NeptuneUnsupportedType` warning if expected metadata
is not found or skips trying to log such metadata"""

with init_run() as run:

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

model = DummyRegressor()
model.fit(diabetes.x_train, diabetes.y_train)

param_grid = {
"strategy": ["mean", "median", "quantile"],
"quantile": [0.1, 0.5, 1.0],
}

X, y = datasets.fetch_california_housing(return_X_y=True)[:10]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

grid_cv = GridSearchCV(model, param_grid, scoring="neg_mean_absolute_error", cv=2).fit(X_train, y_train)
grid_cv = GridSearchCV(model, param_grid, scoring="neg_mean_absolute_error", cv=2).fit(
diabetes.x_train, diabetes.y_train
)

run["regressor_summary"] = npt_utils.create_regressor_summary(grid_cv, X_train, X_test, y_train, y_test)
run["regressor_summary"] = npt_utils.create_regressor_summary(
grid_cv, diabetes.x_train, diabetes.x_test, diabetes.y_train, diabetes.y_test
)

run.wait()

Expand Down

0 comments on commit 549a798

Please sign in to comment.