Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Text2Image into Taskflow #2988

Merged
merged 8 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/model_zoo/taskflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练
| [智能写诗](#智能写诗) | `Taskflow("poetry_generation")` | ✅ | ✅ | ✅ | | | 使用最大中文开源CPM模型完成写诗 |
| [开放域对话](#开放域对话) | `Taskflow("dialogue")` | ✅ | ✅ | ✅ | | | 十亿级语料训练最强中文闲聊模型PLATO-Mini,支持多轮对话 |
| [代码生成](#代码生成) | `Taskflow("code_generation")` | ✅ | ✅ | ✅ | | | 代码生成大模型 |
| [文图生成](#文图生成) | `Taskflow("text2image_generation")` | ✅ | ✅ | ✅ | | | 文图生成大模型 |


## QuickStart
Expand Down Expand Up @@ -1324,6 +1325,40 @@ from paddlenlp import Taskflow
* `output_scores`:是否要输出解码得分,请默认为False。
</div></details>

### 文图生成
<details><summary>&emsp; 通过文图生成模型来生成图片 </summary><div>

#### 支持单条、批量预测

```python
>>> from paddlenlp import Taskflow
# 默认模型为 pai-painter-painting-base-zh
>>> text2imagegen = Taskflow("text2image_generation", model="pai-painter-painting-base-zh")
# 单条输入
>>> images = text2imagegen("风阁水帘今在眼,且来先看早梅红")
# [<PIL.Image.Image image mode=RGB size=1024x256>]
>>> images[0].save("figure.png")
# 多条输入
>>> images = text2imagegen(["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"])
# [<PIL.Image.Image image mode=RGB size=1024x256>,
# <PIL.Image.Image image mode=RGB size=1024x256>]
>>> for i, image in enumerate(images):
>>> image.save(f"figure_{i}.png")
```

#### 可配置参数说明
* `model`:可选模型,默认为`pai-painter-painting-base-zh`,支持的模型有`["pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh", "dalle-mini", "dalle-mega-v16", "dalle-mega"]`。
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
* `temperature`:解码参数temperature,默认为1.0。
* `top_k`:解码参数top_k,默认为32。
* `top_p`:解码参数top_p,默认为1.0。
* `conditional_scale`:dalle-mini模型使用的参数,可参考[推特](https://twitter.com/RiversHaveWings/status/1478093658716966912),默认为10.0。
* `num_return_images`:返回图片的数量,默认为4,即4张图片水平拼接形成一张长图。
* `use_faster`:是否使用faster_generation,默认为False,目前artist模型支持,而dalle-mini模型不支持。
Copy link
Contributor

@guoshengCS guoshengCS Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也尽量隐去 artist 吧,这个模型名确实不见得有认可度,先用 pai这些模型权重来体现吧

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_faster:是否使用faster_generation,默认为False,目前支持faster_generation的模型有["pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh"]

* `use_fp16_decoding`:是否使用fp16加速解码过程,默认为False,只有当use_faster为True的时候才有效。

</div></details>

## PART Ⅱ &emsp; 定制化训练

<details><summary>适配任务列表</summary><div>
Expand Down
41 changes: 41 additions & 0 deletions paddlenlp/taskflow/taskflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .dialogue import DialogueTask
from .information_extraction import UIETask
from .code_generation import CodeGenerationTask
from .text2image_generation import Text2ImageGenerationTask

warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)

Expand Down Expand Up @@ -317,6 +318,46 @@
},
"default": {
"model": "Salesforce/codegen-350M-mono",
},
},
"text2image_generation": {
"models": {
"dalle-mini": {
"task_class": Text2ImageGenerationTask,
"task_flag": "text2image_generation-dalle-mini",
"task_priority_path": "dalle-mini",
},
"dalle-mega-v16": {
"task_class": Text2ImageGenerationTask,
"task_flag": "text2image_generation-dalle-mega-v16",
"task_priority_path": "dalle-mega-v16",
},
"dalle-mega": {
"task_class": Text2ImageGenerationTask,
"task_flag": "text2image_generation-dalle-mega",
"task_priority_path": "dalle-mega",
},
"pai-painter-painting-base-zh": {
"task_class": Text2ImageGenerationTask,
"task_flag":
"text2image_generation-pai-painter-painting-base-zh",
"task_priority_path": "pai-painter-painting-base-zh",
},
"pai-painter-scenery-base-zh": {
"task_class": Text2ImageGenerationTask,
"task_flag":
"text2image_generation-pai-painter-scenery-base-zh",
"task_priority_path": "pai-painter-scenery-base-zh",
},
"pai-painter-commercial-base-zh": {
"task_class": Text2ImageGenerationTask,
"task_flag":
"text2image_generation-pai-painter-commercial-base-zh",
"task_priority_path": "pai-painter-commercial-base-zh",
},
},
"default": {
"model": "pai-painter-painting-base-zh",
}
}
}
Expand Down
161 changes: 161 additions & 0 deletions paddlenlp/taskflow/text2image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from PIL import Image
from ..transformers import AutoModelForImageGeneration, AutoTokenizer
from .task import Task

usage = r"""
from paddlenlp import Taskflow

text2imagegen = Taskflow("text2image_generation")
images = text2imagegen("风阁水帘今在眼,且来先看早梅红")
images[0].save("figure.png")

"""

tokenizer_kwargs = {
"dallebart": {
"max_length": 64,
"return_token_typd_ids": False,
"return_attention_mask": True
},
"gpt": {
"max_length": 32,
"return_token_typd_ids": False,
"return_attention_mask": False
},
}


class Text2ImageGenerationTask(Task):
"""
The text2image generation model to generate the image.
Args:
task(string): The name of task.
model(string): The model name in the task.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""

def __init__(self, task, model="pai-painter-painting-base-zh", **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._batch_size = kwargs.get("batch_size", 1)
self._temperature = kwargs.get("temperature", 1.)
self._top_k = kwargs.get("top_k", 32)
self._top_p = kwargs.get("top_p", 1.)
self._condition_scale = kwargs.get("condition_scale", 10.)
self._num_return_images = kwargs.get("num_return_images", 4)
self._use_faster = kwargs.get("use_faster", False)
self._use_fp16_decoding = kwargs.get("use_fp16_decoding", False)
self._construct_tokenizer(model)
self._construct_model(model)

def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
self._model = AutoModelForImageGeneration.from_pretrained(model)
self._model.eval()

def _construct_tokenizer(self, model):
"""
Construct the tokenizer for the predictor.
"""
self._tokenizer = AutoTokenizer.from_pretrained(model)

def _batchify(self, data, batch_size):
"""
Generate input batches.
"""

def _parse_batch(batch_examples):
tokenizerd_inputs = self._tokenizer(
batch_examples,
return_tensors="pd",
padding="max_length",
truncation=True,
**tokenizer_kwargs[self._model.base_model_prefix])
if self._model.base_model_prefix == "dallebart":
tokenizerd_inputs["condition_scale"] = self._condition_scale
return tokenizerd_inputs

# Seperates data into some batches.
one_batch = []
for example in data:
one_batch.append(example)
if len(one_batch) == batch_size:
yield _parse_batch(one_batch)
one_batch = []
if one_batch:
yield _parse_batch(one_batch)

def _preprocess(self, inputs):
"""
Transform the raw text to the model inputs, two steps involved:
1) Transform the raw text to token ids.
2) Generate the other model inputs from the raw text and token ids.
"""
inputs = self._check_input_text(inputs)
batches = self._batchify(inputs, self._batch_size)
outputs = {'batches': batches, 'text': inputs}
return outputs

def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_preprocess` function.
"""
all_images = []

for batch_inputs in inputs["batches"]:
images = self._model.generate(
**batch_inputs,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p,
num_return_sequences=self._num_return_images,
use_faster=self._use_faster,
use_fp16_decoding=self._use_fp16_decoding).cpu().numpy()
if self._model.base_model_prefix == "dallebart":
images = (images.clip(0, 1) * 255).astype("uint8")
elif self._model.base_model_prefix == "gpt":
images = ((images + 1.0) * 127.5).clip(0, 255).astype("uint8")
Copy link
Contributor

@guoshengCS guoshengCS Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些后处理如果是和模型绑定的话后面也可以看看将这些方法放在transformers模型代码中提供出来,这里进行统一调用,方便更多模型扩展

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 现已添加到model的generate内部。

for image in images:
all_images.append(image)
inputs['images'] = all_images
return inputs

def _postprocess(self, inputs):
"""
The model output is images, this function will convert the model output to PIL Image.
"""
batch_out = []
generated_images = inputs['images']
for generated_image in generated_images:
generated_image = generated_image.transpose([1, 0, 2, 3]).reshape(
generated_image.shape[-3],
self._num_return_images * generated_image.shape[-2],
generated_image.shape[-1])
batch_out.append(Image.fromarray(generated_image))

return batch_out

def _construct_input_spec(self):
"""
Construct the input spec for the convert dygraph model to static model.
"""
self._input_spec = [
paddle.static.InputSpec(shape=[None, None],
dtype="int64",
name='input_ids'),
]