Skip to content

Commit

Permalink
Fix test and docs (#14399)
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge authored Nov 15, 2021
1 parent 4ce74ed commit 74e6111
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/vit/modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def call(
Examples::
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests
Expand All @@ -809,7 +810,7 @@ def call(
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
"""
inputs = input_processing(
func=self.call,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def default_feature_extractor(self):

@slow
def test_inference_image_classification_head(self):
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

feature_extractor = self.default_feature_extractor
image = prepare_img()
Expand Down

0 comments on commit 74e6111

Please sign in to comment.