Skip to content

Commit

Permalink
[FIX] TextGenerationPipeline is currently broken. (#8256)
Browse files Browse the repository at this point in the history
* [FIX] TextGenerationPipeline is currently broken.

It's most likely due to #8180.
What's missing is a multi vs single string handler at the beginning of
the pipe.
And also there was no testing of this pipeline.

* Fixing Conversational tests too.
  • Loading branch information
Narsil authored Nov 3, 2020
1 parent a1bbcf3 commit c66ffa3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
6 changes: 4 additions & 2 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ def __call__(
-- The token ids of the generated text.
"""

if isinstance(text_inputs, str):
text_inputs = [text_inputs]
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
Expand Down Expand Up @@ -2382,6 +2384,8 @@ def __call__(
updated generated responses for those containing a new user input.
"""

if isinstance(conversations, Conversation):
conversations = [conversations]
# Input validation
if isinstance(conversations, list):
for conversation in conversations:
Expand All @@ -2398,8 +2402,6 @@ def __call__(
assert (
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
elif isinstance(conversations, Conversation):
conversations = [conversations]
else:
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

Expand Down
20 changes: 12 additions & 8 deletions tests/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,30 @@
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0


class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
small_models = [] # Models tested without the @slow decorator
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
invalid_inputs = ["Hi there!", Conversation()]

def _test_pipeline(
self, nlp
): # e overide the default test method to check that the output is a `Conversation` object
def _test_pipeline(self, nlp):
# e overide the default test method to check that the output is a `Conversation` object
self.assertIsNotNone(nlp)

mono_result = nlp(self.valid_inputs[0])
# We need to recreate conversation for successive tests to pass as
# Conversation objects get *consumed* by the pipeline
conversation = Conversation("Hi there!")
mono_result = nlp(conversation)
self.assertIsInstance(mono_result, Conversation)

multi_result = nlp(self.valid_inputs[1])
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
multi_result = nlp(conversations)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], Conversation)
# Conversation have been consumed and are not valid anymore
# Inactive conversations passed to the pipeline raise a ValueError
self.assertRaises(ValueError, nlp, self.valid_inputs[1])
self.assertRaises(ValueError, nlp, conversation)
self.assertRaises(ValueError, nlp, conversations)

for bad_input in self.invalid_inputs:
self.assertRaises(Exception, nlp, bad_input)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

from transformers import pipeline

from .test_pipelines_common import MonoInputPipelineCommonMixin


Expand All @@ -8,3 +10,20 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
pipeline_running_kwargs = {"prefix": "This is "}
small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator

def test_simple_generation(self):
nlp = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output

outputs = nlp("This is a test")

self.assertEqual(len(outputs), 1)
self.assertEqual(list(outputs[0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0]["generated_text"]), str)

outputs = nlp(["This is a test", "This is a second test"])
self.assertEqual(len(outputs[0]), 1)
self.assertEqual(list(outputs[0][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)

0 comments on commit c66ffa3

Please sign in to comment.