Skip to content

Commit

Permalink
MAINT: using common test primitive for banned primitives check in loc…
Browse files Browse the repository at this point in the history
…ations (#1867)

* TEST: moved checks for banned primitives into certain folder, out of pkgs

* Generalized for further use of checker for banned primitives

* some fixes

* refactoring
  • Loading branch information
samir-nasibli authored Oct 23, 2024
1 parent f350c0d commit 3936cb9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
34 changes: 25 additions & 9 deletions onedal/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,43 @@
# limitations under the License.
# ==============================================================================

import importlib
import os
from glob import glob

import pytest


def test_sklearn_check_version_ban():
"""This test blocks the use of sklearn_check_version
in onedal files. The versioning should occur in the
sklearnex package for clarity and maintainability.
def _check_primitive_usage_ban(primitive_name, package, allowed_locations=None):
"""This test blocks the usage of the primitive in
in certain files.
"""
from onedal import __file__ as loc

loc = importlib.util.find_spec(package).origin

path = loc.replace("__init__.py", "")
files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]

output = []

for f in files:
if open(f, "r").read().find("sklearn_check_version") != -1:
output += [f.replace(path, "onedal" + os.sep)]
if open(f, "r").read().find(primitive_name) != -1:
output += [f.replace(path, package + os.sep)]

# remove this file from the list
if allowed_locations:
for allowed in allowed_locations:
output = [i for i in output if allowed not in i]

return output


def test_sklearn_check_version_ban():
"""This test blocks the use of sklearn_check_version
in onedal files. The versioning should occur in the
sklearnex package for clarity and maintainability.
"""
output = _check_primitive_usage_ban(
primitive_name="sklearn_check_version", package="onedal"
)

# remove this file from the list
output = "\n".join([i for i in output if "test_common.py" not in i])
Expand Down
29 changes: 8 additions & 21 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
import re
import sys
import trace
from glob import glob

import numpy as np
import pytest
import scipy
import sklearn.utils.validation
from sklearn.utils import all_estimators

from daal4py.sklearn._utils import sklearn_check_version
from onedal.tests.test_common import _check_primitive_usage_ban
from sklearnex.tests.utils import (
PATCHED_MODELS,
SPECIAL_INSTANCES,
Expand All @@ -38,7 +35,7 @@
gen_models_info,
)

ALLOWED_LOCATIONS = [
TARGET_OFFLOAD_ALLOWED_LOCATIONS = [
"_config.py",
"_device_offload.py",
"test",
Expand Down Expand Up @@ -109,23 +106,13 @@ def test_target_offload_ban():
within the architecture of the sklearnex classes. This
is for clarity, traceability and maintainability.
"""
from sklearnex import __file__ as loc

path = loc.replace("__init__.py", "")
files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]

output = []

for f in files:
if open(f, "r").read().find("target_offload") != -1:
output += [f.replace(path, "sklearnex" + os.sep)]

# remove this file from the list
for allowed in ALLOWED_LOCATIONS:
output = [i for i in output if allowed not in i]

output = _check_primitive_usage_ban(
primitive_name="target_offload",
package="sklearnex",
allowed_locations=TARGET_OFFLOAD_ALLOWED_LOCATIONS,
)
output = "\n".join(output)
assert output == "", f"sklearn versioning is occuring in: \n{output}"
assert output == "", f"target offloading is occuring in: \n{output}"


def _sklearnex_walk(func):
Expand Down

0 comments on commit 3936cb9

Please sign in to comment.