Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for image generation with Chameleon & Anole #32013

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

leloykun
Copy link
Contributor

@leloykun leloykun commented Jul 17, 2024

What does this PR do?

  • Adds modelling for the VQVAE decoder & also includes it in the conversion script.
  • Adds support for decoding the BPE tokens -> discrete image tokens -> pixel values
  • Moves masking of image tokens in text-only generation mode to a LogitsProcessor.
  • Adds masking of non-image tokens for image-only generation mode.
  • Reimplements Chameleon's FSM to be more compatible with Transformers and Outlines (for structured generation)
    • This PR will not add the FSM, but instead just makes it easier for external libraries like Outlines & MMSG to integrate with Transformers to add the interleaved generation mode back in
    • Reimplements Chameleon's Finite-State Machine that it uses to dynamically switch between text- and image-generation modes as Logits Processors. We can now support interleaved text-image generation natively.

Required TODOs:

  • Improve docs
    • Write docs for image-only generation with Chameleon/Anole
      • Provide minimal example on how to use
      • Remove the need for passing max_length or max_new_tokens on image-only generation mode
      • Force the model to generate at least one image on image-only generation mode. Chameleon doesn't officially support image-generation yet so it just immediately closes begin-image-tokens with either an end-image-token or an EOS token. And finetunes like Anole haven't fully removed this issue yet so they occasionally still does that.
    • Write docs for interleaved text-image generation
      • Provide minimal example on how to use
      • Show how to split token sequences by modality in the example for interleaved text-image generation
    • Improve docs (as comments) for newly-added logits processors
    • Add sample usage for newly-added logits processors
  • Add tests
    • Add test for image postprocessing
    • Add tests for each generation mode
      • Add tests for text-only generation mode
      • Add tests for image-only generation mode
        • Note: don't rely on hashing pytorch tensors to compare arrays
        • Add test where max_new_tokens is unset on image-only generation mode
      • Add tests for interleaved-text-image generation mode
      • Add tests for unrestricted generation mode
      • Add tests for invalid generation mode
    • Add tests for VQVAE decoder
    • Add tests for each newly-added logits processors
    • Add tests for multi-GPU model sharding
  • Improve modelling of VQVAE encoder & decoder
  • VQVAE: dynamically compute quant_state_flattened_dims which scales with the resolution instead of hardcoding it in the configs
  • Improve postprocessing: only accept and return pytorch tensors
  • Logits processors
    • Add new logits processors to import structure

Optional TODOs or for future PRs:

  • Run a hyperparameter search for image generation
  • Fix bugs caused by sharding the model into multiple GPUs
  • Implement features that were in the Chameleon paper but are not crucial here
  • Refactor VQVAE (sub-)modules
    • Convert mid, down, and up blocks into explicit subclasses of nn.Module() (as suggested by @amyeroberts)
    • Make it clearer which attention types are allowed in the VQVAE (sub-)modules. Tho Chameleon currently only supports "vanilla" attention.
  • Implement support for other finetunes of Chameleon

Links:

(partially) Implements # (issue)

@ArthurZucker @zucchini-nlp @JoyBoy-Su

@leloykun
Copy link
Contributor Author

@zucchini-nlp @ArthurZucker this should now be ready for review

the test errors seem to be related to huggingface_hub & bert & I'm now sure how they relate to this PR.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Looks good to me in general, the only thing is to make generation happy by moving code to the correct location and checking if we can guide users through interleaved generation with external FSM library

Also, we need tests for different generation modes, to make sure it's working correctly. Thsi can be added as a slow IntegrationTest in tests/models/chameleon/test_modeling_chameleon.py

docs/source/en/model_doc/chameleon.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/chameleon.md Outdated Show resolved Hide resolved
src/transformers/image_transforms.py Show resolved Hide resolved
src/transformers/models/chameleon/modeling_chameleon.py Outdated Show resolved Hide resolved
Comment on lines 333 to 373
> Parameters specific to vision-language generation models such as [Chameleon](https://arxiv.org/abs/2405.09818v1)

multimodal_generation_mode (`Literal["text-only", "image-only", "interleaved-text-image", "free"]`, *optional*, defaults to `None`):
Chameleon can generate text, images, or both in an interleaved manner. However, only text generation is
supported by the official model checkpoint. This flag enables the other modes for use with finetuned versions
of the model such as [Anole](https://arxiv.org/abs/2407.06135).
- If set to `"text-only"`, logits for image tokens will be masked out during generation.
- If set to `"image-only"`, logits for non-image tokens will be masked out during generation.
- If set to `"free"`, the logits are left as-is.
- For `"interleaved-text-image"`, Chameleon implements a finite state machine to dynamically switch between text and image modalities.
This library does not support this mode yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I wasn't clear, I meant Chameleon's generation config that is on the hub. Adding args to the general generation config is not a good idea, if it's going to be used only by one model.

Let's see how we can make generate() happy. We can:

  • When saving the model, add a field model.generation_config="text-only" by setting the default to text mode and saving it on the hub
  • Move Chameleon logits processor to the other processors, more comments below
  • Add a generate() method in ConditionalGeneration module that does model-specific preparation, in our case takes generation mode and makes a LogitsProcessor out of it. Then calls super().generate() with all kwargs
  • Optionally, take generation output and run decode_tokens if image mode. In case you can make an example with external library for FSM and interleaved model, do separation of image from text here. And return a custom GenerationDecoderOnlyOutput, which will have an extra field for pixel values

Usually I would find a way w/o custom generate, but Chameleon can be an exception given that it's the only model that generates images. Also, custom generate in this way is less prone to bugs from refactoring general generate(), as we only prepare and pass a model-specific processor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zucchini-nlp!

I've:

  1. Implemented a new generation config for Chameleon
  2. Modified the utils to support custom generation config classes
  3. Added a custom generate func to ChameleonForConditionalGeneration
class ChameleonGenerationConfig(GenerationConfig):
    """Generation Config for [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon)

    Args:
        multimodal_generation_mode (`Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"]`, *optional*, defaults to `None`):
            Chameleon can generate text, images, or both in an interleaved manner. However, only text generation is
            supported by the official model checkpoint. This flag enables the other modes for use with finetuned versions
            of the model such as [Anole](https://arxiv.org/abs/2407.06135).
            - If set to `"unrestricted"`, the logits are left as-is.
            - If set to `"text-only"`, logits for image tokens will be masked out during generation.
            - If set to `"image-only"`, logits for non-image tokens will be masked out during generation.
            - For `"interleaved-text-image"`, Chameleon implements a finite state machine to dynamically switch between text and image modalities.
                Here, we simply use logits processors that exclusively allow image tokens to be generated within a relative window after the
                begin image token and disallow them elsewhere.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.multimodal_generation_mode = kwargs.pop("multimodal_generation_mode", "text-only")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multimodal_generation_mode is now also in the converted model here: https://huggingface.co/leloy/Anole-7b-v0.1-hf/blob/main/generation_config.json

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, I guess that will be the official checkpoint after merging the PR right? We can then use it in IntegrationTests

@leloykun
Copy link
Contributor Author

@zucchini-nlp this should now be finished I think

the failing test seems to have been caused by the issue here: #32094
which is unrelated to this PR

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model ❤️

Looks good to me, but we still don't have tests for image-only generation and interleaved-generation modes. We have to make sure the added model works correctly. I'm approving the PR, and will request review from core maintainers

@thaoshibe
Copy link

Hi @leloykun , thank you for aweesome work!

I run your example, but all the output are black... Is there anything missing here?
image

Thank you!!

@leloykun
Copy link
Contributor Author

Thanks for the feedback, @thaoshibe!

Apparently, we just need to enable sampling during generation (by passing do_sample=True to .generate). If I'm not mistaken, this is because most of the image tokens during training were for "empty" patches. So, greedy decoding of image tokens wouldn't work well.

I've also just updated the docs. Please let me know if you encounter any more issues.

One more thing: loading the model in bfloat16 (the dtype used for finetuning Anole) also seems to improve generation. See:

model = ChameleonForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            attn_implementation="flash_attention_2",
            device_map="auto",
)

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jul 23, 2024

We can actually add those to generation config after uploading model to the hub. We did same for official Chameleon, it performs best with do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2

Same for bf16, can go the the docs in usage tips section. Then we can change all example snippets to bf16

@thaoshibe
Copy link

image

Got it -- Thank youu @leloykun -- I ran your code and I got the correct output :D

@leloykun
Copy link
Contributor Author

We can actually add those to generation config after uploading model to the hub. We did same for official Chameleon, it performs best with do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2

Same for bf16, can go the the docs in usage tips section. Then we can change all example snippets to bf16

@zucchini-nlp , increasing the repetition penalty might not be good for image generation cuz a lot of the image tokens are repeating (e.g. snow tokens when generating a snowman).

I'll only include the others for now then run a hp search

@leloykun
Copy link
Contributor Author

The test errors are all huggingface hub related hnng

@amyeroberts
Copy link
Collaborator

@leloykun Just triggered a re-run. Hopefully just transient issues!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on adding this capability and adding examples!

Main comments are to do with properly testing this new feature, taking and returning torch tensors for the post processing method, and the max_new_tokens behaviour for generating images

src/transformers/image_transforms.py Show resolved Hide resolved
src/transformers/image_transforms.py Outdated Show resolved Hide resolved
src/transformers/image_transforms.py Outdated Show resolved Hide resolved
src/transformers/image_transforms.py Outdated Show resolved Hide resolved
if (
config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the possible values of config.attn_type? I'm a bit worried this can be confused with attn_implementation, a standard config param

if i_level != 0:
up.upsample = ChameleonVQVAEDecoderConvUpsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't appending also give a consistent order? Inserting in the 0th index will be more expensive and then we don't need to reverse things in the forward

Comment on lines +1059 to +1066
up = nn.Module()
up.block = block
up.attn = attn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an indication to me that we should have another class ChameleonVQVAEBlock which within its init sets block and attn

self,
generation_config: Optional[GenerationConfig] = None,
multimodal_generation_mode: Optional[
Literal["text-only", "image-only", "interleaved-text-image", "unrestricted"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good for documentation, but doesn't provide validation on this inputs. It would be good to add a check to make sure it's one of the accepted values if specified

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts

we do raise an error a few lines further in if the value isn't recognized:

    else:
        raise ValueError(
            f"Unknown multimodal generation mode: {generation_config.multimodal_generation_mode}. Please choose one of 'unrestricted', 'text-only', 'image-only', or 'interleaved-text-image'."
        )

that'd suffice, right? Or should I move this to the start of the func?

return generation_config, model_kwargs

@torch.no_grad()
def generate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the generation models "text-only", "image-only", "interleaved-text-image", "unrestricted" should be tested in the model tests

Copy link

@minostauros minostauros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments about running the example in the model doc.

docs/source/en/model_doc/chameleon.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/chameleon.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/chameleon.md Show resolved Hide resolved
processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = ChameleonForConditionalGeneration.from_pretrained(
"leloy/Anole-7b-v0.1-hf",
device_map="auto",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example failed in my environment with 4 gpus, complaining about device unmatch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minostauros can you provide the script you used for this? The complete error message would also help.

Thank you!

Copy link

@minostauros minostauros Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to remove the device_map="auto" and manually send the model to specific cuda to properly run the code.

>>> import accelerate
>>> accelerate.__version__
'0.30.1'
>>> import torch
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> from PIL import Image
>>> 
>>> processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
Some kwargs in processor config are unused and will not have any effect: image_token, image_seq_length. 
>>> model = ChameleonForConditionalGeneration.from_pretrained(
...     "leloy/Anole-7b-v0.1-hf",
...     device_map="auto",
...     torch_dtype=torch.bfloat16,
... )
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.11it/s]
>>> model.device
device(type='cuda', index=0)
>>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
>>> image_snowman = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a variation of this image.<image>"
>>> inputs = processor(
...     prompt,
...     images=[image_snowman],
...     padding=True,
...     return_tensors="pt",
... ).to(model.device, dtype=model.dtype)
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
    return super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
    image_tokens = self.get_image_tokens(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1427, in get_image_tokens
    return self.img2bpe_mapping_tensor[image_toks]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
>>> inputs.input_ids.device
device(type='cuda', index=0)
>>> inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
>>> inputs.pixel_values.device
device(type='cuda', index=0)
>>> model = model.cuda()
You shouldn't move a model that is dispatched using accelerate hooks.
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1821, in generate
    return super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
  File "/workspace/Github/transformers_anole/src/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1881, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1491, in forward
    image_tokens = self.get_image_tokens(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1426, in get_image_tokens
    _, _, image_toks = self.vqmodel.encode(pixel_values)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 1159, in encode
    hidden_states = self.encoder(pixel_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/workspace/Github/transformers_anole/src/transformers/models/chameleon/modeling_chameleon.py", line 979, in forward
    hidden_states = [self.conv_in(pixel_values)]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)
>>> model = ChameleonForConditionalGeneration.from_pretrained(
...     "leloy/Anole-7b-v0.1-hf",
...     torch_dtype=torch.bfloat16,
... ).to(device=0)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.02it/s]
>>> generate_ids = model.generate(
...     **inputs,
...     multimodal_generation_mode="image-only",
...     # Note: We need to set `max_new_tokens` to 1026 since the model generates the `image_start_token` marker token first, then 1024 image tokens, and finally the `image_end_token` marker token.
...     max_new_tokens=1026,
...     # This is important because most of the image tokens during training were for "empty" patches, so greedy decoding of image tokens will likely result in a blank image.
...     do_sample=True,
... )
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
>>> generate_ids.shape
torch.Size([1, 2062])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating accelerate to 0.33.0 did not help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minostauros does this happen with the base Chameleon model? I.e. without this PR?

The issue with F.conv2d may be unrelated to this PR but the issue with return self.img2bpe_mapping_tensor[image_toks] definitely is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I've never seen this happen before but I suspect it's because of the .float() (iirc, to_pil_image rescales the numpy array if it's of float type). What happens if you remove it or cast the array to uint8?

btw, I wouldn't be able to run tests myself for the next few hours as I'm still traveling

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you remove it or cast the array to uint8?

Great point!

# Decode the generated image tokens
pixel_values = model.decode_image_tokens(response_ids[:, 1:-1])
images = processor.postprocess_pixel_values(pixel_values)

# Save the image
from torchvision.transforms.functional import to_pil_image
images = [to_pil_image(img.detach().cpu()) for img in images]
images[0].save("snowman.png")

snowman2

Perhaps just removing the 255 scaling and type casting in ChameleonImagProcessor.postprocess() may also support torchvision.utils.save_image().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output after postprocessing should have the same shape, range, and dtype as the original image so it's better to keep it this way IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also just added a test for model sharding btw

pls check it out!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code now works like a charm! Thanks a lot for your contribution.
Besides, the output does not seem as good as the anole paper states.

prompt: 'A piece of paper with word like "Anole" written on it, and a drawing of an Anole.'

  • from paper
image
  • from "leloy/Anole-7b-v0.1-hf"
image

How may I improve the results?

@leloykun
Copy link
Contributor Author

leloykun commented Aug 5, 2024

Some comments about running the example in the model doc.

Thanks @minostauros! I'll make sure to fix these in the next commit & tag you when it's ready

@leloykun leloykun marked this pull request as draft August 8, 2024 18:06
@leloykun leloykun requested a review from minostauros August 9, 2024 12:48
@minostauros
Copy link

Supporting Lumina-mGPT may be the next PR!

@leloykun
Copy link
Contributor Author

leloykun commented Aug 9, 2024

Supporting Lumina-mGPT may be the next PR!

dang, this looks cool

does it support interleaved text-image generation too?

@minostauros
Copy link

minostauros commented Aug 9, 2024

does it support interleaved text-image generation too?

Sadly none of their examples show interleaved text-image generation while it is mentioned the model is trained with interleaved data.

image

One interesting approach is the prompt is aware of image resolutions.

image

And here's a unique generation parameter handling of Lumina-mGPT
image

@leloykun
Copy link
Contributor Author

hmmm

one crucial difference is that Chameleon uses classifier free guidance while this doesn't

I'll look into implementing it, but I think I'm gonna need help with that

@leloykun
Copy link
Contributor Author

leloykun commented Sep 2, 2024

Hi, I find a bug, if u use ChameleonMoeForConditionalGeneration and calculate ce loss, u should use new special image tokens encoded by vqmodel to update the labels for ce loss calculation~ @leloykun

That doesn't sound right...

We should calculate the CE loss using the BPE-compatible tokens (i.e. the tokens compatible with Chameleon's tokenizer). That's because those are the outputs of the decoder model.

Pls check the img2bpe & bpe2img converter utils

@YeLuoSuiYou
Copy link

YeLuoSuiYou commented Sep 2, 2024

Hi, I find a bug, if u use ChameleonForConditionalGeneration and calculate ce loss, u should use new special image tokens encoded by vqmodel to update the labels for ce loss calculation~ @leloykun

That doesn't sound right...

We should calculate the CE loss using the BPE-compatible tokens (i.e. the tokens compatible with Chameleon's tokenizer). That's because those are the outputs of the decoder model.

Pls check the img2bpe & bpe2img converter utils

Thanks for reply.
Yes, we should calculate CE loss using BPE tokens instead of "image" (id is 8711), but we only can get BPE token in the ChameleonModel and not return the input ids in the output to update the labels in ChameleonForConditionalGeneration. So my suggestion is that when using ChameleonForConditionalGeneration, we can get BPE tokens encoded by vqmodel in ChameleonForConditionalGeneration forward function to update input_ids and labels instead of in ChameleonModel

@leloykun
Copy link
Contributor Author

leloykun commented Sep 2, 2024

Hi, I find a bug, if u use ChameleonForConditionalGeneration and calculate ce loss, u should use new special image tokens encoded by vqmodel to update the labels for ce loss calculation~ @leloykun

That doesn't sound right...

We should calculate the CE loss using the BPE-compatible tokens (i.e. the tokens compatible with Chameleon's tokenizer). That's because those are the outputs of the decoder model.

Pls check the img2bpe & bpe2img converter utils

Thanks for reply.
Yes, we should calculate CE loss using BPE tokens instead of "image" (id is 8711), but we only can get BPE token in the ChameleonModel and not return the input ids in the output to update the labels in ChameleonForConditionalGeneration. So my suggestion is that when using ChameleonForConditionalGeneration, we can get BPE tokens encoded by vqmodel in ChameleonForConditionalGeneration forward function to update input_ids and labels instead of in ChameleonModel

Apologies, I'm a bit confused

Can you share your code so we can debug it together?

# Disallow image tokens which does not include special begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, :, image_tokens] = torch.finfo(logits.dtype).min

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here labels not update, maybe we should udpate labels here for calculate ce loss. In addition, in order to deal with different numbers of images in each sample in a batch, my suggestion is to reshape input_ids to fill image_tokens through the view method.

Below I provide a draft code I used

if pixel_values is not None:
    batch_size, sequence_length = input_ids.shape
    input_ids = input_ids.view(batch_size * sequence_length)
    image_tokens = self.model.get_image_tokens(pixel_values)
    special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
    image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
    input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
    input_ids = input_ids.view(batch_size, sequence_length)

if labels is not None:
    # update labels with new input_ids
    mask = labels != -100
    labels = torch.where(mask, input_ids, labels)

YeLuoSuiYou pushed a commit to YeLuoSuiYou/transformers that referenced this pull request Sep 3, 2024
@YeLuoSuiYou
Copy link

Hi, I find a bug, if u use ChameleonForConditionalGeneration and calculate ce loss, u should use new special image tokens encoded by vqmodel to update the labels for ce loss calculation~ @leloykun

That doesn't sound right...
We should calculate the CE loss using the BPE-compatible tokens (i.e. the tokens compatible with Chameleon's tokenizer). That's because those are the outputs of the decoder model.
Pls check the img2bpe & bpe2img converter utils

Thanks for reply.
Yes, we should calculate CE loss using BPE tokens instead of "image" (id is 8711), but we only can get BPE token in the ChameleonModel and not return the input ids in the output to update the labels in ChameleonForConditionalGeneration. So my suggestion is that when using ChameleonForConditionalGeneration, we can get BPE tokens encoded by vqmodel in ChameleonForConditionalGeneration forward function to update input_ids and labels instead of in ChameleonModel

Apologies, I'm a bit confused

Can you share your code so we can debug it together?

Hi, I have submitted a PR to your repository, you can refer to and modify this code for free

@leloykun
Copy link
Contributor Author

leloykun commented Sep 3, 2024

Hi @YeLuoSuiYou! Thanks for the PR!

My current understanding of the matter is:

  1. Internally, if pixel_values is not None, then we use the vqmodel to tokenize them and add the tokens to the input_ids
  2. But we never touch the labels, regardless of whether pixel_values is None or not. We keep it as-is during training.

(2) isn't actually a bug and your fix doesn't fit the library imo. We don't want the inputs and the outputs to interact before the calculating the loss in order to minimize bugs. cc @zucchini-nlp

What you should do, instead, is to pass labels with the image tokens already added to it. Here's what I do in my finetuning scripts:

  1. Pass text and images to the processor. Get input_ids and pixel_values in return.
  2. Clone the input_ids as labels (i.e. labels = input_ids.clone())
  3. Use the vqmodel to tokenize the pixel_values
  4. Add the tokens to labels
  5. Pass input_ids, pixel_values, & labels to the model

@zucchini-nlp
Copy link
Member

What you should do, instead, is to pass labels with the image tokens already added to it.

I couldn't locate the PR but this is exactly what is expected. We (transformers) return input_ids already expanded to account for the image tokens and the user has only to clone input-ids and mask pad-token-ids, similar to Language Modeling task. And Chameleon, as one of the latest VLMs added to the library, follows this. Older VLMs like LLaVA are still in progress and might need more intervention from the user

@zucchini-nlp
Copy link
Member

@leloykun btw, seems like the most important thing left to add on this PR is the tests. Let me know if you need any help with that, would be super nice to have this PR merged soon :)

@YeLuoSuiYou
Copy link

Hi @YeLuoSuiYou! Thanks for the PR!

My current understanding of the matter is:

  1. Internally, if pixel_values is not None, then we use the vqmodel to tokenize them and add the tokens to the input_ids
  2. But we never touch the labels, regardless of whether pixel_values is None or not. We keep it as-is during training.

(2) isn't actually a bug and your fix doesn't fit the library imo. We don't want the inputs and the outputs to interact before the calculating the loss in order to minimize bugs. cc @zucchini-nlp

What you should do, instead, is to pass labels with the image tokens already added to it. Here's what I do in my finetuning scripts:相反,您应该做的是传递已添加图像标记_的_labels 。这是我在微调脚本中所做的事情:

  1. Pass text and images to the processor. Get input_ids and pixel_values in return.
  2. Clone the input_ids as labels (i.e. labels = input_ids.clone())
  3. Use the vqmodel to tokenize the pixel_values
  4. Add the tokens to labels
  5. Pass input_ids, pixel_values, & labels to the model

Thanks for reply, I got it

Copy link

@minostauros minostauros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm already using some of the features added by this PR and hope it goes upstream soon.
Thanks for your work!

@ArthurZucker
Copy link
Collaborator

@leloykun feel free to ping @zucchini-nlp again for a review, I'll do the final one afterwards!

@ArthurZucker ArthurZucker removed their request for review November 19, 2024 15:20
@zucchini-nlp
Copy link
Member

@leloykun hey, I can take over and write tests if you are busy, so we can merge faster. I think everything else was approved earlier

@leloykun
Copy link
Contributor Author

Hi @zucchini-nlp ! I'd really appreciate it as I don't see myself being able to continue working on it for the next few weeks.

@zucchini-nlp zucchini-nlp marked this pull request as ready for review December 3, 2024 12:06
@zucchini-nlp
Copy link
Member

@ArthurZucker ready for review! Added some more tests and removed unused logits processor to reduce maintainment burden. Actually I believe the generation might be done with simple prefix constraint and one new logit processor but didn't have time to check it out. Should be very similar to Emu3 inference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants