Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decode_input_ids to bare T5Model and improve doc #18791

Merged
merged 8 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/source/en/model_doc/t5.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,15 @@ ignored. The code example below illustrates all of this.

>>> # encode the targets
>>> target_encoding = tokenizer(
... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
... [output_sequence_1, output_sequence_2],
... padding="longest",
... max_length=max_target_length,
... truncation=True,
... return_tensors="pt",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

... )
>>> labels = target_encoding.input_ids

>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
>>> labels = torch.tensor(labels)
>>> labels[labels == tokenizer.pad_token_id] = -100

>>> # forward pass
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,10 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids

>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for T5ForConditionalGeneration as it does this internally using labels arg.
ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved
>>> decoder_input_ids = model._shift_right(decoder_input_ids)

>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,10 @@ def forward(
... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1

>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for T5ForConditionalGeneration as it does this internally using labels arg.
ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved
>>> decoder_input_ids = model._shift_right(decoder_input_ids)

>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,10 @@ def call(
... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1

>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for T5ForConditionalGeneration as it does this internally using labels arg.
ekagra-ranjan marked this conversation as resolved.
Show resolved Hide resolved
>>> decoder_input_ids = model._shift_right(decoder_input_ids)

>>> # forward pass
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
Expand Down