diff --git a/pyproject.toml b/pyproject.toml index 8682dcee..0b4e7dbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,7 @@ ignore = [ [tool.bandit] exclude_dirs = ["examples", "tests"] -skips = ["B101"] +skips = ["B101", "B404"] [tool.pytest.ini_options] log_cli = true diff --git a/src/nnbench/context.py b/src/nnbench/context.py index 612a9a05..7bf28e9b 100644 --- a/src/nnbench/context.py +++ b/src/nnbench/context.py @@ -17,3 +17,100 @@ def cpuarch() -> dict[str, str]: def python_version() -> dict[str, str]: return {"python_version": platform.python_version()} + + +class PythonPackageInfo: + """ + A context helper returning version info for requested installed packages. + + If a requested package is not installed, an empty string is returned instead. + + Parameters + ---------- + *packages: str + Names of the requested packages under which they exist in the current environment. + For packages installed through ``pip``, this equals the PyPI package name. + """ + + def __init__(self, *packages: str): + self.packages = packages + + def __call__(self) -> dict[str, str]: + from importlib.metadata import PackageNotFoundError, version + + result: dict[str, str] = {} + for pkg in self.packages: + try: + result[pkg] = version(pkg) + except PackageNotFoundError: + result[pkg] = "" + return result + + +class GitEnvironmentInfo: + """ + A context helper providing the current git commit, latest tag, and upstream repository name. + + Parameters + ---------- + remote: str + Remote name for which to provide info, by default ``"origin"``. + """ + + def __init__(self, remote: str = "origin"): + self.remote = remote + + def __call__(self) -> dict[str, str]: + import subprocess + + def git_subprocess(args: list[str]) -> subprocess.CompletedProcess: + if platform.system() == "Windows": + git = "git.exe" + else: + git = "git" + + return subprocess.run( # nosec: B603 + [git, *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8" + ) + + result: dict[str, str] = { + "commit": "", + "provider": "", + "repository": "", + "tag": "", + } + + # first, check if inside a repo. + rc = subprocess.call(["git", "rev-parse", "--is-inside-work-tree"]) # nosec: B603, B607 + # if not, return empty info. + if rc: + return result + + # secondly: get the current commit. + p = git_subprocess(["rev-parse", "HEAD"]) + if not p.returncode: + result["commit"] = p.stdout.strip() + + # thirdly, get the latest tag, without a short commit SHA attached. + p = git_subprocess(["describe", "--tags", "--abbrev=0"]) + if not p.returncode: + result["tag"] = p.stdout.strip() + + # and finally, get the remote repo name pointed to by the given remote. + p = git_subprocess(["remote", "get-url", self.remote]) + if not p.returncode: + remotename: str = p.stdout.strip() + # it's an SSH remote. + if "@" in remotename: + prefix, sep = "git@", ":" + else: + # it is HTTPS. + prefix, sep = "https://", "/" + + remotename = remotename.removeprefix(prefix) + provider, reponame = remotename.split(sep, 1) + + result["provider"] = provider + result["repository"] = reponame.removesuffix(".git") + + return result diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 00000000..41b18e87 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,24 @@ +from nnbench.context import GitEnvironmentInfo, PythonPackageInfo + + +def test_python_package_info() -> None: + p = PythonPackageInfo("pre-commit", "pyyaml")() + + for v in p.values(): + assert v != "" + + # for a bogus package, it should not fail but produce an empty string. + p = PythonPackageInfo("asdfghjkl")() + + for v in p.values(): + assert v == "" + + +def test_git_info_provider() -> None: + g = GitEnvironmentInfo()() + + # tag might not be available in a shallow checkout in CI, + # but commit, provider and repo are. + assert g["commit"] != "" + assert g["provider"] == "github.com" + assert g["repository"] == "aai-institute/nnbench"