Skip to content

Update scikit-learn requirement in SklearnIntegration #3551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/zenml/integrations/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -21,12 +21,12 @@ class SklearnIntegration(Integration):
"""Definition of sklearn integration for ZenML."""

NAME = SKLEARN
REQUIREMENTS = ["scikit-learn", "scikit-image"]
REQUIREMENTS = ["scikit-learn"]

@classmethod
def activate(cls) -> None:
"""Activates the integration."""
from zenml.integrations.sklearn import materializers # noqa

SklearnIntegration.check_installation()

SklearnIntegration.check_installation()
26 changes: 17 additions & 9 deletions tests/integration/examples/mlflow/steps/dynamic_importer_step.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This step downloads the latest data from a mock API and returns it as a numpy array."""

# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,13 +16,17 @@
import numpy as np # type: ignore [import]
import pandas as pd # type: ignore [import]
import requests # type: ignore [import]
from skimage.transform import resize
from typing_extensions import Annotated

from zenml import step


def get_data_from_api():
def get_data_from_api() -> Annotated[np.ndarray, "api_data"]:
"""Downloads the latest data from a mock API.
Returns:
Annotated[np.ndarray, "data"]: Downsampled image data as a numpy array.
"""
url = (
"https://storage.googleapis.com/zenml-public-bucket/mnist"
"/mnist_handwritten_test.json"
@@ -30,12 +36,10 @@ def get_data_from_api():
data = df["image"].map(lambda x: np.array(x)).values
data = np.array(
[
resize(
x.reshape(28, 28).astype("uint8"),
(8, 8),
anti_aliasing=False,
preserve_range=True,
)
# Pad the image to 32x32 to enable downsampling to 8x8
np.pad(x.reshape(28, 28).astype("float64"), 2)[
::4, ::4
] # Downsample to 8x8 by taking every 4th pixel
for x in data
]
)
@@ -44,6 +48,10 @@ def get_data_from_api():

@step(enable_cache=False)
def dynamic_importer() -> Annotated[np.ndarray, "data"]:
"""Downloads the latest data from a mock API."""
"""Downloads the latest data from a mock API.
Returns:
Annotated[np.ndarray, "data"]: Downsampled image data as a numpy array.
"""
data = get_data_from_api()
return data