diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index eacc5bf009..e7f1ae7f61 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -26,7 +26,7 @@ from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory -from huggingface_hub import hf_hub_download +from huggingface_hub import ModelCard, ModelCardData, hf_hub_download from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel @@ -55,7 +55,6 @@ _prepare_prompt_learning_config, _set_adapter, _set_trainable, - add_library_to_model_card, get_peft_model_state_dict, infer_device, load_peft_weights, @@ -650,13 +649,22 @@ def create_or_update_model_card(self, output_dir: str): Updates or create model card to include information about peft: 1. Adds `peft` library tag 2. Adds peft version - 3. Adds quantization information if it was used + 3. Adds base model info + 4. Adds quantization information if it was used """ - # Adds `peft` library tag - add_library_to_model_card(output_dir) - with open(os.path.join(output_dir, "README.md"), "r") as f: - lines = f.readlines() + filename = os.path.join(output_dir, "README.md") + + card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData()) + + card.data["library_name"] = "peft" + model_config = self.config + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + if model_config["model_type"] != "custom": + card.data["base_model"] = model_config["_name_or_path"] + + lines = card.text.splitlines() quantization_config = None if hasattr(self.config, "quantization_config"): @@ -681,9 +689,8 @@ def create_or_update_model_card(self, output_dir: str): else: lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n") - # write the lines back to README.md - with open(os.path.join(output_dir, "README.md"), "w") as f: - f.writelines(lines) + card.text = "\n".join(lines) + card.save(filename) class PeftModelForSequenceClassification(PeftModel): diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 9c02a911fd..b42d8d070b 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -30,7 +30,6 @@ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, _set_trainable, - add_library_to_model_card, bloom_model_postprocess_past_key_value, prepare_model_for_int8_training, prepare_model_for_kbit_training, diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 2dbe0a2c5b..35bb0622dc 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -14,7 +14,6 @@ # limitations under the License. import copy import inspect -import os import warnings from typing import Optional @@ -39,31 +38,6 @@ def infer_device(): return torch_device -# Add or edit model card to have `library_name: peft` -def add_library_to_model_card(output_dir): - if os.path.exists(os.path.join(output_dir, "README.md")): - with open(os.path.join(output_dir, "README.md"), "r") as f: - lines = f.readlines() - # check if the first line is `---` - if len(lines) > 0 and lines[0].startswith("---"): - for i, line in enumerate(lines[1:]): - # check if line starts with `library_name`, if yes, update it - if line.startswith("library_name"): - lines[i + 1] = "library_name: peft\n" - break - elif line.startswith("---"): - # insert `library_name: peft` before the last `---` - lines.insert(i + 1, "library_name: peft\n") - break - else: - lines = ["---\n", "library_name: peft\n", "---\n"] + lines - else: - lines = ["---\n", "library_name: peft\n", "---\n"] - # write the lines back to README.md - with open(os.path.join(output_dir, "README.md"), "w") as f: - f.writelines(lines) - - # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): past_key_values = torch.cat(past_key_values) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 18251a2458..7e381f4ccb 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import os import tempfile import unittest @@ -365,6 +366,41 @@ def run_with_disable(config_kwargs, bias): def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + def test_existing_model_card(self): + # ensure that if there is already a model card, it is not overwritten + model = MLP() + config = LoraConfig(target_modules=["lin0"]) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + # create a model card + text = "---\nmeta: hello\n---\nThis is a model card\n" + with open(os.path.join(tmp_dirname, "README.md"), "w") as f: + f.write(text) + + model.save_pretrained(tmp_dirname) + with open(os.path.join(tmp_dirname, "README.md"), "r") as f: + model_card = f.read() + + self.assertIn("library_name: peft", model_card) + self.assertIn("meta: hello", model_card) + self.assertIn("This is a model card", model_card) + + def test_non_existing_model_card(self): + # ensure that if there is already a model card, it is not overwritten + model = MLP() + config = LoraConfig(target_modules=["lin0"]) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + with open(os.path.join(tmp_dirname, "README.md"), "r") as f: + model_card = f.read() + + self.assertIn("library_name: peft", model_card) + # rough check that the model card is pre-filled + self.assertGreater(len(model_card), 1000) + class TestMultiRankAdapter(unittest.TestCase): """Tests related to multirank LoRA adapters""" diff --git a/tests/testing_common.py b/tests/testing_common.py index b6ac83dcbf..495c6cfc98 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os import pickle +import re import tempfile from collections import OrderedDict from dataclasses import replace import torch +import yaml from diffusers import StableDiffusionPipeline from peft import ( @@ -171,6 +174,33 @@ class PeftCommonTester: def prepare_inputs_for_common(self): raise NotImplementedError + def check_modelcard(self, tmp_dirname, model): + # check the generated README.md + filename = os.path.join(tmp_dirname, "README.md") + self.assertTrue(os.path.exists(filename)) + with open(filename, "r", encoding="utf-8") as f: + readme = f.read() + metainfo = re.search(r"---\n(.*?)\n---", readme, re.DOTALL).group(1) + dct = yaml.safe_load(metainfo) + self.assertEqual(dct["library_name"], "peft") + + model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() + if model_config["model_type"] != "custom": + self.assertEqual(dct["base_model"], model_config["_name_or_path"]) + else: + self.assertTrue("base_model" not in dct) + + def check_config_json(self, tmp_dirname, model): + # check the generated config.json + filename = os.path.join(tmp_dirname, "adapter_config.json") + self.assertTrue(os.path.exists(filename)) + with open(filename, "r", encoding="utf-8") as f: + config = json.load(f) + + model_config = model.config if isinstance(model.config, dict) else model.config.to_dict() + if model_config["model_type"] != "custom": + self.assertEqual(config["base_model_name_or_path"], model_config["_name_or_path"]) + def _test_model_attr(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -293,6 +323,9 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + self.check_modelcard(tmp_dirname, model) + self.check_config_json(tmp_dirname, model) + def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): if issubclass(config_cls, AdaLoraConfig): # AdaLora does not support adding more than 1 adapter @@ -368,6 +401,9 @@ def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_k self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "config.json"))) + self.check_modelcard(tmp_dirname, model) + self.check_config_json(tmp_dirname, model) + with tempfile.TemporaryDirectory() as tmp_dirname: model.save_pretrained(tmp_dirname, selected_adapters=["default"])