Skip to content

Commit

Permalink
add support for image generation and interleaved image-text generatio…
Browse files Browse the repository at this point in the history
…n with Chameleon & its finetunes like Anole
  • Loading branch information
leloykun committed Aug 25, 2024
1 parent b252643 commit dae439c
Show file tree
Hide file tree
Showing 12 changed files with 955 additions and 73 deletions.
148 changes: 147 additions & 1 deletion docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ The original code can be found [here](https://github.com/facebookresearch/chamel

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating.

- When generating images, we advice users to load the model in `bfloat16` for better results. Simply make sure to set `torch_dtype=torch.bfloat16` when loading the model.

- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question.

- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor.

> [!NOTE]
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`. You have to add `<image>` to your prompt in the place where the image should be embedded for correct generation.
> [!NOTE]
> The official model checkpoint currently only supports text generation. To generate images and interleaved text-image responses, you can use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135). Note however that Anole has a bias for "empty" or background patches, so it is recommended to use sampling when generating images (i.e. setting `do_sample=True` during generation) to reduce the likelihood of generating a blank image.
## Usage example

### Single image inference
Expand Down Expand Up @@ -117,13 +122,154 @@ prompts = [

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
inputs = processor(
text=prompts,
images=[image_stop, image_cats, image_snowman],
padding=True,
return_tensors="pt",
).to(device="cuda", dtype=torch.bfloat16)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
```

### Text to image generation

Chameleon can also generate images. However, the official model checkpoint currently only supports text generation. We need to use finetuned versions such as [Anole](https://arxiv.org/abs/2407.06135) to do image generation. Here is how you can do it:

```python
import torch
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Prepare a prompt
prompt = "Generate an image of a snowman."

# Preprocess the prompt
inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype)

# Generate discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="image-only",
# Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
max_new_tokens=1026,
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(response_ids[:, 1:-1])
images = processor.postprocess_pixel_values(pixel_values)

# Save the image
images[0].save("snowman.png")
```

### Text-image to image generation

We can also interleave text and images in the prompt to generate images. Here is how you can do it:

```python
import requests

import torch
from PIL import Image
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
from transformers.image_transforms import to_pil_image

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Get image of a snowman
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image_snowman = Image.open(requests.get(url, stream=True).raw)

# Prepare a prompt
prompt = "Generate a variation of this image.<image>"

# Preprocess the prompt
inputs = processor(
prompt,
images=[image_snowman],
padding=True,
return_tensors="pt",
).to(model.device, dtype=model.dtype)

# Generate discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="image-only",
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]

# The generated image tokens are wrapped by the `image_start_token` and `image_end_token` tokens. We need to remove them before decoding the image tokens.
image_token_ids = response_ids[:, 1:-1]

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(image_token_ids)
pixel_values = processor.postprocess_pixel_values(pixel_values)

# Save the image
image = to_pil_image(pixel_values[0].detach().cpu())
image.save("snowman.png")
```

### Interleaved text-image generation

We can also generate interleaved text and images in the output. Here is how you can do it:

```python
import torch
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration

processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Prepare a prompt
prompt = "Can you draw a snowman and explain how to build one?"

# Preprocess the prompt
inputs = processor(prompt, padding=True, return_tensors="pt").to(model.device, dtype=model.dtype)

# Generate interleaved text and discrete image tokens
generate_ids = model.generate(
**inputs,
multimodal_generation_mode="interleaved-text-image",
# Note: We will need a larger `max_new_tokens` value since we are generating both text and image tokens.
max_new_tokens=4096,
# This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
do_sample=True,
)

# Only keep the tokens from the response
response_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]
```

From here, you can split the response tokens into text and image token segments, decode them separately as shown in the previous examples, and finally render the resulting text and images together. You can also use [MMSG](https://github.com/leloykun/mmsg) to do this more easily.

## Model optimization

### Quantization using Bitsandbytes
Expand Down
164 changes: 143 additions & 21 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,38 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
class SuppressTokensInIndexRangeLogitsProcessor(LogitsProcessor):
r"""
[`SuppressTokensInIndexRangeLogitsProcessor`] supresses a list of tokens from `start_index` to `end_index` (exclusive)
Args:
suppress_tokens (`List[int]`):
List of token ids to suppress during generation.
start_index (`int`):
The index at which to start suppressing tokens.
end_index (`int`, *optional*):
The index at which to end suppressing tokens. If `None`, it will suppress tokens indefinitely.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors.
"""

def __init__(
self, suppress_tokens: List[int], start_index: int, end_index: Optional[int] = None, device: str = "cpu"
):
self.suppress_tokens = torch.tensor(suppress_tokens, device=device)
self.start_index = start_index
self.end_index = end_index if end_index is not None else math.inf

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
current_index = input_ids.shape[1]
if self.start_index > current_index or current_index > self.end_index:
return scores
suppress_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
suppress_tokens_mask[:, self.suppress_tokens] = True
return scores.masked_fill(suppress_tokens_mask, torch.finfo(scores.dtype).min)


class SuppressTokensAtBeginLogitsProcessor(SuppressTokensInIndexRangeLogitsProcessor):
r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
Expand Down Expand Up @@ -1786,24 +1817,17 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
"""

def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
super().__init__(begin_suppress_tokens, begin_index, begin_index + 1, device=device)
self.begin_index = begin_index

def set_begin_index(self, begin_index):
self.start_index = begin_index
self.end_index = begin_index + 1
# Keeping this here for backwards compatibility
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
scores_processed = scores
if input_ids.shape[-1] == self.begin_index:
scores_processed = torch.where(suppress_token_mask, -float("inf"), scores)

return scores_processed


class SuppressTokensLogitsProcessor(LogitsProcessor):
class SuppressTokensLogitsProcessor(SuppressTokensInIndexRangeLogitsProcessor):
r"""
This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
that they are not generated. Originally created for
Expand Down Expand Up @@ -1833,14 +1857,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
"""

def __init__(self, suppress_tokens, device: str = "cpu"):
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores
super().__init__(suppress_tokens, 0, device=device)


class WhisperTimeStampLogitsProcessor(LogitsProcessor):
Expand Down Expand Up @@ -2449,3 +2466,108 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias

return scores_processed


class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
r"""
[`AllowOnlyTokensAtRelativeOffsetLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens
that can be generated at a relative offset from a trigger token (e.g. begin image token). If `exclusive` is set to
`True`, the set of tokens allowed at this offset will not be allowed anywhere else. This is useful for enforcing
multimodal generation constraints with begin and end marker tokens.
Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon).
Args:
trigger_token_id (`int`):
The token id that triggers the offset check.
allowed_token_ids (`List[int]`):
The list of token ids that are allowed at the specified offset.
offset (`int`):
The relative offset from the trigger token.
exclusive (`bool`, *optional*, defaults to `False`):
If `True`, the set of tokens allowed at this offset will not be allowed anywhere else.
device (`str`, *optional*, defaults to `cpu`):
The device to allocate the util tensor on.
"""

def __init__(
self,
trigger_token_id: int,
allowed_token_ids: List[int],
offset: int,
exclusive: bool = False,
device: str = "cpu",
):
self.trigger_token_id = trigger_token_id
self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device)
self.offset = offset
self.exclusive = exclusive

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] < self.offset and not self.exclusive:
return scores

disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool)
disallowed_tokens_mask[:, self.allowed_token_ids] = False

if input_ids.shape[1] < self.offset:
return scores.masked_fill(~disallowed_tokens_mask, torch.finfo(scores.dtype).min)

trigger_positions = (input_ids[:, -self.offset] == self.trigger_token_id).unsqueeze(-1)

if self.exclusive:
return scores.masked_fill(~(disallowed_tokens_mask ^ trigger_positions), torch.finfo(scores.dtype).min)
return scores.masked_fill(disallowed_tokens_mask & trigger_positions, torch.finfo(scores.dtype).min)


class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
r"""
[`AllowOnlyTokensInRelativeWindowLogitsProcessor`] suppresses the logits of tokens aside from a specific set of tokens
that can be generated at a relative window from a trigger token (e.g. begin image token). If `exclusive` is set to
`True`, the set of tokens allowed at this window will not be allowed anywhere else. This is useful for enforcing
multimodal generation constraints.
Originally created for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon).
Args:
trigger_token_id (`int`):
The token id that triggers the window check.
allowed_token_ids (`List[int]`):
The list of token ids that are allowed at the specified relative window.
window_width (`int`):
The window_width of the window from the trigger token.
exclusive (`bool`, *optional*, defaults to `False`):
If `True`, the set of tokens allowed at this window will not be allowed anywhere else.
device (`str`, *optional*, defaults to `cpu`):
The device to allocate the util tensor on.
"""

def __init__(
self,
trigger_token_id: int,
allowed_token_ids: List[int],
window_width: int,
exclusive: bool = False,
device: str = "cpu",
):
self.trigger_token_id = trigger_token_id
self.allowed_token_ids = torch.tensor(allowed_token_ids, device=device).unsqueeze(0)
self.window_width = window_width
self.exclusive = exclusive

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
window_width = min(self.window_width, input_ids.shape[1])
trigger_positions = (input_ids[:, -window_width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)

disallowed_tokens_mask = torch.ones_like(scores, dtype=torch.bool)
disallowed_tokens_mask[:, self.allowed_token_ids] = False

if self.exclusive:
return scores.masked_fill(
~(disallowed_tokens_mask ^ trigger_positions),
torch.finfo(scores.dtype).min,
)
return scores.masked_fill(
disallowed_tokens_mask & trigger_positions,
torch.finfo(scores.dtype).min,
)
14 changes: 7 additions & 7 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,13 +1276,13 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de

def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
has_default_min_length,
model_input_name,
input_ids_length,
inputs_tensor,
):
generation_config: GenerationConfig,
has_default_max_length: bool,
has_default_min_length: bool,
model_input_name: str,
input_ids_length: int,
inputs_tensor: torch.Tensor,
) -> GenerationConfig:
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""

if generation_config.max_new_tokens is not None:
Expand Down
Loading

0 comments on commit dae439c

Please sign in to comment.