Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: modelling_vlm.py & README.md #34

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
<a href="https://www.deepseek.com/" target="_blank">
<img alt="Homepage" src="images/badge.svg" />
</a>
</a>
<a href="https://huggingface.co/deepseek-ai" target="_blank">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
</a>
Expand Down Expand Up @@ -66,7 +65,7 @@

**2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow).

**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link](https://github.com/open-compass/VLMEvalKit/pull/541).

**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).

Expand Down Expand Up @@ -165,10 +164,10 @@ prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)

# # run image encoder to get the image embeddings
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

# # run the model to get the response
# run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
Expand Down
218 changes: 92 additions & 126 deletions janus/models/modeling_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
PretrainedConfig,
)
from transformers.configuration_utils import PretrainedConfig

from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector


class vision_head(torch.nn.Module):
class VisionHead(torch.nn.Module):
"""
Vision head module for processing visual embeddings.
"""
def __init__(self, params):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params.n_embed, params.image_token_embed
)
self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed)
self.vision_activation = torch.nn.GELU()
self.vision_head = torch.nn.Linear(
params.image_token_embed, params.image_token_size
)
self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size)

def forward(self, x):
x = self.output_mlp_projector(x)
Expand All @@ -52,171 +50,132 @@ def forward(self, x):


def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name:
cls = MlpProjector

elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower

elif "VQ" in cls_name:
"""
Maps a class name to its corresponding class.
"""
mapping = {
"MlpProjector": MlpProjector,
"CLIPVisionTower": CLIPVisionTower,
"vision_head": VisionHead,
}

if "VQ" in cls_name:
from janus.models.vq_model import VQ_models
return VQ_models[cls_name]

cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
else:
raise ValueError(f"class_name {cls_name} is invalid.")

cls = mapping.get(cls_name)
if cls is None:
raise ValueError(f"Invalid class name: {cls_name}")
return cls


class VisionConfig(PretrainedConfig):
model_type = "vision"
class BaseConfig(PretrainedConfig):
"""
Base configuration class for multi-modality components.
"""
model_type = ""
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)
class VisionConfig(BaseConfig):
model_type = "vision"

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(BaseConfig):
model_type = "aligner"


class GenVisionConfig(PretrainedConfig):
class GenVisionConfig(BaseConfig):
model_type = "gen_vision"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class GenAlignerConfig(PretrainedConfig):
class GenAlignerConfig(BaseConfig):
model_type = "gen_aligner"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class GenHeadConfig(PretrainedConfig):
class GenHeadConfig(BaseConfig):
model_type = "gen_head"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class MultiModalityConfig(PretrainedConfig):
"""
Configuration for the multi-modality model.
"""
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig

gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig

language_config: LlamaConfig

def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)

aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)

gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)

gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)

gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {}))
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))

language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
self.language_config = (
language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config)
)


class MultiModalityPreTrainedModel(PreTrainedModel):
"""
Base class for multi-modality pre-trained models.
"""
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"


class MultiModalityCausalLM(MultiModalityPreTrainedModel):
"""
Multi-modality causal language model combining vision and language components.
"""
def __init__(self, config: MultiModalityConfig):
super().__init__(config)

vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
# Initialize vision model
vision_cls = model_name_to_cls(config.vision_config.cls)
self.vision_model = vision_cls(**config.vision_config.params)

aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
# Initialize aligner
aligner_cls = model_name_to_cls(config.aligner_config.cls)
self.aligner = aligner_cls(config.aligner_config.params)

gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
# Initialize generative vision model
gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()

gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
# Initialize generative aligner
gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params)

gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
# Initialize generative head
gen_head_cls = model_name_to_cls(config.gen_head_config.cls)
self.gen_head = gen_head_cls(config.gen_head_config.params)

# Generative embedding layer
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
config.gen_vision_config.params.image_token_size,
config.gen_vision_config.params.n_embed,
)

language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
# Language model
self.language_model = LlamaForCausalLM(config.language_config)

def prepare_inputs_embeds(
self,
Expand All @@ -227,46 +186,53 @@ def prepare_inputs_embeds(
**kwargs,
):
"""
Prepares input embeddings by combining text and image embeddings.

Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]

assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)

Returns:
input_embeds (torch.Tensor): [b, T, D]
"""

bs, n = pixel_values.shape[0:2]
bs, n = pixel_values.shape[:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))

# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# Process images through vision model and aligner
images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2]

# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# Prepare text embeddings
input_ids[input_ids < 0] = 0 # Ignore negative IDs
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D]

# replace with the image embeddings
# Replace text embeddings with image embeddings where applicable
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]

return inputs_embeds

def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
"""
Prepares generative image embeddings.

Args:
image_ids (torch.LongTensor): Image token IDs.

Returns:
torch.Tensor: Generated image embeddings.
"""
return self.gen_aligner(self.gen_embed(image_ids))


# Register configurations with Hugging Face's AutoConfig
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)
AutoConfig.register("gen_aligner", GenAlignerConfig)
AutoConfig.register("gen_head", GenHeadConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)

# Register the multi-modality causal LM model
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)