Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: using common test primitive for banned primitives check in locations #1867

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions onedal/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,46 @@
# ==============================================================================

import os
import pkgutil
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
from glob import glob

import pytest


def test_sklearn_check_version_ban():
"""This test blocks the use of sklearn_check_version
in onedal files. The versioning should occur in the
sklearnex package for clarity and maintainability.
def _check_primitive_usage_ban(primitive_name, package, allowed_locations=None):
"""This test blocks the usage of the primitive in
in certain files.
"""
from onedal import __file__ as loc

# TODO:
# Address deprecation warning.
# The function "get_loader" is deprecated Use importlib.util.find_spec() instead.
# Will be removed in Python 3.14.
loc = pkgutil.get_loader(package).get_filename()
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved

path = loc.replace("__init__.py", "")
files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]

output = []

for f in files:
if open(f, "r").read().find("sklearn_check_version") != -1:
output += [f.replace(path, "onedal" + os.sep)]
if open(f, "r").read().find(primitive_name) != -1:
output += [f.replace(path, package + os.sep)]

# remove this file from the list
if allowed_locations:
for allowed in allowed_locations:
output = [i for i in output if allowed not in i]

return output


def test_sklearn_check_version_ban():
"""This test blocks the use of sklearn_check_version
in onedal files. The versioning should occur in the
sklearnex package for clarity and maintainability.
"""
output = _check_primitive_usage_ban(
primitive_name="sklearn_check_version", package="onedal"
)

# remove this file from the list
output = "\n".join([i for i in output if "test_common.py" not in i])
Expand Down
29 changes: 8 additions & 21 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
import re
import sys
import trace
from glob import glob

import numpy as np
import pytest
import scipy
import sklearn.utils.validation
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
from sklearn.utils import all_estimators

from daal4py.sklearn._utils import sklearn_check_version
from onedal.tests.test_common import _check_primitive_usage_ban
from sklearnex.tests.utils import (
PATCHED_MODELS,
SPECIAL_INSTANCES,
Expand All @@ -38,7 +35,7 @@
gen_models_info,
)

ALLOWED_LOCATIONS = [
TARGET_OFFLOAD_ALLOWED_LOCATIONS = [
"_config.py",
"_device_offload.py",
"test",
Expand Down Expand Up @@ -109,23 +106,13 @@ def test_target_offload_ban():
within the architecture of the sklearnex classes. This
is for clarity, traceability and maintainability.
"""
from sklearnex import __file__ as loc

path = loc.replace("__init__.py", "")
files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]

output = []

for f in files:
if open(f, "r").read().find("target_offload") != -1:
output += [f.replace(path, "sklearnex" + os.sep)]

# remove this file from the list
for allowed in ALLOWED_LOCATIONS:
output = [i for i in output if allowed not in i]

output = _check_primitive_usage_ban(
primitive_name="target_offload",
package="sklearnex",
allowed_locations=TARGET_OFFLOAD_ALLOWED_LOCATIONS,
)
output = "\n".join(output)
assert output == "", f"sklearn versioning is occuring in: \n{output}"
assert output == "", f"target offloading is occuring in: \n{output}"
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved


def _sklearnex_walk(func):
Expand Down