Skip to content

Commit

Permalink
Add support for retrieving license files for transformers models (mlf…
Browse files Browse the repository at this point in the history
…low#10871)

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
  • Loading branch information
BenWilson2 authored Jan 24, 2024
1 parent 4d0af6f commit 26653cd
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 11 deletions.
9 changes: 9 additions & 0 deletions docs/source/llms/transformers/guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,15 @@ the stored model object.
In addition to the ``ModelCard``, the components that comprise any Pipeline (or the individual components if saving a dictionary of named components) will have their source types
stored. The model type, pipeline type, task, and classes of any supplementary component (such as a ``Tokenizer`` or ``ImageProcessor``) will be stored in the ``MLmodel`` file as well.

In order to preserve any attached legal requirements to the usage of any model that is hosted on the huggingface hub, a "best effort" attempt
is made when logging a transformers model to retrieve and persist any license information. A file will be generated (``LICENSE.txt``) within the root of
the model directory. Within this file you will either find a copy of a declared license, the name of a common license type that applies to the model's use (i.e., 'apache-2.0', 'mit'),
or, in the event that license information was never submitted to the huggingface hub when uploading a model repository, a link to the repository for you to use
in order to determine what restrictions exist regarding the use of the model.

.. note::
Model license information was introduced in **MLflow 2.10.0**. Previous versions do not include license information for models.

Automatic Signature inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
76 changes: 69 additions & 7 deletions mlflow/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import pathlib
import re
import shutil
import string
import sys
from functools import lru_cache
Expand Down Expand Up @@ -93,6 +94,8 @@
_IMAGE_PROCESSOR_TYPE_KEY = "image_processor_type"
_INFERENCE_CONFIG_BINARY_KEY = "inference_config.txt"
_INSTANCE_TYPE_KEY = "instance_type"
_LICENSE_FILE_NAME = "LICENSE.txt"
_LICENSE_FILE_PATTERN = re.compile(r"license(\.[a-z]+|$)", re.IGNORECASE)
_MODEL_KEY = "model"
_MODEL_BINARY_KEY = "model_binary"
_MODEL_BINARY_FILE_NAME = "model"
Expand Down Expand Up @@ -532,12 +535,17 @@ def save_model(
inference_config=inference_config,
)

model_name = transformers_model.model.name_or_path

# Get the model card from either the argument or the HuggingFace marketplace
card_data = model_card if model_card is not None else _fetch_model_card(transformers_model)
card_data = model_card or _fetch_model_card(model_name)

# If the card data can be acquired, save the text and the data separately
_write_card_data(card_data, path)

# Write the license information (or guidance) along with the model
_write_license_information(model_name, card_data, path)

model_bin_kwargs = {_MODEL_BINARY_KEY: _MODEL_BINARY_FILE_NAME}

# Only allow a subset of task types to have a pyfunc definition.
Expand Down Expand Up @@ -1096,7 +1104,7 @@ def _deserialize_torch_dtype_if_exists(flavor_config):
return _torch_dype_mapping()[flavor_config["torch_dtype"]]


def _fetch_model_card(model_or_pipeline):
def _fetch_model_card(model_name):
"""
Attempts to retrieve the model card for the specified model architecture iff the
`huggingface_hub` library is installed. If a card cannot be found in the registry or
Expand All @@ -1112,18 +1120,16 @@ def _fetch_model_card(model_or_pipeline):
)
return

model = model_or_pipeline.model

if hasattr(hub, "ModelCard"):
try:
return hub.ModelCard.load(model.name_or_path)
return hub.ModelCard.load(model_name)
except Exception as e:
_logger.warning(f"The model card could not be retrieved from the hub due to {e}")
else:
_logger.warning(
f"The version of huggingface_hub that is installed does not provide "
"The version of huggingface_hub that is installed does not provide "
f"ModelCard functionality. You have version {hub.__version__} installed. "
f"Update huggingface_hub to >= '0.10.0' to retrieve the ModelCard data."
"Update huggingface_hub to >= '0.10.0' to retrieve the ModelCard data."
)


Expand All @@ -1143,6 +1149,62 @@ def _write_card_data(card_data, path):
)


def _extract_license_file_from_repository(model_name):
"""Returns the top-level file inventory of `RepoFile` objects from the huggingface hub"""
try:
import huggingface_hub as hub
except ImportError:
_logger.debug(
f"Unable to list repository contents for the model repo {model_name}. In order "
"to enable repository listing functionality, please install the huggingface_hub "
"package by running `pip install huggingface_hub>0.10.0"
)
return
try:
files = hub.list_repo_files(model_name)
return next(file for file in files if _LICENSE_FILE_PATTERN.search(file))
except Exception as e:
_logger.debug(
f"Failed to retrieve repository file listing data for {model_name} due to {e}"
)


def _write_license_information(model_name, card_data, path):
"""Writes the license file or instructions to retrieve license information."""

fallback = (
f"A license file could not be found for the '{model_name}' repository. \n"
"To ensure that you are in compliance with the license requirements for this "
f"model, please visit the model repository here: https://huggingface.co/{model_name}"
)

if license_file := _extract_license_file_from_repository(model_name):
try:
import huggingface_hub as hub

license_location = hub.hf_hub_download(repo_id=model_name, filename=license_file)
except Exception as e:
_logger.warning(f"Failed to download the license file due to: {e}")
else:
local_license_path = pathlib.Path(license_location)
target_path = path.joinpath(local_license_path.name)
try:
shutil.copy(local_license_path, target_path)
return
except Exception as e:
_logger.warning(f"The license file could not be copied due to: {e}")

# Fallback or card data license info
if card_data and card_data.data.license != "other":
fallback = f"{fallback}\nThe declared license type is: '{card_data.data.license}'"
else:
_logger.warning(
"Unable to find license information for this model. Please verify "
"permissible usage for the model you are storing prior to use."
)
path.joinpath(_LICENSE_FILE_NAME).write_text(fallback, encoding="utf-8")


def _build_pipeline_from_model_input(model, task: str):
"""
Utility for generating a pipeline from component parts. If required components are not
Expand Down
45 changes: 41 additions & 4 deletions tests/transformers/test_transformers_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_TransformersWrapper,
_validate_transformers_task_type,
_write_card_data,
_write_license_information,
get_default_conda_env,
get_default_pip_requirements,
)
Expand Down Expand Up @@ -310,11 +311,33 @@ def test_saving_with_invalid_dict_as_model(model_path):


def test_model_card_acquisition_vision_model(small_vision_model):
model_provided_card = _fetch_model_card(small_vision_model)
model_provided_card = _fetch_model_card(small_vision_model.model.name_or_path)
assert model_provided_card.data.to_dict()["tags"] == ["vision", "image-classification"]
assert len(model_provided_card.text) > 0


@pytest.mark.parametrize(
("repo_id", "license_file"),
[
("google/mobilenet_v2_1.0_224", "LICENSE.txt"), # no license declared
("csarron/mobilebert-uncased-squad-v2", "LICENSE.txt"), # mit license
("codellama/CodeLlama-34b-hf", "LICENSE"), # custom license
("openai/whisper-tiny", "LICENSE.txt"), # apache license
("stabilityai/stable-code-3b", "LICENSE"), # custom
("mistralai/Mixtral-8x7B-Instruct-v0.1", "LICENSE.txt"), # apache
],
)
def test_license_acquisition(repo_id, license_file, tmp_path):
card_data = _fetch_model_card(repo_id)
_write_license_information(repo_id, card_data, tmp_path)
assert tmp_path.joinpath(license_file).stat().st_size > 0


def test_license_fallback(tmp_path):
_write_license_information("not a real repo", None, tmp_path)
assert tmp_path.joinpath("LICENSE.txt").stat().st_size > 0


def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
Expand All @@ -325,6 +348,9 @@ def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path
# validate inferred model card data
card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
assert card_data["tags"] == ["vision", "image-classification"]
# verify the license file has been written
license_file = model_path.joinpath("LICENSE.txt").read_text()
assert len(license_file) > 0
# Validate inferred model card text
with model_path.joinpath("model_card.md").open() as file:
card_text = file.read()
Expand Down Expand Up @@ -356,7 +382,9 @@ def test_vision_model_save_model_for_task_and_card_inference(small_vision_model,
# Validate inferred model card text
card_text = model_path.joinpath("model_card.md").read_text(encoding="utf-8")
assert len(card_text) > 0

# verify the license file has been written
license_file = model_path.joinpath("LICENSE.txt").read_text()
assert len(license_file) > 0
# Validate the MLModel file
mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
flavor_config = mlmodel["flavors"]["transformers"]
Expand Down Expand Up @@ -385,6 +413,9 @@ def test_qa_model_save_model_for_task_and_card_inference(small_seq2seq_pipeline,
assert card_data["datasets"] == ["emo"]
# The creator of this model did not include tag data in the card. Ensure it is missing.
assert "tags" not in card_data
# verify the license file has been written
license_file = model_path.joinpath("LICENSE.txt").read_text()
assert len(license_file) > 0
# Validate inferred model card text
with model_path.joinpath("model_card.md").open() as file:
card_text = file.read()
Expand Down Expand Up @@ -422,6 +453,9 @@ def test_qa_model_save_and_override_card(small_qa_pipeline, model_path):
# Validate inferred model card text
with model_path.joinpath("model_card.md").open() as file:
card_text = file.read()
# verify the license file has been written
license_file = model_path.joinpath("LICENSE.txt").read_text()
assert len(license_file) > 0
assert card_text.startswith("\n# I made a new model!")
# validate MLmodel files
mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
Expand Down Expand Up @@ -919,7 +953,7 @@ def test_non_existent_model_card_entry(small_seq2seq_pipeline, model_path):

def test_huggingface_hub_not_installed(small_seq2seq_pipeline, model_path):
with mock.patch.dict("sys.modules", {"huggingface_hub": None}):
result = mlflow.transformers._fetch_model_card(small_seq2seq_pipeline)
result = mlflow.transformers._fetch_model_card(small_seq2seq_pipeline.model.name_or_path)

assert result is None

Expand All @@ -928,6 +962,9 @@ def test_huggingface_hub_not_installed(small_seq2seq_pipeline, model_path):
contents = {item.name for item in model_path.iterdir()}
assert not contents.intersection({"model_card.txt", "model_card_data.yaml"})

license_data = model_path.joinpath("LICENSE.txt").read_text()
assert license_data.rstrip().endswith("mobilebert")


def test_save_pipeline_without_defined_components(small_conversational_model, model_path):
# This pipeline type explicitly does not have a configuration for an image_processor
Expand Down Expand Up @@ -3818,7 +3855,7 @@ def _calculate_expected_size(path_or_dir):
expected_size = 0
for folder in [model_dir, tokenizer_dir]:
expected_size += _calculate_expected_size(folder)
other_files = ["model_card.md", "model_card_data.yaml"]
other_files = ["model_card.md", "model_card_data.yaml", "LICENSE.txt"]
for file in other_files:
path = tmp_path.joinpath(file)
expected_size += _calculate_expected_size(path)
Expand Down

0 comments on commit 26653cd

Please sign in to comment.