Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model card loading #45

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def from_pretrained(
:param token: The huggingface token to use.
:return: A StaticEmbedder
"""
embeddings, tokenizer, config = load_pretrained(path, token=token)
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)

return cls(embeddings, tokenizer, config)
return cls(
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
)

def encode(
self,
Expand Down
29 changes: 26 additions & 3 deletions model2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@

def load_pretrained(
folder_or_repo_path: str | Path, token: str | None = None
) -> tuple[np.ndarray, Tokenizer, dict[str, Any]]:
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
"""
Loads a pretrained model from a folder.

Expand All @@ -111,7 +111,7 @@
- If the local path is not found, we will attempt to load from the huggingface hub.
:param token: The huggingface token to use.
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
:return: The embeddings, tokenizer, and config.
:return: The embeddings, tokenizer, config, and metadata.

"""
folder_or_repo_path = Path(folder_or_repo_path)
Expand All @@ -133,6 +133,10 @@
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")

# README is optional, so this is a bit finicky.
readme_path = folder_or_repo_path / "README.md"
metadata = _get_metadata_from_readme(readme_path)

else:
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
try:
Expand All @@ -148,6 +152,13 @@
# Raise original exception.
raise e

try:
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
metadata = _get_metadata_from_readme(Path(readme_path))
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No README found in the model folder. No model card loaded.")
metadata = {}

Check warning on line 160 in model2vec/utils.py

View check run for this annotation

Codecov / codecov/patch

model2vec/utils.py#L155-L160

Added lines #L155 - L160 were not covered by tests

config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)

Expand All @@ -162,7 +173,19 @@
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
)

return embeddings, tokenizer, config
return embeddings, tokenizer, config, metadata


def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
"""Get metadata from a README file."""
if not readme_path.exists():
logger.info(f"README file not found in {readme_path}. No model card loaded.")
return {}
model_card = ModelCard.load(readme_path)
data: dict[str, Any] = model_card.data.to_dict()
if not data:
logger.info("File README.md exists, but was empty. No model card loaded.")
return data


def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

from model2vec.utils import _get_metadata_from_readme


def test__get_metadata_from_readme_not_exists() -> None:
"""Test getting metadata from a README."""
assert _get_metadata_from_readme(Path("zzz")) == {}


def test__get_metadata_from_readme_mocked_file() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"---\nkey: value\n---\n")
f.flush()
assert _get_metadata_from_readme(Path(f.name))["key"] == "value"


def test__get_metadata_from_readme_mocked_file_keys() -> None:
"""Test getting metadata from a README."""
with NamedTemporaryFile() as f:
f.write(b"")
f.flush()
assert set(_get_metadata_from_readme(Path(f.name))) == set()