Skip to content

Commit

Permalink
CU-8693n892x environment/dependency snapshots (#438)
Browse files Browse the repository at this point in the history
* CU-8693n892x: Save environment/dependency snapshot upon model pack creation

* CU-8693n892x: Fix typing for env snapshot module

* CU-8693n892x: Add test for env file existance in .zip

* CU-8693n892x: Add doc strings

* CU-8693n892x: Centralise env snapshot file name

* CU-8693n892x: Add env snapshot file to exceptions in serialisation tests

* CU-8693n892x: Only list direct dependencies

* CU-8693n892x: Add test that verifies all direct dependencies are listed in environment

* CU-8693n892x: Move requirements to separate file and use that for environment snapshot

* CU-8693n892x: Remove unused constants

* CU-8693n892x: Allow URL based dependencies when using direct dependencies

* CU-8693n892x: Distribute install_requires.txt alongside the package; use correct path in distributed version
  • Loading branch information
mart-r authored Jun 19, 2024
1 parent e11c1da commit e4715ae
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 27 deletions.
24 changes: 24 additions & 0 deletions install_requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
'pandas>=1.4.2' # first to support 3.11
'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump
'spacy>=3.6.0,<4.0.0' # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
'scipy~=1.9.2' # 1.9.2 is first to support 3.11
'transformers>=4.34.0,<5.0.0' # avoid major version bump
'accelerate>=0.23.0' # required by Trainer class in de-id
'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
'tqdm>=4.27'
'scikit-learn>=1.1.3,<2.0.0' # 1.1.3 is first to supporrt 3.11; avoid major version bump
'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
'datasets>=2.2.2,<3.0.0' # avoid major bump
'jsonpickle>=2.0.0' # allow later versions, tested with 3.0.0
'psutil>=5.8.0'
# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
'multiprocess~=0.70.12' # 0.70.14 seemed to work just fine
'aiofiles>=0.8.0' # allow later versions, tested with 22.1.0
'ipywidgets>=7.6.5' # allow later versions, tested with 0.8.0
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
'blis>=0.7.5' # allow later versions, tested with 0.7.9
'click>=8.0.4' # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
"humanfriendly~=10.0" # for human readable file / RAM sizes
"peft>=0.8.2"
7 changes: 7 additions & 0 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from medcat.utils.decorators import deprecated
from medcat.ner.transformers_ner import TransformersNER
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters

Expand Down Expand Up @@ -318,6 +319,12 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
with open(model_card_path, 'w') as f:
json.dump(self.get_model_card(as_dict=True), f, indent=2)

# add a dependency snapshot
env_info = get_environment_info()
env_info_path = os.path.join(save_dir_path, ENV_SNAPSHOT_FILE_NAME)
with open(env_info_path, 'w') as f:
json.dump(env_info, f)

# Zip everything
shutil.make_archive(os.path.join(_save_dir_path, model_pack_name), 'zip', root_dir=save_dir_path)

Expand Down
73 changes: 73 additions & 0 deletions medcat/utils/saving/envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import List, Dict, Any, Set

import os
import re
import pkg_resources
import platform


ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json"

INSTALL_REQUIRES_FILE_PATH = os.path.join(os.path.dirname(__file__),
"..", "..", "..",
"install_requires.txt")
# NOTE: The install_requires.txt file is copied into the wheel during build
# so that it can be included in the distributed package.
# However, that means it's 1 folder closer to this file since it'll now
# be in the root of the package rather than the root of the project.
INSTALL_REQUIRES_FILE_PATH_PIP = os.path.join(os.path.dirname(__file__),
"..", "..",
"install_requires.txt")


def get_direct_dependencies() -> Set[str]:
"""Get the set of direct dependeny names.
The current implementation reads install_requires.txt for dependenceies,
removes comments, whitespace, quotes; removes the versions and returns
the names as a set.
Returns:
Set[str]: The set of direct dependeny names.
"""
req_file = INSTALL_REQUIRES_FILE_PATH
if not os.path.exists(req_file):
# When pip-installed. See note above near constant definiation
req_file = INSTALL_REQUIRES_FILE_PATH_PIP
with open(req_file) as f:
# read every line, strip quotes and comments
dep_lines = [line.split("#")[0].replace("'", "").replace('"', "").strip() for line in f.readlines()]
# remove comment-only (or empty) lines
deps = [dep for dep in dep_lines if dep]
return set(re.split("[@<=>~]", dep)[0].strip() for dep in deps)


def get_installed_packages() -> List[List[str]]:
"""Get the installed packages and their versions.
Returns:
List[List[str]]: List of lists. Each item contains of a dependency name and version.
"""
direct_deps = get_direct_dependencies()
installed_packages = []
for package in pkg_resources.working_set:
if package.project_name not in direct_deps:
continue
installed_packages.append([package.project_name, package.version])
return installed_packages


def get_environment_info() -> Dict[str, Any]:
"""Get the current environment information.
This includes dependency versions, the OS, the CPU architecture and the python version.
Returns:
Dict[str, Any]: _description_
"""
return {
"dependencies": get_installed_packages(),
"os": platform.platform(),
"cpu_architecture": platform.machine(),
"python_version": platform.python_version()
}
39 changes: 13 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import setuptools
import shutil

with open("./README.md", "r") as fh:
long_description = fh.read()

# make a copy of install requirements so that it gets distributed with the wheel
shutil.copy('install_requires.txt', 'medcat/install_requires.txt')

with open("install_requires.txt") as f:
# read every line, strip quotes and comments
dep_lines = [l.split("#")[0].replace("'", "").replace('"', "").strip() for l in f.readlines()]
# remove comment-only (or empty) lines
install_requires = [dep for dep in dep_lines if dep]


setuptools.setup(
name="medcat",
Expand All @@ -17,32 +27,9 @@
packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets',
'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction',
'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'],
install_requires=[
'numpy>=1.22.0,<1.26.0', # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
'pandas>=1.4.2', # first to support 3.11
'gensim>=4.3.0,<5.0.0', # 5.3.0 is first to support 3.11; avoid major version bump
'spacy>=3.6.0,<4.0.0', # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump
'scipy~=1.9.2', # 1.9.2 is first to support 3.11
'transformers>=4.34.0,<5.0.0', # avoid major version bump
'accelerate>=0.23.0', # required by Trainer class in de-id
'torch>=1.13.0,<3.0.0', # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now
'tqdm>=4.27',
'scikit-learn>=1.1.3,<2.0.0', # 1.1.3 is first to supporrt 3.11; avoid major version bump
'dill>=0.3.6,<1.0.0', # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump
'datasets>=2.2.2,<3.0.0', # avoid major bump
'jsonpickle>=2.0.0', # allow later versions, tested with 3.0.0
'psutil>=5.8.0',
# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets
'multiprocess~=0.70.12', # 0.70.14 seemed to work just fine
'aiofiles>=0.8.0', # allow later versions, tested with 22.1.0
'ipywidgets>=7.6.5', # allow later versions, tested with 0.8.0
'xxhash>=3.0.0', # allow later versions, tested with 3.1.0
'blis>=0.7.5', # allow later versions, tested with 0.7.9
'click>=8.0.4', # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes
"humanfriendly~=10.0", # for human readable file / RAM sizes
"peft>=0.8.2", # allow later versions, tested with 0.10.0
],
install_requires=install_requires,
include_package_data=True,
package_data={"medcat": ["install_requires.txt"]},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
Expand Down
105 changes: 105 additions & 0 deletions tests/utils/saving/test_envsnapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any
import platform
import os
import tempfile
import json
import zipfile

from medcat.cat import CAT
from medcat.utils.saving import envsnapshot

import unittest


def list_zip_contents(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
return zip_ref.namelist()


class DirectDependenciesTests(unittest.TestCase):

def setUp(self) -> None:
self.direct_deps = envsnapshot.get_direct_dependencies()

def test_nonempty(self):
self.assertTrue(self.direct_deps)

def test_does_not_contain_versions(self, version_starters: str = '<=>~'):
for dep in self.direct_deps:
for vs in version_starters:
with self.subTest(f"DEP '{dep}' check for '{vs}'"):
self.assertNotIn(vs, dep)

def test_deps_are_installed_packages(self):
for dep in self.direct_deps:
with self.subTest(f"Has '{dep}'"):
envsnapshot.pkg_resources.require(dep)


class EnvSnapshotAloneTests(unittest.TestCase):

def setUp(self) -> None:
self.env_info = envsnapshot.get_environment_info()

def test_info_is_dict(self):
self.assertIsInstance(self.env_info, dict)

def test_info_is_not_empty(self):
self.assertTrue(self.env_info)

def assert_has_target(self, target: str, expected: Any):
self.assertIn(target, self.env_info)
py_ver = self.env_info[target]
self.assertEqual(py_ver, expected)

def test_has_os(self):
self.assert_has_target("os", platform.platform())

def test_has_py_ver(self):
self.assert_has_target("python_version", platform.python_version())

def test_has_cpu_arch(self):
self.assert_has_target("cpu_architecture", platform.machine())

def test_has_dependencies(self, name: str = "dependencies"):
# NOTE: just making sure it's a anon-empty list
self.assertIn(name, self.env_info)
deps = self.env_info[name]
self.assertTrue(deps)

def test_all_direct_dependencies_are_installed(self):
deps = self.env_info['dependencies']
direct_deps = envsnapshot.get_direct_dependencies()
self.assertEqual(len(deps), len(direct_deps))


CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples")
ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME


class EnvSnapshotInCATTests(unittest.TestCase):
expected_env = envsnapshot.get_environment_info()

@classmethod
def setUpClass(cls) -> None:
cls.cat = CAT.load_model_pack(CAT_PATH)
cls._temp_dir = tempfile.TemporaryDirectory()
mpn = cls.cat.create_model_pack(cls._temp_dir.name)
cls.cat_folder = os.path.join(cls._temp_dir.name, mpn)
cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME)

def test_has_environment(self):
self.assertTrue(os.path.exists(self.envrion_file_path))

def test_eviron_saved(self):
with open(self.envrion_file_path) as f:
saved_info: dict = json.load(f)
self.assertEqual(saved_info.keys(), self.expected_env.keys())
for k in saved_info:
with self.subTest(k):
v1, v2 = saved_info[k], self.expected_env[k]
self.assertEqual(v1, v2)

def test_zip_has_env_snapshot(self):
filenames = list_zip_contents(self.cat_folder + ".zip")
self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames)
4 changes: 3 additions & 1 deletion tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from medcat.vocab import Vocab

from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY
from medcat.utils.saving.envsnapshot import ENV_SNAPSHOT_FILE_NAME

import medcat.utils.saving.coding as _

Expand Down Expand Up @@ -60,6 +61,7 @@ class ModelCreationTests(unittest.TestCase):
json_model_pack = tempfile.TemporaryDirectory()
EXAMPLES = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "..", "..", "..", "examples")
EXCEPTIONAL_JSONS = ['model_card.json', ENV_SNAPSHOT_FILE_NAME]

@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_dill_to_json(self):
SPECIALITY_NAMES) - len(ONE2MANY))
for json in jsons:
with self.subTest(f'JSON {json}'):
if json.endswith('model_card.json'):
if any(json.endswith(exception) for exception in self.EXCEPTIONAL_JSONS):
continue # ignore model card here
if any(name in json for name in ONE2MANY):
# ignore cui2many and name2many
Expand Down

0 comments on commit e4715ae

Please sign in to comment.