Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base model metadata to model card #975

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from huggingface_hub import ModelCard, ModelCardData, hf_hub_download
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
Expand Down Expand Up @@ -55,7 +55,6 @@
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
add_library_to_model_card,
get_peft_model_state_dict,
infer_device,
load_peft_weights,
Expand Down Expand Up @@ -650,13 +649,22 @@ def create_or_update_model_card(self, output_dir: str):
Updates or create model card to include information about peft:
1. Adds `peft` library tag
2. Adds peft version
3. Adds quantization information if it was used
3. Adds base model info
4. Adds quantization information if it was used
"""
# Adds `peft` library tag
add_library_to_model_card(output_dir)

with open(os.path.join(output_dir, "README.md"), "r") as f:
lines = f.readlines()
filename = os.path.join(output_dir, "README.md")

card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData())

card.data["library_name"] = "peft"
model_config = self.config
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
if model_config["model_type"] != "custom":
card.data["base_model"] = model_config["_name_or_path"]

lines = card.text.splitlines()

quantization_config = None
if hasattr(self.config, "quantization_config"):
Expand All @@ -681,9 +689,8 @@ def create_or_update_model_card(self, output_dir: str):
else:
lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n")

# write the lines back to README.md
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.writelines(lines)
card.text = "\n".join(lines)
card.save(filename)


class PeftModelForSequenceClassification(PeftModel):
Expand Down
1 change: 0 additions & 1 deletion src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
_set_trainable,
add_library_to_model_card,
bloom_model_postprocess_past_key_value,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
Expand Down
26 changes: 0 additions & 26 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
import copy
import inspect
import os
import warnings
from typing import Optional

Expand All @@ -39,31 +38,6 @@ def infer_device():
return torch_device


# Add or edit model card to have `library_name: peft`
def add_library_to_model_card(output_dir):
if os.path.exists(os.path.join(output_dir, "README.md")):
with open(os.path.join(output_dir, "README.md"), "r") as f:
lines = f.readlines()
# check if the first line is `---`
if len(lines) > 0 and lines[0].startswith("---"):
for i, line in enumerate(lines[1:]):
# check if line starts with `library_name`, if yes, update it
if line.startswith("library_name"):
lines[i + 1] = "library_name: peft\n"
break
elif line.startswith("---"):
# insert `library_name: peft` before the last `---`
lines.insert(i + 1, "library_name: peft\n")
break
else:
lines = ["---\n", "library_name: peft\n", "---\n"] + lines
else:
lines = ["---\n", "library_name: peft\n", "---\n"]
# write the lines back to README.md
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.writelines(lines)


# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
past_key_values = torch.cat(past_key_values)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import tempfile
import unittest

Expand Down Expand Up @@ -365,6 +366,41 @@ def run_with_disable(config_kwargs, bias):
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

def test_existing_model_card(self):
# ensure that if there is already a model card, it is not overwritten
model = MLP()
config = LoraConfig(target_modules=["lin0"])
model = get_peft_model(model, config)

with tempfile.TemporaryDirectory() as tmp_dirname:
# create a model card
text = "---\nmeta: hello\n---\nThis is a model card\n"
with open(os.path.join(tmp_dirname, "README.md"), "w") as f:
f.write(text)

model.save_pretrained(tmp_dirname)
with open(os.path.join(tmp_dirname, "README.md"), "r") as f:
model_card = f.read()

self.assertIn("library_name: peft", model_card)
self.assertIn("meta: hello", model_card)
self.assertIn("This is a model card", model_card)

def test_non_existing_model_card(self):
# ensure that if there is already a model card, it is not overwritten
model = MLP()
config = LoraConfig(target_modules=["lin0"])
model = get_peft_model(model, config)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
with open(os.path.join(tmp_dirname, "README.md"), "r") as f:
model_card = f.read()

self.assertIn("library_name: peft", model_card)
# rough check that the model card is pre-filled
self.assertGreater(len(model_card), 1000)


class TestMultiRankAdapter(unittest.TestCase):
"""Tests related to multirank LoRA adapters"""
Expand Down
36 changes: 36 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import pickle
import re
import tempfile
from collections import OrderedDict
from dataclasses import replace

import torch
import yaml
from diffusers import StableDiffusionPipeline

from peft import (
Expand Down Expand Up @@ -171,6 +174,33 @@ class PeftCommonTester:
def prepare_inputs_for_common(self):
raise NotImplementedError

def check_modelcard(self, tmp_dirname, model):
# check the generated README.md
filename = os.path.join(tmp_dirname, "README.md")
self.assertTrue(os.path.exists(filename))
with open(filename, "r", encoding="utf-8") as f:
readme = f.read()
metainfo = re.search(r"---\n(.*?)\n---", readme, re.DOTALL).group(1)
dct = yaml.safe_load(metainfo)
self.assertEqual(dct["library_name"], "peft")

model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
if model_config["model_type"] != "custom":
self.assertEqual(dct["base_model"], model_config["_name_or_path"])
else:
self.assertTrue("base_model" not in dct)

def check_config_json(self, tmp_dirname, model):
# check the generated config.json
filename = os.path.join(tmp_dirname, "adapter_config.json")
self.assertTrue(os.path.exists(filename))
with open(filename, "r", encoding="utf-8") as f:
config = json.load(f)

model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
if model_config["model_type"] != "custom":
self.assertEqual(config["base_model_name_or_path"], model_config["_name_or_path"])

def _test_model_attr(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down Expand Up @@ -293,6 +323,9 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

self.check_modelcard(tmp_dirname, model)
self.check_config_json(tmp_dirname, model)

def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
Expand Down Expand Up @@ -368,6 +401,9 @@ def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_k
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))
self.assertFalse(os.path.exists(os.path.join(new_adapter_dir, "config.json")))

self.check_modelcard(tmp_dirname, model)
self.check_config_json(tmp_dirname, model)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname, selected_adapters=["default"])

Expand Down
Loading