Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Fix GIT batched captioning #21738

Merged
merged 1 commit into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,31 +1217,29 @@ def _prepare_model_inputs(
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)

return inputs, input_name, model_kwargs

def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
) -> tf.Tensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs

encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.shape[:-1]
Expand All @@ -1250,7 +1248,13 @@ def _maybe_initialize_input_ids_for_generation(
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")

batch_size = batch_size if batch_size is not None else 1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, tf.Tensor):
batch_size = value.shape[0]
break
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id

@staticmethod
Expand Down
18 changes: 11 additions & 7 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,17 +544,15 @@ def _prepare_model_inputs(
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs

def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Expand All @@ -567,13 +565,13 @@ def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs

encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
Expand All @@ -582,7 +580,13 @@ def _maybe_initialize_input_ids_for_generation(
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")

batch_size = batch_size if batch_size is not None else 1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def _prepare_attention_mask_for_generation(
Expand Down
22 changes: 22 additions & 0 deletions tests/models/git/test_modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,24 @@ def _test_beam_search_generate(self, config, input_ids, input_mask, pixel_values

self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))

def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
model = GitForCausalLM(config=config)
model.to(torch_device)
model.eval()

# generate
generated_ids = model.generate(
input_ids=None, # captioning -> no input_ids
attention_mask=None,
pixel_values=pixel_values,
do_sample=False,
max_length=20,
num_beams=2,
num_return_sequences=2,
)

self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

Expand Down Expand Up @@ -398,6 +416,10 @@ def test_beam_search_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_beam_search_generate(*config_and_inputs)

def test_batched_generate_captioning(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_batched_generate_captioning(*config_and_inputs)

def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
Expand Down