Skip to content

Commit

Permalink
Add AltDiffusion (open-mmlab#1299)
Browse files Browse the repository at this point in the history
* add conversion script for vae

* up

* up

* some fixes

* add text model

* use the correct config

* add docs

* move model in it's own file

* move model in its own file

* pass attenion mask to text encoder

* pass attn mask to uncond inputs

* quality

* fix image2image

* add imag2image in init

* fix import

* fix one more import

* fix import, dummy objetcs

* fix copied from

* up

* finish

Co-authored-by: patil-suraj <surajp815@gmail.com>
  • Loading branch information
patrickvonplaten and patil-suraj authored Nov 15, 2022
1 parent 4625f04 commit 8a73064
Show file tree
Hide file tree
Showing 17 changed files with 1,486 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
- sections:
- local: api/pipelines/overview
title: "Overview"
- local: api/pipelines/alt_diffusion
title: "AltDiffusion"
- local: api/pipelines/cycle_diffusion
title: "Cycle Diffusion"
- local: api/pipelines/ddim
Expand Down
83 changes: 83 additions & 0 deletions docs/source/api/pipelines/alt_diffusion.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# AltDiffusion

AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu

The abstract of the paper is the following:

*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.*


*Overview*:

| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | -
| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |-

## Tips

- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion).

- *Run AltDiffusion*

AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img).

- *How to load and use different schedulers.*

The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:

```python
>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler

>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=euler_scheduler)
```
- *How to conver all use cases with multiple or single pipeline*
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
```python
>>> from diffusers import (
... AltDiffusionPipeline,
... AltDiffusionImg2ImgPipeline,
... )
>>> img2text = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
>>> img2img = AltDiffusionImg2ImgPipeline(**img2text.components)
>>> # now you can use img2text(...) and img2img(...) just like the call methods of each respective pipeline
```

## AltDiffusionPipelineOutput
[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput

## AltDiffusionPipeline
[[autodoc]] AltDiffusionPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

## AltDiffusionImg2ImgPipeline
[[autodoc]] AltDiffusionImg2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
1 change: 1 addition & 0 deletions docs/source/api/pipelines/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ available a colab notebook to directly try them out.

| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ available a colab notebook to directly try them out.

| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
Expand Down
2 changes: 0 additions & 2 deletions docs/source/using-diffusers/conditional_image_generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,3 @@ You can save the image by simply calling:
```python
>>> image.save("image_of_squirrel_painting.png")
```


2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@

if is_torch_available() and is_transformers_available():
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_transformers_available():
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
Expand Down
34 changes: 34 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np

import PIL
from PIL import Image

from ...utils import BaseOutput, is_torch_available, is_transformers_available


@dataclass
# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
class AltDiffusionPipelineOutput(BaseOutput):
"""
Output class for Alt Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]


if is_transformers_available() and is_torch_available():
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_alt_diffusion import AltDiffusionPipeline
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
110 changes: 110 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn

from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
from transformers.utils import ModelOutput


@dataclass
class TransformationModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

projection_state: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(
self,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
project_dim=512,
pooler_fn="cls",
learn_encoder=False,
use_attention_mask=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
self.use_attention_mask = use_attention_mask


class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
base_model_prefix = "roberta"
config_class = RobertaSeriesConfig

def __init__(self, config):
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
self.post_init()

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
r""" """

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

projection_state = self.transformation(outputs.last_hidden_state)

return TransformationModelOutput(
projection_state=projection_state,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading

0 comments on commit 8a73064

Please sign in to comment.