Skip to content

Commit

Permalink
add push_to_hub to pipeline (#29172)
Browse files Browse the repository at this point in the history
* add `push_to_hub` to pipeline

* fix docs

* format with ruff

* update save_pretrained

* update save_pretrained

* remove unnecessary comment

* switch to push_to_hub method in DynamicPipelineTester

* remove unused imports

* update docs for add_new_pipeline

* fix docs for add_new_pipeline

* add comment

* fix italien docs

* changes to token retrieval for pipelines

* Update src/transformers/pipelines/base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and ArthurZucker committed Apr 22, 2024
1 parent 5c98053 commit 525a561
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 44 deletions.
8 changes: 2 additions & 6 deletions docs/source/de/add_new_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,10 @@ from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```

Dann können wir sie auf dem Hub mit der Methode `save_pretrained` in einem `Repository` freigeben:
Dann können wir sie auf dem Hub mit der Methode `push_to_hub` freigeben:

```py
from huggingface_hub import Repository

repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
classifier.push_to_hub("test-dynamic-pipeline")
```

Dadurch wird die Datei, in der Sie `PairClassificationPipeline` definiert haben, in den Ordner `"test-dynamic-pipeline"` kopiert,
Expand Down
8 changes: 2 additions & 6 deletions docs/source/en/add_new_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,10 @@ from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```

Then we can share it on the Hub by using the `save_pretrained` method in a `Repository`:
Then we can share it on the Hub by using the `push_to_hub` method:

```py
from huggingface_hub import Repository

repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
classifier.push_to_hub("test-dynamic-pipeline")
```

This will copy the file where you defined `PairClassificationPipeline` inside the folder `"test-dynamic-pipeline"`,
Expand Down
8 changes: 2 additions & 6 deletions docs/source/es/add_new_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,10 @@ from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```

Ahora podemos compartirlo en el Hub usando el método `save_pretrained` (guardar pre-entrenado) en un `Repository`:
Ahora podemos compartirlo en el Hub usando el método `save_pretrained`:

```py
from huggingface_hub import Repository

repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
classifier.push_to_hub("test-dynamic-pipeline")
```

Esto copiará el archivo donde definiste `PairClassificationPipeline` dentro de la carpeta `"test-dynamic-pipeline"`,
Expand Down
8 changes: 2 additions & 6 deletions docs/source/it/add_new_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,10 @@ from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```

Successivamente possiamo condividerlo sull'Hub usando il metodo `save_pretrained` in un `Repository`:
Successivamente possiamo condividerlo sull'Hub usando il metodo `push_to_hub`

```py
from huggingface_hub import Repository

repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
classifier.push_to_hub("test-dynamic-pipeline")
```

Questo codice copierà il file dove è stato definitp `PairClassificationPipeline` all'interno della cartella `"test-dynamic-pipeline"`,
Expand Down
8 changes: 2 additions & 6 deletions docs/source/ko/add_new_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,10 @@ from transformers import pipeline
classifier = pipeline("pair-classification", model="sgugger/finetuned-bert-mrpc")
```

그런 다음 `Repository``save_pretrained` 메소드를 사용하여 허브에 공유할 수 있습니다:
그런 다음 `push_to_hub` 메소드를 사용하여 허브에 공유할 수 있습니다:

```py
from huggingface_hub import Repository

repo = Repository("test-dynamic-pipeline", clone_from="{your_username}/test-dynamic-pipeline")
classifier.save_pretrained("test-dynamic-pipeline")
repo.push_to_hub()
classifier.push_to_hub("test-dynamic-pipeline")
```

이렇게 하면 "test-dynamic-pipeline" 폴더 내에 `PairClassificationPipeline`을 정의한 파일이 복사되며, 파이프라인의 모델과 토크나이저도 저장한 후, `{your_username}/test-dynamic-pipeline` 저장소에 있는 모든 것을 푸시합니다.
Expand Down
44 changes: 37 additions & 7 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
ModelOutput,
PushToHubMixin,
add_end_docstrings,
copy_func,
infer_framework,
is_tf_available,
is_torch_available,
Expand Down Expand Up @@ -781,7 +783,7 @@ def build_pipeline_init_args(


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
class Pipeline(_ScikitCompat):
class Pipeline(_ScikitCompat, PushToHubMixin):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
different pipelines.
Expand Down Expand Up @@ -908,16 +910,36 @@ def __init__(
# then we should keep working
self.image_processor = self.feature_extractor

def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs,
):
"""
Save the pipeline's model and tokenizer.
Args:
save_directory (`str`):
save_directory (`str` or `os.PathLike`):
A path to the directory where to saved. It will be created if it doesn't exist.
safe_serialization (`str`):
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow.
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token", None) is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token

if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand All @@ -944,16 +966,17 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
# Save the pipeline custom code
custom_object_save(self, save_directory)

self.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
kwargs["safe_serialization"] = safe_serialization
self.model.save_pretrained(save_directory, **kwargs)

if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory, **kwargs)

if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)
self.feature_extractor.save_pretrained(save_directory, **kwargs)

if self.image_processor is not None:
self.image_processor.save_pretrained(save_directory)
self.image_processor.save_pretrained(save_directory, **kwargs)

if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)
Expand Down Expand Up @@ -1234,6 +1257,13 @@ def iterate(self, inputs, preprocess_params, forward_params, postprocess_params)
yield self.run_single(input_, preprocess_params, forward_params, postprocess_params)


Pipeline.push_to_hub = copy_func(Pipeline.push_to_hub)
if Pipeline.push_to_hub.__doc__ is not None:
Pipeline.push_to_hub.__doc__ = Pipeline.push_to_hub.__doc__.format(
object="pipe", object_class="pipeline", object_files="pipeline file"
).replace(".from_pretrained", "")


class ChunkPipeline(Pipeline):
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
all_outputs = []
Expand Down
11 changes: 4 additions & 7 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import datasets
import numpy as np
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from huggingface_hub import HfFolder, delete_repo
from requests.exceptions import HTTPError

from transformers import (
Expand Down Expand Up @@ -846,9 +846,6 @@ def test_push_to_hub_dynamic_pipeline(self):
model = BertForSequenceClassification(config).eval()

with tempfile.TemporaryDirectory() as tmp_dir:
create_repo(f"{USER}/test-dynamic-pipeline", token=self._token)
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", token=self._token)

vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
Expand All @@ -860,7 +857,7 @@ def test_push_to_hub_dynamic_pipeline(self):
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]

classifier.save_pretrained(tmp_dir)
# checks
# checks if the configuration has been added after calling the save_pretrained method
self.assertDictEqual(
classifier.model.config.custom_pipelines,
{
Expand All @@ -871,8 +868,8 @@ def test_push_to_hub_dynamic_pipeline(self):
}
},
)

repo.push_to_hub()
# use push_to_hub method to push the pipeline
classifier.push_to_hub(f"{USER}/test-dynamic-pipeline", token=self._token)

# Fails if the user forget to pass along `trust_remote_code=True`
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 525a561

Please sign in to comment.