Skip to content

Commit

Permalink
Generate: validate model_kwargs on FLAX (and catch typos in generate …
Browse files Browse the repository at this point in the history
…arguments) (huggingface#18653)
  • Loading branch information
gante authored and oneraghavan committed Sep 26, 2022
1 parent af7f5ed commit 77a58ea
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.


import inspect
import warnings
from functools import partial
from typing import Dict, Optional
from typing import Any, Dict, Optional

import numpy as np

Expand Down Expand Up @@ -160,6 +161,24 @@ def _adapt_logits_for_beam_search(self, logits):
"""
return logits

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)

if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)

def generate(
self,
input_ids: jnp.ndarray,
Expand Down Expand Up @@ -262,6 +281,9 @@ def generate(
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())

# set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
Expand Down
21 changes: 21 additions & 0 deletions tests/generation/test_generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import random
import unittest

import numpy as np

Expand All @@ -26,6 +27,7 @@

import jax.numpy as jnp
from jax import jit
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
Expand Down Expand Up @@ -273,3 +275,22 @@ def test_beam_search_generate_attn_mask(self):
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences

self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())


@require_flax
class FlaxGenerationIntegrationTests(unittest.TestCase):
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids

# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)

# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)

0 comments on commit 77a58ea

Please sign in to comment.