From d1f1371f9726a6c76741723c61f2b841d17fa0be Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 28 Sep 2023 16:54:18 +0200 Subject: [PATCH 1/4] Add base model metadata to model card Resolves #938 This PR adds the base model metadata, if present, to the model card. On top of this, the code for creating the model card has been refactored to use the huggingface_hub classes instead of doing ad hoc parsing and writing. A consequence of this is that if no model card exists, the default template will now be used, with a lot of placeholder text. LMK if this is not desired. --- src/peft/peft_model.py | 39 ++++++++++++++++++++++++++++---------- src/peft/utils/__init__.py | 1 - src/peft/utils/other.py | 25 ------------------------ tests/testing_common.py | 36 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 36 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index eacc5bf009..d7453f679d 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -17,16 +17,18 @@ import inspect import os +import re import warnings from contextlib import contextmanager from copy import deepcopy from typing import Any, Dict, List, Optional, Union import torch +import yaml from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, ModelCard, ModelCardData from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel @@ -55,7 +57,6 @@ _prepare_prompt_learning_config, _set_adapter, _set_trainable, - add_library_to_model_card, get_peft_model_state_dict, infer_device, load_peft_weights, @@ -650,13 +651,32 @@ def create_or_update_model_card(self, output_dir: str): Updates or create model card to include information about peft: 1. Adds `peft` library tag 2. Adds peft version - 3. Adds quantization information if it was used + 3. Adds base model info + 4. Adds quantization information if it was used """ - # Adds `peft` library tag - add_library_to_model_card(output_dir) - with open(os.path.join(output_dir, "README.md"), "r") as f: - lines = f.readlines() + filename = os.path.join(output_dir, "README.md") + if not os.path.exists(filename): + # touch an empty model card + with open(filename, "a"): + pass + + with open(filename, "r") as f: + readme = f.read() + + match = re.search(r"---\n(.*?)\n---", readme, re.DOTALL) + metainfo = {} if match is None else yaml.safe_load(match.group(1)) + metainfo["library_name"] = "peft" + model_config = self.config + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + if model_config["model_type"] != "custom": + metainfo["base_model"] = model_config["_name_or_path"] + + card_data = ModelCardData(**metainfo) + + card = ModelCard.from_template(card_data) + lines = card.text.splitlines() quantization_config = None if hasattr(self.config, "quantization_config"): @@ -681,9 +701,8 @@ def create_or_update_model_card(self, output_dir: str): else: lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n") - # write the lines back to README.md - with open(os.path.join(output_dir, "README.md"), "w") as f: - f.writelines(lines) + card.text = "\n".join(lines) + card.save(filename) class PeftModelForSequenceClassification(PeftModel): diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 9c02a911fd..b42d8d070b 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -30,7 +30,6 @@ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, _set_trainable, - add_library_to_model_card, bloom_model_postprocess_past_key_value, prepare_model_for_int8_training, prepare_model_for_kbit_training, diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 2dbe0a2c5b..a28fa88ef1 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -39,31 +39,6 @@ def infer_device(): return torch_device -# Add or edit model card to have `library_name: peft` -def add_library_to_model_card(output_dir): - if os.path.exists(os.path.join(output_dir, "README.md")): - with open(os.path.join(output_dir, "README.md"), "r") as f: - lines = f.readlines() - # check if the first line is `---` - if len(lines) > 0 and lines[0].startswith("---"): - for i, line in enumerate(lines[1:]): - # check if line starts with `library_name`, if yes, update it - if line.startswith("library_name"): - lines[i + 1] = "library_name: peft\n" - break - elif line.startswith("---"): - # insert `library_name: peft` before the last `---` - lines.insert(i + 1, "library_name: peft\n") - break - else: - lines = ["---\n", "library_name: peft\n", "---\n"] + lines - else: - lines = ["---\n", "library_name: peft\n", "---\n"] - # write the lines back to README.md - with open(os.path.join(output_dir, "README.md"), "w") as f: - f.writelines(lines) - - # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): past_key_values = torch.cat(past_key_values) diff --git a/tests/testing_common.py b/tests/testing_common.py index b6ac83dcbf..495c6cfc98 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os import pickle +import re import tempfile from collections import OrderedDict from dataclasses import replace import torch +import yaml from diffusers import StableDiffusionPipeline from peft import ( @@ -171,6 +174,33 @@ class PeftCommonTester: def prepare_inputs_for_common(self): raise NotImplementedError + def check_modelcard(self, tmp_dirname, model): + # check the generated README.md + filename = os.path.join(tmp_dirname, "README.md") + self.assertTrue(os.path.exists(filename)) + with open(filename, "r", encoding="utf-8") as f: + readme = f.read() + metainfo = re.search(r"---\n(.*?)\n---", readme, re.DOTALL).group(1) + dct = yaml.safe_load(metainfo) + self.assertEqual(dct["library_name"], "peft") + + model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() + if model_config["model_type"] != "custom": + self.assertEqual(dct["base_model"], model_config["_name_or_path"]) + else: + self.assertTrue("base_model" not in dct) + + def check_config_json(self, tmp_dirname, model): + # check the generated config.json + filename = os.path.join(tmp_dirname, "adapter_config.json") + self.assertTrue(os.path.exists(filename)) + with open(filename, "r", encoding="utf-8") as f: + config = json.load(f) + + model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() + if model_config["model_type"] != "custom": + self.assertEqual(config["base_model_name_or_path"], model_config["_name_or_path"]) + def _test_model_attr(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -293,6 +323,9 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + self.check_modelcard(tmp_dirname, model) + self.check_config_json(tmp_dirname, model) + def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): if issubclass(config_cls, AdaLoraConfig): # AdaLora does not support adding more than 1 adapter @@ -368,6 +401,9 @@ def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_k self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "config.json"))) + self.check_modelcard(tmp_dirname, model) + self.check_config_json(tmp_dirname, model) + with tempfile.TemporaryDirectory() as tmp_dirname: model.save_pretrained(tmp_dirname, selected_adapters=["default"]) From 3e19c1ac8df8739c96eb55c983c7a47d1019e59b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 28 Sep 2023 17:19:52 +0200 Subject: [PATCH 2/4] Fixes for pre-existing model card Ensure that this works correctly if a model card already exists. --- src/peft/peft_model.py | 14 ++++++++++++-- src/peft/utils/other.py | 1 - tests/test_custom_models.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index d7453f679d..357b0daf51 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -28,7 +28,7 @@ from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory -from huggingface_hub import hf_hub_download, ModelCard, ModelCardData +from huggingface_hub import ModelCard, ModelCardData, hf_hub_download from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel @@ -656,7 +656,9 @@ def create_or_update_model_card(self, output_dir: str): """ filename = os.path.join(output_dir, "README.md") + card_exists = True if not os.path.exists(filename): + card_exists = False # touch an empty model card with open(filename, "a"): pass @@ -664,6 +666,7 @@ def create_or_update_model_card(self, output_dir: str): with open(filename, "r") as f: readme = f.read() + # add metadata match = re.search(r"---\n(.*?)\n---", readme, re.DOTALL) metainfo = {} if match is None else yaml.safe_load(match.group(1)) metainfo["library_name"] = "peft" @@ -675,8 +678,15 @@ def create_or_update_model_card(self, output_dir: str): card_data = ModelCardData(**metainfo) + # add extra data to model card body card = ModelCard.from_template(card_data) - lines = card.text.splitlines() + + # check if there is already a text body on the README.md, if not, use default template + if card_exists: + text = readme.split("\n---\n", 1)[-1] + else: + text = card.text + lines = text.splitlines() quantization_config = None if hasattr(self.config, "quantization_config"): diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index a28fa88ef1..35bb0622dc 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -14,7 +14,6 @@ # limitations under the License. import copy import inspect -import os import warnings from typing import Optional diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 18251a2458..7e381f4ccb 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import os import tempfile import unittest @@ -365,6 +366,41 @@ def run_with_disable(config_kwargs, bias): def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + def test_existing_model_card(self): + # ensure that if there is already a model card, it is not overwritten + model = MLP() + config = LoraConfig(target_modules=["lin0"]) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + # create a model card + text = "---\nmeta: hello\n---\nThis is a model card\n" + with open(os.path.join(tmp_dirname, "README.md"), "w") as f: + f.write(text) + + model.save_pretrained(tmp_dirname) + with open(os.path.join(tmp_dirname, "README.md"), "r") as f: + model_card = f.read() + + self.assertIn("library_name: peft", model_card) + self.assertIn("meta: hello", model_card) + self.assertIn("This is a model card", model_card) + + def test_non_existing_model_card(self): + # ensure that if there is already a model card, it is not overwritten + model = MLP() + config = LoraConfig(target_modules=["lin0"]) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + with open(os.path.join(tmp_dirname, "README.md"), "r") as f: + model_card = f.read() + + self.assertIn("library_name: peft", model_card) + # rough check that the model card is pre-filled + self.assertGreater(len(model_card), 1000) + class TestMultiRankAdapter(unittest.TestCase): """Tests related to multirank LoRA adapters""" From c6a996fc56c58a07bf4ab963502a682549d12b35 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 29 Sep 2023 16:01:00 +0200 Subject: [PATCH 3/4] Reviewer feedback: simplify card loading Co-authored-by: Lucain --- src/peft/peft_model.py | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 357b0daf51..7479582797 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -656,37 +656,21 @@ def create_or_update_model_card(self, output_dir: str): """ filename = os.path.join(output_dir, "README.md") - card_exists = True - if not os.path.exists(filename): - card_exists = False - # touch an empty model card - with open(filename, "a"): - pass - - with open(filename, "r") as f: - readme = f.read() - - # add metadata - match = re.search(r"---\n(.*?)\n---", readme, re.DOTALL) - metainfo = {} if match is None else yaml.safe_load(match.group(1)) - metainfo["library_name"] = "peft" + + card = ( + ModelCard.load(filename) + if os.path.exists(filename) + else ModelCard.from_template(ModelCardData()) + ) + + card.data["library_name"] = "peft" model_config = self.config if hasattr(model_config, "to_dict"): model_config = model_config.to_dict() if model_config["model_type"] != "custom": - metainfo["base_model"] = model_config["_name_or_path"] - - card_data = ModelCardData(**metainfo) + card.data["base_model"] = model_config["_name_or_path"] - # add extra data to model card body - card = ModelCard.from_template(card_data) - - # check if there is already a text body on the README.md, if not, use default template - if card_exists: - text = readme.split("\n---\n", 1)[-1] - else: - text = card.text - lines = text.splitlines() + lines = card.text.splitlines() quantization_config = None if hasattr(self.config, "quantization_config"): From 84d898331f2f7c7bd0237a0aae172d52f2a08bca Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 29 Sep 2023 16:04:52 +0200 Subject: [PATCH 4/4] Make style --- src/peft/peft_model.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 7479582797..e7f1ae7f61 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -17,14 +17,12 @@ import inspect import os -import re import warnings from contextlib import contextmanager from copy import deepcopy from typing import Any, Dict, List, Optional, Union import torch -import yaml from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory @@ -656,12 +654,8 @@ def create_or_update_model_card(self, output_dir: str): """ filename = os.path.join(output_dir, "README.md") - - card = ( - ModelCard.load(filename) - if os.path.exists(filename) - else ModelCard.from_template(ModelCardData()) - ) + + card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData()) card.data["library_name"] = "peft" model_config = self.config