diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index b43863bc99c8..974c34ba8009 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -41,6 +41,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练 | [智能写诗](#智能写诗) | `Taskflow("poetry_generation")` | ✅ | ✅ | ✅ | | | 使用最大中文开源CPM模型完成写诗 | | [开放域对话](#开放域对话) | `Taskflow("dialogue")` | ✅ | ✅ | ✅ | | | 十亿级语料训练最强中文闲聊模型PLATO-Mini,支持多轮对话 | | [代码生成](#代码生成) | `Taskflow("code_generation")` | ✅ | ✅ | ✅ | | | 代码生成大模型 | +| [文图生成](#文图生成) | `Taskflow("text2image_generation")` | ✅ | ✅ | ✅ | | | 文图生成大模型 | ## QuickStart @@ -1324,6 +1325,40 @@ from paddlenlp import Taskflow * `output_scores`:是否要输出解码得分,请默认为False。 +### 文图生成 +
  通过文图生成模型来生成图片
+ +#### 支持单条、批量预测 + +```python +>>> from paddlenlp import Taskflow +# 默认模型为 pai-painter-painting-base-zh +>>> text2imagegen = Taskflow("text2image_generation", model="pai-painter-painting-base-zh") +# 单条输入 +>>> images = text2imagegen("风阁水帘今在眼,且来先看早梅红") +# [] +>>> images[0].save("figure.png") +# 多条输入 +>>> images = text2imagegen(["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"]) +# [, +# ] +>>> for i, image in enumerate(images): +>>> image.save(f"figure_{i}.png") +``` + +#### 可配置参数说明 +* `model`:可选模型,默认为`pai-painter-painting-base-zh`,支持的模型有`["pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh", "dalle-mini", "dalle-mega-v16", "dalle-mega"]`。 +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `temperature`:解码参数temperature,默认为1.0。 +* `top_k`:解码参数top_k,默认为32。 +* `top_p`:解码参数top_p,默认为1.0。 +* `conditional_scale`:dalle-mini模型使用的参数,可参考[推特](https://twitter.com/RiversHaveWings/status/1478093658716966912),默认为10.0。 +* `num_return_images`:返回图片的数量,默认为4,即4张图片水平拼接形成一张长图。 +* `use_faster`:是否使用faster_generation,默认为False,目前支持faster_generation的模型有`["pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh"]`。 +* `use_fp16_decoding`:是否使用fp16加速解码过程,默认为False,只有当use_faster为True的时候才有效。 + +
+ ## PART Ⅱ   定制化训练
适配任务列表
diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 5be4842efee4..83270831ab45 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -37,6 +37,7 @@ from .dialogue import DialogueTask from .information_extraction import UIETask from .code_generation import CodeGenerationTask +from .text2image_generation import Text2ImageGenerationTask warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False) @@ -317,6 +318,46 @@ }, "default": { "model": "Salesforce/codegen-350M-mono", + }, + }, + "text2image_generation": { + "models": { + "dalle-mini": { + "task_class": Text2ImageGenerationTask, + "task_flag": "text2image_generation-dalle-mini", + "task_priority_path": "dalle-mini", + }, + "dalle-mega-v16": { + "task_class": Text2ImageGenerationTask, + "task_flag": "text2image_generation-dalle-mega-v16", + "task_priority_path": "dalle-mega-v16", + }, + "dalle-mega": { + "task_class": Text2ImageGenerationTask, + "task_flag": "text2image_generation-dalle-mega", + "task_priority_path": "dalle-mega", + }, + "pai-painter-painting-base-zh": { + "task_class": Text2ImageGenerationTask, + "task_flag": + "text2image_generation-pai-painter-painting-base-zh", + "task_priority_path": "pai-painter-painting-base-zh", + }, + "pai-painter-scenery-base-zh": { + "task_class": Text2ImageGenerationTask, + "task_flag": + "text2image_generation-pai-painter-scenery-base-zh", + "task_priority_path": "pai-painter-scenery-base-zh", + }, + "pai-painter-commercial-base-zh": { + "task_class": Text2ImageGenerationTask, + "task_flag": + "text2image_generation-pai-painter-commercial-base-zh", + "task_priority_path": "pai-painter-commercial-base-zh", + }, + }, + "default": { + "model": "pai-painter-painting-base-zh", } } } diff --git a/paddlenlp/taskflow/text2image_generation.py b/paddlenlp/taskflow/text2image_generation.py new file mode 100644 index 000000000000..9fa3ea7e4e5b --- /dev/null +++ b/paddlenlp/taskflow/text2image_generation.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import numpy as np +from PIL import Image +from ..transformers import AutoModelForImageGeneration, AutoTokenizer +from .task import Task + +usage = r""" + from paddlenlp import Taskflow + + text2imagegen = Taskflow("text2image_generation") + images = text2imagegen("风阁水帘今在眼,且来先看早梅红") + images[0].save("figure.png") + + """ + + +class Text2ImageGenerationTask(Task): + """ + The text2image generation model to generate the image. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + + def __init__(self, task, model="pai-painter-painting-base-zh", **kwargs): + super().__init__(task=task, model=model, **kwargs) + self._batch_size = kwargs.get("batch_size", 1) + self._temperature = kwargs.get("temperature", 1.) + self._top_k = kwargs.get("top_k", 32) + self._top_p = kwargs.get("top_p", 1.) + self._condition_scale = kwargs.get("condition_scale", 10.) + self._num_return_images = kwargs.get("num_return_images", 4) + self._use_faster = kwargs.get("use_faster", False) + self._use_fp16_decoding = kwargs.get("use_fp16_decoding", False) + self._construct_tokenizer(model) + self._construct_model(model) + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + self._model = AutoModelForImageGeneration.from_pretrained(model) + self._model.eval() + + def _construct_tokenizer(self, model): + """ + Construct the tokenizer for the predictor. + """ + self._tokenizer = AutoTokenizer.from_pretrained(model) + + def _batchify(self, data, batch_size): + """ + Generate input batches. + """ + + def _parse_batch(batch_examples): + tokenizerd_inputs = self._tokenizer(batch_examples, + return_tensors="pd", + padding="max_length", + truncation=True) + if self._model.base_model_prefix == "dallebart": + tokenizerd_inputs["condition_scale"] = self._condition_scale + return tokenizerd_inputs + + # Seperates data into some batches. + one_batch = [] + for example in data: + one_batch.append(example) + if len(one_batch) == batch_size: + yield _parse_batch(one_batch) + one_batch = [] + if one_batch: + yield _parse_batch(one_batch) + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + inputs = self._check_input_text(inputs) + batches = self._batchify(inputs, self._batch_size) + outputs = {'batches': batches, 'text': inputs} + return outputs + + def _run_model(self, inputs): + """ + Run the task model from the outputs of the `_preprocess` function. + """ + all_images = [] + + for batch_inputs in inputs["batches"]: + images = self._model.generate( + **batch_inputs, + temperature=self._temperature, + top_k=self._top_k, + top_p=self._top_p, + num_return_sequences=self._num_return_images, + use_faster=self._use_faster, + use_fp16_decoding=self._use_fp16_decoding) + all_images.append(images.numpy()) + inputs['images'] = np.concatenate(all_images, axis=0) + return inputs + + def _postprocess(self, inputs): + """ + The model output is images, this function will convert the model output to PIL Image. + """ + batch_out = [] + generated_images = inputs['images'] + # [batch_size, num_return_sequences, 256, 256, 3] -> [batch_size, 256, num_return_sequences*256, 3] + generated_images = generated_images.transpose([0, 2, 1, 3, 4]).reshape([ + -1, generated_images.shape[-3], + self._num_return_images * generated_images.shape[-2], + generated_images.shape[-1] + ]) + for generated_image in generated_images: + batch_out.append(Image.fromarray(generated_image)) + + return batch_out + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + self._input_spec = [ + paddle.static.InputSpec(shape=[None, None], + dtype="int64", + name='input_ids'), + ] diff --git a/paddlenlp/transformers/artist/modeling.py b/paddlenlp/transformers/artist/modeling.py index 7f9946d0bcd6..c86579a76872 100644 --- a/paddlenlp/transformers/artist/modeling.py +++ b/paddlenlp/transformers/artist/modeling.py @@ -212,7 +212,7 @@ def generate(self, Returns: Tensor: Returns tensor `images`, which is the output of :class:`VQGanDetokenizer`. - Its data type should be float32 and has a shape of [batch_size, num_return_sequences, 256, 256, 3]. + Its data type should be uint8 and has a shape of [batch_size, num_return_sequences, 256, 256, 3]. Example: .. code-block:: @@ -228,28 +228,18 @@ def generate(self, # Prepare the model inputs. prompts = ["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"] - tokenized_inputs = tokenizer( - prompts, - return_tensors="pd", - padding="max_length", - truncation=True, - return_token_type_ids=False, - return_attention_mask=False, - max_length=32, - ) + tokenized_inputs = tokenizer(prompts, return_tensors="pd") top_k = 32 num_return_sequences = 4 images = model.generate(**tokenized_inputs, top_k=top_k, num_return_sequences=num_return_sequences) - print(images.shape) - # [2, 4, 256, 256, 3] - images = ((images.cpu().numpy() + 1.0) * 127.5).clip(0, 255).astype("uint8") + print(images.shape) # [2, 4, 256, 256, 3] # [2, 256, 4*256, 3] - images = images.transpose([0, 2, 1, 3, - 4]).reshape(-1, images.shape[-3], + images = images.numpy().transpose([0, 2, 1, 3, + 4]).reshape([-1, images.shape[-3], num_return_sequences * images.shape[-2], - images.shape[-1]) + images.shape[-1]]) for i, image in enumerate(images): image = Image.fromarray(image) image.save(f"figure_{i}.png") @@ -273,4 +263,5 @@ def generate(self, -1, num_return_sequences, images.shape[1], images.shape[2], images.shape[3] ]) - return images + images = ((images + 1.0) * 127.5).clip(0, 255).astype("uint8") + return images \ No newline at end of file diff --git a/paddlenlp/transformers/artist/tokenizer.py b/paddlenlp/transformers/artist/tokenizer.py index befd11849886..5e5a74c3f147 100644 --- a/paddlenlp/transformers/artist/tokenizer.py +++ b/paddlenlp/transformers/artist/tokenizer.py @@ -203,11 +203,11 @@ def __call__( self, text, text_pair=None, - max_length=None, + max_length=32, # default stride=0, is_split_into_words=False, - padding=False, - truncation=False, + padding="max_length", # default + truncation=True, # default return_position_ids=False, return_token_type_ids=False, # don't return token_type_ids return_attention_mask=False, diff --git a/paddlenlp/transformers/dallebart/modeling.py b/paddlenlp/transformers/dallebart/modeling.py index 95b91de30655..ef661acd61bd 100644 --- a/paddlenlp/transformers/dallebart/modeling.py +++ b/paddlenlp/transformers/dallebart/modeling.py @@ -1729,30 +1729,23 @@ def generate(self, sequences for each sequence in the batch. Default to 1. Returns: Tensor: Returns tensor `images`, which is the output of :class:`VQGanDetokenizer`. - Its data type should be float32 and has a shape of [batch_size, num_return_sequences, 256, 256, 3]. + Its data type should be uint8 and has a shape of [batch_size, num_return_sequences, 256, 256, 3]. Example: .. code-block:: import paddle - from paddlenlp.transformers import DalleBartForImageGeneration, DalleBartTokenizer + from paddlenlp.transformers import AutoModelForImageGeneration, AutoTokenizer from PIL import Image # Initialize the model and tokenizer model_name_or_path = 'dalle-mini' - model = DalleBartForImageGeneration.from_pretrained(model_name_or_path) - tokenizer = DalleBartTokenizer.from_pretrained(model_name_or_path) + model = AutoModelForImageGeneration.from_pretrained(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model.eval() # Prepare the model inputs. prompts = ["graphite sketch of Elon Musk", "Mohanlal graphite sketch"] - tokenized_inputs = tokenizer( - prompts, - return_tensors="pd", - padding="max_length", - truncation=True, - return_attention_mask=True, - max_length=64, - ) + tokenized_inputs = tokenizer(prompts, return_tensors="pd") top_k = 32 condition_scale = 16.0 num_return_sequences = 4 @@ -1760,14 +1753,12 @@ def generate(self, top_k=top_k, condition_scale=condition_scale, num_return_sequences=num_return_sequences) - print(images.shape) - # [2, 4, 256, 256, 3] - images = (images.cpu().numpy().clip(0, 1) * 255).astype("uint8") + print(images.shape) # [2, 4, 256, 256, 3] # [2, 256, 4*256, 3] - images = images.transpose([0, 2, 1, 3, - 4]).reshape(-1, images.shape[-3], + images = images.numpy().transpose([0, 2, 1, 3, + 4]).reshape([-1, images.shape[-3], num_return_sequences * images.shape[-2], - images.shape[-1]) + images.shape[-1]]) for i, image in enumerate(images): image = Image.fromarray(image) image.save(f"figure_{i}.png") @@ -1787,4 +1778,5 @@ def generate(self, -1, num_return_sequences, images.shape[1], images.shape[2], images.shape[3] ]) + images = (images.clip(0, 1) * 255).astype("uint8") return images diff --git a/paddlenlp/transformers/dallebart/tokenizer.py b/paddlenlp/transformers/dallebart/tokenizer.py index 49f745ac852f..9be91793ab21 100644 --- a/paddlenlp/transformers/dallebart/tokenizer.py +++ b/paddlenlp/transformers/dallebart/tokenizer.py @@ -490,27 +490,28 @@ def create_token_type_ids_from_sequences(self, return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] - def __call__(self, - text, - text_pair=None, - max_length=None, - stride=0, - is_split_into_words=False, - padding=False, - truncation=False, - return_position_ids=False, - return_token_type_ids=False, - return_attention_mask=False, - return_length=False, - return_overflowing_tokens=False, - return_special_tokens_mask=False, - return_dict=True, - return_offsets_mapping=False, - add_special_tokens=True, - pad_to_multiple_of=None, - return_tensors=None, - verbose: bool = True, - **kwargs): + def __call__( + self, + text, + text_pair=None, + max_length=64, # default + stride=0, + is_split_into_words=False, + padding="max_length", # default + truncation=True, # default + return_position_ids=False, + return_token_type_ids=False, # don't return token_type_ids + return_attention_mask=True, # default + return_length=False, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + return_dict=True, + return_offsets_mapping=False, + add_special_tokens=True, + pad_to_multiple_of=None, + return_tensors=None, + verbose: bool = True, + **kwargs): if self.normalize_text: is_batched = isinstance(text, (list, tuple)) if is_batched: