From 774a04acda454ba589d0daeeb504f90e35476051 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 28 Mar 2022 14:18:53 +0200 Subject: [PATCH 1/2] Fix doc --- src/transformers/utils/doc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 394d2aaa2fed76..75c5d688243852 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -269,9 +269,8 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None): ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) - >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) + >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification") - >>> num_labels = len(model.config.id2label) >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to( ... torch.float ... ) From d77680e1ad6a4b57fa766790383e6b68268bf29b Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 28 Mar 2022 14:23:14 +0200 Subject: [PATCH 2/2] Make fixup --- src/transformers/utils/doc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/doc.py b/src/transformers/utils/doc.py index 75c5d688243852..f81066bea6c2cf 100644 --- a/src/transformers/utils/doc.py +++ b/src/transformers/utils/doc.py @@ -269,7 +269,9 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None): ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) - >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification") + >>> model = {model_class}.from_pretrained( + ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification" + ... ) >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to( ... torch.float