Skip to content

Commit

Permalink
Preserve card metadata format/ordering on load->save
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Sep 26, 2024
1 parent f984cdc commit 680d7a5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def content(self, content: str):
data_dict = {}
self.text = content

self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors)
self.data = self.card_data_class(
**data_dict, ignore_metadata_errors=self.ignore_metadata_errors, original_order=list(data_dict.keys())
)

def __str__(self):
return self.content
Expand Down
15 changes: 11 additions & 4 deletions src/huggingface_hub/repocard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,12 @@ class CardData:
inherit from `dict` to allow this export step.
"""

def __init__(self, ignore_metadata_errors: bool = False, **kwargs):
def __init__(self, ignore_metadata_errors: bool = False, original_order: Optional[List[str]] = None, **kwargs):
self.__dict__.update(kwargs)
if original_order:
self.__dict__ = {
k: self.__dict__[k] for k in original_order + list(set(self.__dict__.keys()) - set(original_order))
}

def to_dict(self) -> Dict[str, Any]:
"""Converts CardData to a dict.
Expand Down Expand Up @@ -316,6 +320,7 @@ def __init__(
pipeline_tag: Optional[str] = None,
tags: Optional[List[str]] = None,
ignore_metadata_errors: bool = False,
original_order: Optional[List[str]] = None,
**kwargs,
):
self.base_model = base_model
Expand Down Expand Up @@ -347,7 +352,7 @@ def __init__(
" some information will be lost. Use it at your own risk."
)

super().__init__(**kwargs)
super().__init__(**kwargs, original_order=original_order)

if self.eval_results:
if isinstance(self.eval_results, EvalResult):
Expand Down Expand Up @@ -419,6 +424,7 @@ def __init__(
train_eval_index: Optional[Dict] = None,
config_names: Optional[Union[str, List[str]]] = None,
ignore_metadata_errors: bool = False,
original_order: Optional[List[str]] = None,
**kwargs,
):
self.annotations_creators = annotations_creators
Expand All @@ -436,7 +442,7 @@ def __init__(

# TODO - maybe handle this similarly to EvalResult?
self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None)
super().__init__(**kwargs)
super().__init__(**kwargs, original_order=original_order)

def _to_dict(self, data_dict):
data_dict["train-eval-index"] = data_dict.pop("train_eval_index")
Expand Down Expand Up @@ -507,6 +513,7 @@ def __init__(
datasets: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
ignore_metadata_errors: bool = False,
original_order: Optional[List[str]] = None,
**kwargs,
):
self.title = title
Expand All @@ -520,7 +527,7 @@ def __init__(
self.models = models
self.datasets = datasets
self.tags = _to_unique_list(tags)
super().__init__(**kwargs)
super().__init__(**kwargs, original_order=original_order)


def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,12 @@ def test_model_card_from_template_eval_results(self):
self.assertTrue(card.data.to_dict().get("eval_results") is None)
self.assertEqual(str(card)[: len(DUMMY_MODELCARD_EVAL_RESULT)], DUMMY_MODELCARD_EVAL_RESULT)

def test_preserve_order_load_save(self):
model_card = ModelCard(DUMMY_MODELCARD)
model_card.data.license = "test"
model_card.content
self.assertEqual(model_card.content, "---\nlicense: test\ndatasets:\n- foo\n- bar\n---\n\nHello\n")


class DatasetCardTest(TestCaseWithHfApi):
def test_load_datasetcard_from_file(self):
Expand Down

0 comments on commit 680d7a5

Please sign in to comment.