Skip to content

Commit

Permalink
[Fix doc example] FlaxVisionEncoderDecoder (huggingface#15626)
Browse files Browse the repository at this point in the history
* Fix wrong checkpoint name: vit

* Fix missing import

* Fix more missing import

* make style

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
  • Loading branch information
3 people authored and ManuelFay committed Mar 31, 2022
1 parent 44e0b88 commit 8884fe9
Showing 1 changed file with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def encode(
Example:
```python
>>> from transformers import FlaxVisionEncoderDecoderModel
>>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
>>> from PIL import Image
>>> import requests
Expand All @@ -403,7 +403,9 @@ def encode(
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained("vit", "gpt2")
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "gpt2"
... )
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
>>> encoder_outputs = model.encode(pixel_values)
Expand Down Expand Up @@ -469,7 +471,7 @@ def decode(
Example:
```python
>>> from transformers import FlaxVisionEncoderDecoderModel
>>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
>>> import jax.numpy as jnp
>>> from PIL import Image
>>> import requests
Expand All @@ -480,7 +482,9 @@ def decode(
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained("vit", "gpt2")
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "gpt2"
... )
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
>>> encoder_outputs = model.encode(pixel_values)
Expand Down Expand Up @@ -610,7 +614,9 @@ def __call__(
>>> tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2")
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained("vit", "gpt2")
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "gpt2"
... )
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
Expand Down

0 comments on commit 8884fe9

Please sign in to comment.