diff --git a/docs/source/en/model_doc/t5.mdx b/docs/source/en/model_doc/t5.mdx index 5a1928923476cb..92cd753b645767 100644 --- a/docs/source/en/model_doc/t5.mdx +++ b/docs/source/en/model_doc/t5.mdx @@ -187,12 +187,15 @@ ignored. The code example below illustrates all of this. >>> # encode the targets >>> target_encoding = tokenizer( -... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True +... [output_sequence_1, output_sequence_2], +... padding="longest", +... max_length=max_target_length, +... truncation=True, +... return_tensors="pt", ... ) >>> labels = target_encoding.input_ids >>> # replace padding token id's of the labels by -100 so it's ignored by the loss ->>> labels = torch.tensor(labels) >>> labels[labels == tokenizer.pad_token_id] = -100 >>> # forward pass diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 918a605fc4813a..2732bf591690f7 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -1388,6 +1388,10 @@ class FlaxT5Model(FlaxT5PreTrainedModel): ... ).input_ids >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e4c36109bd7709..0ec4c74d1d8da5 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1376,6 +1376,10 @@ def forward( ... ).input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 091cb9d63eb42d..dc909c8d8f3349 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1180,6 +1180,10 @@ def call( ... ).input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + >>> # forward pass >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state