Skip to content

Commit

Permalink
Merge pull request #21 from neptune-ai/ss/handle_NeptuneUnsupportedTy…
Browse files Browse the repository at this point in the history
…pe_error

Supressing `NeptuneUnsupportedType` warning if expected metadata not found
  • Loading branch information
SiddhantSadangi authored Jan 17, 2024
2 parents 1e91d11 + bb1b212 commit 0a96ba1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## [UNRELEASED] neptune-sklearn 2.1.1

### Fixes
- `create_*_summary()` now does not throw a `NeptuneUnsupportedType` error if expected metadata is not found ([#21](https://github.com/neptune-ai/neptune-sklearn/pull/21))
- Fixed method names in docstrings ([#18](https://github.com/neptune-ai/neptune-sklearn/pull/18))

## neptune-sklearn 2.1.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ name = "neptune-sklearn"
readme = "README.md"
version = "0.0.0"
classifiers = [
"Development Status :: 4 - Beta",
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
Expand Down
71 changes: 46 additions & 25 deletions src/neptune_sklearn/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def create_regressor_summary(regressor, X_train, X_test, y_train, y_test, nrows=
log_charts (`bool`, optional): Whether to calculate and log chart visualizations.
Note: Calculating visualizations is potentially expensive depending on input data and regressor,
and may take some time to finish. This is equivalent to calling the following functions from
this module: `create_learning_curve_chart()`, `create_feature_importance_chart()`, `create_residuals_chart()`,
`create_prediction_error_chart()`, and `create_cooks_distance_chart()`.
this module: `create_learning_curve_chart()`, `create_feature_importance_chart()`,
`create_residuals_chart()`, `create_prediction_error_chart()`, and `create_cooks_distance_chart()`.
Returns:
`dict` with all summary items.
Expand Down Expand Up @@ -144,17 +144,26 @@ def create_regressor_summary(regressor, X_train, X_test, y_train, y_test, nrows=
"scores": get_scores(regressor, X_test, y_test, y_pred=y_pred),
}

if log_charts:
reg_summary["diagnostics_charts"] = {
"learning_curve": create_learning_curve_chart(regressor, X_train, y_train),
"feature_importance": create_feature_importance_chart(regressor, X_train, y_train),
"residuals": create_residuals_chart(regressor, X_train, X_test, y_train, y_test),
"prediction_error": create_prediction_error_chart(regressor, X_train, X_test, y_train, y_test),
"cooks_distance": create_cooks_distance_chart(regressor, X_train, y_train),
}

reg_summary["integration/about/neptune-sklearn"] = __version__

if log_charts:
learning_curve = create_learning_curve_chart(regressor, X_train, y_train)
feature_importance = create_feature_importance_chart(regressor, X_train, y_train)
residuals = create_residuals_chart(regressor, X_train, X_test, y_train, y_test)
prediction_error = create_prediction_error_chart(regressor, X_train, X_test, y_train, y_test)
cooks_distance = create_cooks_distance_chart(regressor, X_train, y_train)

if learning_curve:
reg_summary["diagnostics_charts/learning_curve"] = learning_curve
if feature_importance:
reg_summary["diagnostics_charts/feature_importance"] = feature_importance
if residuals:
reg_summary["diagnostics_charts/residuals"] = residuals
if prediction_error:
reg_summary["diagnostics_charts/prediction_error"] = prediction_error
if cooks_distance:
reg_summary["diagnostics_charts/cooks_distance"] = cooks_distance

return reg_summary


Expand Down Expand Up @@ -217,17 +226,26 @@ def create_classifier_summary(classifier, X_train, X_test, y_train, y_test, nrow
"scores": get_scores(classifier, X_test, y_test, y_pred=y_pred),
}

if log_charts:
cls_summary["diagnostics_charts"] = {
"classification_report": create_classification_report_chart(classifier, X_train, X_test, y_train, y_test),
"confusion_matrix": create_confusion_matrix_chart(classifier, X_train, X_test, y_train, y_test),
"ROC_AUC": create_roc_auc_chart(classifier, X_train, X_test, y_train, y_test),
"precision_recall": create_precision_recall_chart(classifier, X_test, y_test),
"class_prediction_error": create_class_prediction_error_chart(classifier, X_train, X_test, y_train, y_test),
}

cls_summary["integration/about/neptune-sklearn"] = __version__

if log_charts:
classification_report = create_classification_report_chart(classifier, X_train, X_test, y_train, y_test)
confusion_matrix = create_confusion_matrix_chart(classifier, X_train, X_test, y_train, y_test)
roc_auc = create_roc_auc_chart(classifier, X_train, X_test, y_train, y_test)
precision_recall = create_precision_recall_chart(classifier, X_test, y_test)
class_prediction_error = create_class_prediction_error_chart(classifier, X_train, X_test, y_train, y_test)

if classification_report:
cls_summary["diagnostics_charts/classification_report"] = classification_report
if confusion_matrix:
cls_summary["diagnostics_charts/confusion_matrix"] = confusion_matrix
if roc_auc:
cls_summary["diagnostics_charts/ROC_AUC"] = roc_auc
if precision_recall:
cls_summary["diagnostics_charts/precision_recall"] = precision_recall
if class_prediction_error:
cls_summary["diagnostics_charts/class_prediction_error"] = class_prediction_error

return cls_summary


Expand Down Expand Up @@ -272,13 +290,16 @@ def create_kmeans_summary(model, X, nrows=1000, **kwargs):
kmeans_summary["all_params"] = stringify_unsupported(get_estimator_params(model))
kmeans_summary["pickled_model"] = get_pickled_model(model)
kmeans_summary["cluster_labels"] = get_cluster_labels(model, X, nrows=nrows, **kwargs)
kmeans_summary["diagnostics_charts"] = {
"kelbow": create_kelbow_chart(model, X, **kwargs),
"silhouette": create_silhouette_chart(model, X, **kwargs),
}

kmeans_summary["integration/about/neptune-sklearn"] = __version__

kelbow = create_kelbow_chart(model, X, **kwargs)
silhouette = create_silhouette_chart(model, X, **kwargs)

if kelbow:
kmeans_summary["diagnostics_charts/kelbow"] = kelbow
if silhouette:
kmeans_summary["diagnostics_charts/silhouette"] = silhouette

return kmeans_summary


Expand Down
34 changes: 33 additions & 1 deletion tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
except ImportError:
from neptune.new import init_run

import pytest
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.dummy import DummyRegressor
from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
)
from sklearn.model_selection import train_test_split
from sklearn.model_selection import (
GridSearchCV,
train_test_split,
)

import neptune_sklearn as npt_utils

Expand Down Expand Up @@ -61,6 +66,33 @@ def test_kmeans_summary():
validate_run(run, log_charts=True)


@pytest.mark.filterwarnings("error::neptune.common.warnings.NeptuneUnsupportedType")
def test_unsupported_object():
"""This method checks if Neptune throws a `NeptuneUnsupportedType` warning if expected metadata
is not found or skips trying to log such metadata"""

with init_run() as run:

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

model = DummyRegressor()

param_grid = {
"strategy": ["mean", "median", "quantile"],
"quantile": [0.1, 0.5, 1.0],
}

X, y = datasets.fetch_california_housing(return_X_y=True)[:10]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

grid_cv = GridSearchCV(model, param_grid, scoring="neg_mean_absolute_error", cv=2).fit(X_train, y_train)

run["regressor_summary"] = npt_utils.create_regressor_summary(grid_cv, X_train, X_test, y_train, y_test)

run.wait()


def validate_run(run, log_charts):
assert run.exists("summary/all_params")
assert run.exists("summary/pickled_model")
Expand Down

0 comments on commit 0a96ba1

Please sign in to comment.