diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index f9897ef2dbd3c5..16452de1ab650e 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -836,6 +836,8 @@ def __call__( -- The token ids of the generated text. """ + if isinstance(text_inputs, str): + text_inputs = [text_inputs] results = [] for prompt_text in text_inputs: # Manage correct placement of the tensors @@ -2382,6 +2384,8 @@ def __call__( updated generated responses for those containing a new user input. """ + if isinstance(conversations, Conversation): + conversations = [conversations] # Input validation if isinstance(conversations, list): for conversation in conversations: @@ -2398,8 +2402,6 @@ def __call__( assert ( self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input" - elif isinstance(conversations, Conversation): - conversations = [conversations] else: raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input") diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 3492283479b718..2c5da9ee249101 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -9,26 +9,30 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 -class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): +class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "conversational" small_models = [] # Models tested without the @slow decorator large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator - valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]] invalid_inputs = ["Hi there!", Conversation()] - def _test_pipeline( - self, nlp - ): # e overide the default test method to check that the output is a `Conversation` object + def _test_pipeline(self, nlp): + # e overide the default test method to check that the output is a `Conversation` object self.assertIsNotNone(nlp) - mono_result = nlp(self.valid_inputs[0]) + # We need to recreate conversation for successive tests to pass as + # Conversation objects get *consumed* by the pipeline + conversation = Conversation("Hi there!") + mono_result = nlp(conversation) self.assertIsInstance(mono_result, Conversation) - multi_result = nlp(self.valid_inputs[1]) + conversations = [Conversation("Hi there!"), Conversation("How are you?")] + multi_result = nlp(conversations) self.assertIsInstance(multi_result, list) self.assertIsInstance(multi_result[0], Conversation) + # Conversation have been consumed and are not valid anymore # Inactive conversations passed to the pipeline raise a ValueError - self.assertRaises(ValueError, nlp, self.valid_inputs[1]) + self.assertRaises(ValueError, nlp, conversation) + self.assertRaises(ValueError, nlp, conversations) for bad_input in self.invalid_inputs: self.assertRaises(Exception, nlp, bad_input) diff --git a/tests/test_pipelines_text_generation.py b/tests/test_pipelines_text_generation.py index a4ca32551d697a..711b2e10e3773f 100644 --- a/tests/test_pipelines_text_generation.py +++ b/tests/test_pipelines_text_generation.py @@ -1,5 +1,7 @@ import unittest +from transformers import pipeline + from .test_pipelines_common import MonoInputPipelineCommonMixin @@ -8,3 +10,20 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas pipeline_running_kwargs = {"prefix": "This is "} small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator large_models = [] # Models tested with the @slow decorator + + def test_simple_generation(self): + nlp = pipeline(task="text-generation", model=self.small_models[0]) + # text-generation is non-deterministic by nature, we can't fully test the output + + outputs = nlp("This is a test") + + self.assertEqual(len(outputs), 1) + self.assertEqual(list(outputs[0].keys()), ["generated_text"]) + self.assertEqual(type(outputs[0]["generated_text"]), str) + + outputs = nlp(["This is a test", "This is a second test"]) + self.assertEqual(len(outputs[0]), 1) + self.assertEqual(list(outputs[0][0].keys()), ["generated_text"]) + self.assertEqual(type(outputs[0][0]["generated_text"]), str) + self.assertEqual(list(outputs[1][0].keys()), ["generated_text"]) + self.assertEqual(type(outputs[1][0]["generated_text"]), str)