Skip to content

Commit

Permalink
Refactored test
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddhantSadangi committed Jan 16, 2024
1 parent be80c06 commit 4e9d2cd
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
LinearRegression,
LogisticRegression,
)
from sklearn.model_selection import train_test_split
from sklearn.model_selection import (
GridSearchCV,
train_test_split,
)

import neptune_sklearn as npt_utils

Expand Down Expand Up @@ -67,28 +70,26 @@ def test_unsupported_object():
"""This method checks if Neptune throws a `NeptuneUnsupportedType` warning if expected metadata
is not found or skips trying to log such metadata"""

from sklearn.model_selection import GridSearchCV
with init_run() as run:

run = init_run()
X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
model = LinearRegression()

model = LinearRegression()

param_grid = {
"copy_X": [True, False],
"fit_intercept": [True, False],
}
param_grid = {
"copy_X": [True, False],
"fit_intercept": [True, False],
}

X, y = datasets.fetch_california_housing(return_X_y=True)[:10]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X, y = datasets.fetch_california_housing(return_X_y=True)[:10]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

grid_cv = GridSearchCV(model, param_grid, scoring="neg_mean_absolute_error", cv=2).fit(X_train, y_train)
grid_cv = GridSearchCV(model, param_grid, scoring="neg_mean_absolute_error", cv=2).fit(X_train, y_train)

run["regressor_summary"] = npt_utils.create_regressor_summary(grid_cv, X_train, X_test, y_train, y_test)
run["regressor_summary"] = npt_utils.create_regressor_summary(grid_cv, X_train, X_test, y_train, y_test)

run.wait()
run.wait()


def validate_run(run, log_charts):
Expand Down

0 comments on commit 4e9d2cd

Please sign in to comment.