Skip to content

Commit

Permalink
Add conversion for interleave llava (#31858)
Browse files Browse the repository at this point in the history
* add conversion for interleave llava

* remove debug lines

* remove unused imports

* Update src/transformers/models/llava/convert_llava_weights_to_hf.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* small changes + docs

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
zucchini-nlp and amyeroberts authored Jul 10, 2024
1 parent ad35309 commit 97aa3e2
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 18 deletions.
14 changes: 13 additions & 1 deletion docs/source/en/model_doc/llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,20 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/

- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.

- For better results, we recommend users to prompt the model with the correct prompt format:
- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint:

[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format:
```bash
"<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant"
```

For multiple turns conversation:

```bash
"<|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant <answer1><|im_end|><|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant "
```

[llava-1.5 models](https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0) requires the following format:
```bash
"USER: <image>\n<prompt> ASSISTANT:"
```
Expand Down
86 changes: 69 additions & 17 deletions src/transformers/models/llava/convert_llava_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob

import torch
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors import safe_open

from transformers import (
AddedToken,
AutoConfig,
AutoImageProcessor,
AutoTokenizer,
CLIPImageProcessor,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaProcessor,
SiglipVisionConfig,
)


Expand All @@ -48,6 +51,7 @@

KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
Expand All @@ -58,6 +62,26 @@
}


def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])

original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)

# tied wieghts so lm.head is not saved. Let's clone to load state dict
if "lm_head.weight" not in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()

del original_state_dict["model.image_newline"] # not used in the original implementation because "merge_type=flat"
return original_state_dict


# used only for llava-interlave
# for ex: Qwen/Qwen1.5-0.5B-Chat google/siglip-so400m-patch14-384 lmms-lab/llava-next-interleave-qwen-0.5b
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
Expand All @@ -77,24 +101,48 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o

tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})

image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)
if "Qwen" not in text_model_id: # qwen already has a pad token
tokenizer.add_special_tokens({"pad_token": "<pad>"})

image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)

config = LlavaConfig(text_config=text_config)
config.pad_token_id = 32001
if "Qwen" in text_model_id:
vision_config = SiglipVisionConfig(
hidden_size=1152,
image_size=384,
intermediate_size=4304,
num_attention_heads=16,
num_hidden_layers=26,
patch_size=14,
vision_use_head=False,
).to_dict()
else:
vision_config = None

config = LlavaConfig(
text_config=text_config,
vision_config=vision_config,
)

# llms-lab interleeave models do not use any selection startegy except for last hidden state
if "Qwen" in text_model_id:
config.image_token_index = 151646
config.vision_feature_select_strategy = "full"
config.vision_feature_layer = -1
else:
config.pad_token_id = 32001
config.image_token_index = 32000

with torch.device("meta"):
model = LlavaForConditionalGeneration(config)

# Pad to 64 for performance reasons
pad_shape = 64
if "Qwen" in text_model_id:
state_dict = load_original_state_dict(old_state_dict_id)
else:
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
state_dict = torch.load(state_dict_path, map_location="cpu")

state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")

state_dict = torch.load(state_dict_path, map_location="cpu")
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True)

Expand All @@ -104,14 +152,18 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)

# We add an image token so we resize the model
# We add an image token so we resize the model and pad to 64 for performance reasons
pad_shape = 64
vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
tuple(
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
),
dim=0,
)
model.language_model.lm_head.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)

Expand Down

0 comments on commit 97aa3e2

Please sign in to comment.