From e0d3ef71aa0987143d57a6e6d1f3ee95a78844f6 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 31 Oct 2024 14:01:45 +0100 Subject: [PATCH] Fixed bug --- model2vec/utils.py | 18 ++++++++++-------- pyproject.toml | 1 + uv.lock | 4 +++- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/model2vec/utils.py b/model2vec/utils.py index 073f5d0..b77c45a 100644 --- a/model2vec/utils.py +++ b/model2vec/utils.py @@ -26,17 +26,19 @@ def get_tensor(self, key: str) -> np.ndarray: def get_package_extras(package: str, extra: str) -> Iterator[str]: """Get the extras of the package.""" - message = metadata(package) + try: + message = metadata(package) + except Exception as e: + raise ImportError(f"Could not retrieve metadata for package '{package}': {e}") + all_packages = message.get_all("Requires-Dist") or [] for package in all_packages: name, *rest = package.split(";", maxsplit=1) - if not rest: - continue - _, found_extra = rest[0].split("==", maxsplit=1) - # Strip off quotes - found_extra = found_extra.strip(' "') - if found_extra == extra: - yield name + if rest: + # Extract and clean the extra requirement + found_extra = rest[0].split("==")[-1].strip(" \"'") + if found_extra == extra: + yield name.strip() def importable(module: str, extra: str) -> None: diff --git a/pyproject.toml b/pyproject.toml index 5385728..7209a95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "rich", "tqdm", "tokenizers>=0.20", + "safetensors", "setuptools", ] diff --git a/uv.lock b/uv.lock index 5c7056d..b331640 100644 --- a/uv.lock +++ b/uv.lock @@ -505,11 +505,12 @@ wheels = [ [[package]] name = "model2vec" -version = "0.2.4" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "numpy" }, { name = "rich" }, + { name = "safetensors" }, { name = "setuptools" }, { name = "tokenizers" }, { name = "tqdm" }, @@ -542,6 +543,7 @@ requires-dist = [ { name = "pytest-coverage", marker = "extra == 'dev'" }, { name = "rich" }, { name = "ruff", marker = "extra == 'dev'" }, + { name = "safetensors" }, { name = "scikit-learn", marker = "extra == 'distill'" }, { name = "setuptools" }, { name = "tokenizers", specifier = ">=0.20" },