Skip to content

Commit

Permalink
Chameleon: minor fixes after shipping (#32037)
Browse files Browse the repository at this point in the history
* fix merging

* make chameleon conditional
  • Loading branch information
zucchini-nlp authored Jul 18, 2024
1 parent 765732e commit 673d30b
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 31 deletions.
20 changes: 10 additions & 10 deletions docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ The original code can be found [here](https://github.com/facebookresearch/chamel
Here's how to load the model and perform inference in half-precision (`torch.float16`):

```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
import torch
from PIL import Image
import requests

processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")

# prepare image and text prompt
url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg"
Expand All @@ -89,13 +89,13 @@ print(processor.decode(output[0], skip_special_tokens=True))
Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:

```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
import torch
from PIL import Image
import requests

processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")

# Get three different images
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
Expand Down Expand Up @@ -129,7 +129,7 @@ processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokeniza
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:

```python
from transformers import ChameleonForCausalLM, BitsAndBytesConfig
from transformers import ChameleonForConditionalGeneration, BitsAndBytesConfig

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
Expand All @@ -138,17 +138,17 @@ quantization_config = BitsAndBytesConfig(
bnb_4bit_compute_dtype=torch.float16,
)

model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
```

### Use Flash-Attention 2 and SDPA to further speed-up generation

The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:

```python
from transformers import ChameleonForCausalLM
from transformers import ChameleonForConditionalGeneration

model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
Expand Down Expand Up @@ -183,7 +183,7 @@ model = ChameleonForCausalLM.from_pretrained(
[[autodoc]] ChameleonModel
- forward

## ChameleonForCausalLM
## ChameleonForConditionalGeneration

[[autodoc]] ChameleonForCausalLM
[[autodoc]] ChameleonForConditionalGeneration
- forward
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,7 @@
)
_import_structure["models.chameleon"].extend(
[
"ChameleonForCausalLM",
"ChameleonForConditionalGeneration",
"ChameleonModel",
"ChameleonPreTrainedModel",
"ChameleonProcessor",
Expand Down Expand Up @@ -6276,7 +6276,7 @@
load_tf_weights_in_canine,
)
from .models.chameleon import (
ChameleonForCausalLM,
ChameleonForConditionalGeneration,
ChameleonModel,
ChameleonPreTrainedModel,
ChameleonProcessor,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForCausalLM"),
("chameleon", "ChameleonForCausalLM"),
("code_llama", "LlamaForCausalLM"),
("codegen", "CodeGenForCausalLM"),
("cohere", "CohereForCausalLM"),
Expand Down Expand Up @@ -703,6 +702,7 @@
[
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("chameleon", "ChameleonForConditionalGeneration"),
("git", "GitForCausalLM"),
("idefics2", "Idefics2ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/chameleon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
pass
else:
_import_structure["modeling_chameleon"] = [
"ChameleonForCausalLM",
"ChameleonForConditionalGeneration",
"ChameleonModel",
"ChameleonPreTrainedModel",
"ChameleonVQVAE",
Expand All @@ -62,7 +62,7 @@
pass
else:
from .modeling_chameleon import (
ChameleonForCausalLM,
ChameleonForConditionalGeneration,
ChameleonModel,
ChameleonPreTrainedModel,
ChameleonVQVAE,
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,8 @@ def forward(
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.device, input_ids.dtype)
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -1445,7 +1446,7 @@ def _update_causal_mask(
"Chameleon Model with a head on top used for outputting logits for next token prediction.",
CHAMELEON_START_DOCSTRING,
)
class ChameleonForCausalLM(ChameleonPreTrainedModel):
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1504,12 +1505,12 @@ def forward(
Example:
```python
>>> from transformers import ChameleonProcessor, ChameleonForCausalLM
>>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
>>> import torch
>>> import requests
>>> from PIL import Image
>>> model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
>>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
>>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
>>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,7 @@ def load_tf_weights_in_canine(*args, **kwargs):
requires_backends(load_tf_weights_in_canine, ["torch"])


class ChameleonForCausalLM(metaclass=DummyObject):
class ChameleonForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
28 changes: 17 additions & 11 deletions tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import torch

from transformers import (
ChameleonForCausalLM,
ChameleonForConditionalGeneration,
ChameleonModel,
ChameleonProcessor,
)
Expand Down Expand Up @@ -191,7 +191,7 @@ def create_and_check_for_causal_lm(
encoder_hidden_states,
encoder_attention_mask,
):
model = ChameleonForCausalLM(config=config)
model = ChameleonForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
Expand All @@ -209,7 +209,7 @@ def create_and_check_decoder_model_past_large_inputs(
encoder_attention_mask,
):
config.is_decoder = True
model = ChameleonForCausalLM(config=config)
model = ChameleonForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()

Expand Down Expand Up @@ -273,12 +273,12 @@ def prepare_config_and_inputs_for_common(self):

@require_torch
class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ChameleonModel, ChameleonForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (ChameleonForCausalLM,) if is_torch_available() else ()
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ChameleonForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ChameleonModel,
"text-generation": ChameleonForCausalLM,
"text-generation": ChameleonForConditionalGeneration,
}
if is_torch_available()
else {}
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
device_map={"": 0},
Expand All @@ -355,7 +355,7 @@ def test_flash_attn_2_generate_padding_right(self):
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = processor.tokenizer.batch_decode(output_native)

model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
attn_implementation="flash_attention_2",
Expand All @@ -377,7 +377,9 @@ class ChameleonIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_7b(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")

image = Image.open(
Expand All @@ -397,7 +399,9 @@ def test_model_7b(self):
@require_bitsandbytes
@require_read_token
def test_model_7b_batched(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")

image = Image.open(
Expand Down Expand Up @@ -428,7 +432,9 @@ def test_model_7b_batched(self):
@require_bitsandbytes
@require_read_token
def test_model_7b_multi_image(self):
model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto")
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")

image = Image.open(
Expand Down

0 comments on commit 673d30b

Please sign in to comment.