-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CU-8693n892x environment/dependency snapshots (#438)
* CU-8693n892x: Save environment/dependency snapshot upon model pack creation * CU-8693n892x: Fix typing for env snapshot module * CU-8693n892x: Add test for env file existance in .zip * CU-8693n892x: Add doc strings * CU-8693n892x: Centralise env snapshot file name * CU-8693n892x: Add env snapshot file to exceptions in serialisation tests * CU-8693n892x: Only list direct dependencies * CU-8693n892x: Add test that verifies all direct dependencies are listed in environment * CU-8693n892x: Move requirements to separate file and use that for environment snapshot * CU-8693n892x: Remove unused constants * CU-8693n892x: Allow URL based dependencies when using direct dependencies * CU-8693n892x: Distribute install_requires.txt alongside the package; use correct path in distributed version
- Loading branch information
Showing
6 changed files
with
225 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy | ||
'pandas>=1.4.2' # first to support 3.11 | ||
'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump | ||
'spacy>=3.6.0,<4.0.0' # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump | ||
'scipy~=1.9.2' # 1.9.2 is first to support 3.11 | ||
'transformers>=4.34.0,<5.0.0' # avoid major version bump | ||
'accelerate>=0.23.0' # required by Trainer class in de-id | ||
'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now | ||
'tqdm>=4.27' | ||
'scikit-learn>=1.1.3,<2.0.0' # 1.1.3 is first to supporrt 3.11; avoid major version bump | ||
'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump | ||
'datasets>=2.2.2,<3.0.0' # avoid major bump | ||
'jsonpickle>=2.0.0' # allow later versions, tested with 3.0.0 | ||
'psutil>=5.8.0' | ||
# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets | ||
'multiprocess~=0.70.12' # 0.70.14 seemed to work just fine | ||
'aiofiles>=0.8.0' # allow later versions, tested with 22.1.0 | ||
'ipywidgets>=7.6.5' # allow later versions, tested with 0.8.0 | ||
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0 | ||
'blis>=0.7.5' # allow later versions, tested with 0.7.9 | ||
'click>=8.0.4' # allow later versions, tested with 8.1.3 | ||
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes | ||
"humanfriendly~=10.0" # for human readable file / RAM sizes | ||
"peft>=0.8.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import List, Dict, Any, Set | ||
|
||
import os | ||
import re | ||
import pkg_resources | ||
import platform | ||
|
||
|
||
ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json" | ||
|
||
INSTALL_REQUIRES_FILE_PATH = os.path.join(os.path.dirname(__file__), | ||
"..", "..", "..", | ||
"install_requires.txt") | ||
# NOTE: The install_requires.txt file is copied into the wheel during build | ||
# so that it can be included in the distributed package. | ||
# However, that means it's 1 folder closer to this file since it'll now | ||
# be in the root of the package rather than the root of the project. | ||
INSTALL_REQUIRES_FILE_PATH_PIP = os.path.join(os.path.dirname(__file__), | ||
"..", "..", | ||
"install_requires.txt") | ||
|
||
|
||
def get_direct_dependencies() -> Set[str]: | ||
"""Get the set of direct dependeny names. | ||
The current implementation reads install_requires.txt for dependenceies, | ||
removes comments, whitespace, quotes; removes the versions and returns | ||
the names as a set. | ||
Returns: | ||
Set[str]: The set of direct dependeny names. | ||
""" | ||
req_file = INSTALL_REQUIRES_FILE_PATH | ||
if not os.path.exists(req_file): | ||
# When pip-installed. See note above near constant definiation | ||
req_file = INSTALL_REQUIRES_FILE_PATH_PIP | ||
with open(req_file) as f: | ||
# read every line, strip quotes and comments | ||
dep_lines = [line.split("#")[0].replace("'", "").replace('"', "").strip() for line in f.readlines()] | ||
# remove comment-only (or empty) lines | ||
deps = [dep for dep in dep_lines if dep] | ||
return set(re.split("[@<=>~]", dep)[0].strip() for dep in deps) | ||
|
||
|
||
def get_installed_packages() -> List[List[str]]: | ||
"""Get the installed packages and their versions. | ||
Returns: | ||
List[List[str]]: List of lists. Each item contains of a dependency name and version. | ||
""" | ||
direct_deps = get_direct_dependencies() | ||
installed_packages = [] | ||
for package in pkg_resources.working_set: | ||
if package.project_name not in direct_deps: | ||
continue | ||
installed_packages.append([package.project_name, package.version]) | ||
return installed_packages | ||
|
||
|
||
def get_environment_info() -> Dict[str, Any]: | ||
"""Get the current environment information. | ||
This includes dependency versions, the OS, the CPU architecture and the python version. | ||
Returns: | ||
Dict[str, Any]: _description_ | ||
""" | ||
return { | ||
"dependencies": get_installed_packages(), | ||
"os": platform.platform(), | ||
"cpu_architecture": platform.machine(), | ||
"python_version": platform.python_version() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Any | ||
import platform | ||
import os | ||
import tempfile | ||
import json | ||
import zipfile | ||
|
||
from medcat.cat import CAT | ||
from medcat.utils.saving import envsnapshot | ||
|
||
import unittest | ||
|
||
|
||
def list_zip_contents(zip_file_path): | ||
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | ||
return zip_ref.namelist() | ||
|
||
|
||
class DirectDependenciesTests(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.direct_deps = envsnapshot.get_direct_dependencies() | ||
|
||
def test_nonempty(self): | ||
self.assertTrue(self.direct_deps) | ||
|
||
def test_does_not_contain_versions(self, version_starters: str = '<=>~'): | ||
for dep in self.direct_deps: | ||
for vs in version_starters: | ||
with self.subTest(f"DEP '{dep}' check for '{vs}'"): | ||
self.assertNotIn(vs, dep) | ||
|
||
def test_deps_are_installed_packages(self): | ||
for dep in self.direct_deps: | ||
with self.subTest(f"Has '{dep}'"): | ||
envsnapshot.pkg_resources.require(dep) | ||
|
||
|
||
class EnvSnapshotAloneTests(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.env_info = envsnapshot.get_environment_info() | ||
|
||
def test_info_is_dict(self): | ||
self.assertIsInstance(self.env_info, dict) | ||
|
||
def test_info_is_not_empty(self): | ||
self.assertTrue(self.env_info) | ||
|
||
def assert_has_target(self, target: str, expected: Any): | ||
self.assertIn(target, self.env_info) | ||
py_ver = self.env_info[target] | ||
self.assertEqual(py_ver, expected) | ||
|
||
def test_has_os(self): | ||
self.assert_has_target("os", platform.platform()) | ||
|
||
def test_has_py_ver(self): | ||
self.assert_has_target("python_version", platform.python_version()) | ||
|
||
def test_has_cpu_arch(self): | ||
self.assert_has_target("cpu_architecture", platform.machine()) | ||
|
||
def test_has_dependencies(self, name: str = "dependencies"): | ||
# NOTE: just making sure it's a anon-empty list | ||
self.assertIn(name, self.env_info) | ||
deps = self.env_info[name] | ||
self.assertTrue(deps) | ||
|
||
def test_all_direct_dependencies_are_installed(self): | ||
deps = self.env_info['dependencies'] | ||
direct_deps = envsnapshot.get_direct_dependencies() | ||
self.assertEqual(len(deps), len(direct_deps)) | ||
|
||
|
||
CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") | ||
ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME | ||
|
||
|
||
class EnvSnapshotInCATTests(unittest.TestCase): | ||
expected_env = envsnapshot.get_environment_info() | ||
|
||
@classmethod | ||
def setUpClass(cls) -> None: | ||
cls.cat = CAT.load_model_pack(CAT_PATH) | ||
cls._temp_dir = tempfile.TemporaryDirectory() | ||
mpn = cls.cat.create_model_pack(cls._temp_dir.name) | ||
cls.cat_folder = os.path.join(cls._temp_dir.name, mpn) | ||
cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME) | ||
|
||
def test_has_environment(self): | ||
self.assertTrue(os.path.exists(self.envrion_file_path)) | ||
|
||
def test_eviron_saved(self): | ||
with open(self.envrion_file_path) as f: | ||
saved_info: dict = json.load(f) | ||
self.assertEqual(saved_info.keys(), self.expected_env.keys()) | ||
for k in saved_info: | ||
with self.subTest(k): | ||
v1, v2 = saved_info[k], self.expected_env[k] | ||
self.assertEqual(v1, v2) | ||
|
||
def test_zip_has_env_snapshot(self): | ||
filenames = list_zip_contents(self.cat_folder + ".zip") | ||
self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters