Skip to content

Commit

Permalink
Add doctests to Perceiver examples (#19129)
Browse files Browse the repository at this point in the history
* Fix bug in example and add to tests

* Fix failing tests

* Check the size of logits

* Code style

* Try again...

* Add expected loss for PerceiverForMaskedLM doctest

Co-authored-by: Steven Anton <antonstv@amazon.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 23, 2022
1 parent fe01ec3 commit 49bf569
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,8 @@ def forward(
>>> with torch.no_grad():
... outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
Expand All @@ -810,6 +812,7 @@ def forward(
>>> # EXAMPLE 2: using the Perceiver to classify images
>>> # - we define an ImagePreprocessor, which can be used to embed images
>>> config = PerceiverConfig(image_size=224)
>>> preprocessor = PerceiverImagePreprocessor(
... config,
... prep_type="conv1x1",
Expand Down Expand Up @@ -844,6 +847,8 @@ def forward(
>>> with torch.no_grad():
... outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -1017,7 +1022,12 @@ def forward(
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> round(loss.item(), 2)
19.87
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> # inference
>>> text = "This is an incomplete sentence where some words are missing."
Expand All @@ -1030,6 +1040,8 @@ def forward(
>>> with torch.no_grad():
... outputs = model(**encoding)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
>>> tokenizer.decode(masked_tokens_predictions)
Expand Down Expand Up @@ -1128,6 +1140,8 @@ def forward(
>>> inputs = tokenizer(text, return_tensors="pt").input_ids
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
```"""
if inputs is not None and input_ids is not None:
raise ValueError("You cannot use both `inputs` and `input_ids`")
Expand Down Expand Up @@ -1265,9 +1279,13 @@ def forward(
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
Expand Down Expand Up @@ -1402,9 +1420,13 @@ def forward(
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
Expand Down Expand Up @@ -1539,9 +1561,13 @@ def forward(
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
Expand Down Expand Up @@ -1689,6 +1715,8 @@ def forward(
>>> patches = torch.randn(1, 2, 27, 368, 496)
>>> outputs = model(inputs=patches)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 368, 496, 2]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -1915,6 +1943,14 @@ def forward(
>>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
>>> logits = outputs.logits
>>> list(logits["audio"].shape)
[1, 240]
>>> list(logits["image"].shape)
[1, 6272, 3]
>>> list(logits["label"].shape)
[1, 700]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -2925,7 +2961,6 @@ def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str
self.classifier = nn.Linear(in_channels, config.samples_per_patch)

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:

logits = self.classifier(inputs)
return torch.reshape(logits, [inputs.shape[0], -1])

Expand Down
1 change: 1 addition & 0 deletions utils/documentation_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/perceiver/modeling_perceiver.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py
Expand Down

0 comments on commit 49bf569

Please sign in to comment.