Skip to content

Commit

Permalink
Language stream changes (#264)
Browse files Browse the repository at this point in the history
* fix padding side when generating

* clean up language stream forward pass (less for looping)

* expose BLIP model

* fixes for forward pass without images

* restore for looping
  • Loading branch information
anas-awadalla authored Sep 21, 2023
1 parent 5ad05c4 commit 939d460
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions open_flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .src.flamingo import Flamingo
from .src.kosmos import Kosmos
from .src.blip import BLIP
from .src.factory import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
22 changes: 16 additions & 6 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ def generate(
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)

# convert pixels to vision tokens
if vision_x is not None:
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
Expand All @@ -268,6 +268,7 @@ def generate(
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
generating=True,
)
output = self.lang_model.generate(
**new_inputs,
Expand Down Expand Up @@ -402,6 +403,7 @@ def _prepare_inputs_for_forward(
past_key_values=None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
generating: bool = False, # Not used for cross-attention models
):
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
self.lang_model._condition_media_before_forward(
Expand Down Expand Up @@ -532,7 +534,7 @@ def __init__(
block._use_gradient_checkpointing = gradient_checkpointing
assert (
self.vis_embedding_dim == self.lang_embedding_dim
), "To place visual tokens direclty in the language stream, the visual and language tokens need to be the same dim."
), "To place visual tokens directly in the language stream, the visual and language tokens need to be the same dim."

def _prepare_inputs_for_forward(
self,
Expand All @@ -543,6 +545,7 @@ def _prepare_inputs_for_forward(
past_key_values=None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
generating: bool = False, # whether we're generating to decide on padding side
):
"""
Insert the vision tokens directly into the language stream/
Expand All @@ -567,6 +570,13 @@ def _prepare_inputs_for_forward(
for i in range(B):
# get index of <image> tokens in lang_x[i]
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]

if len(image_token_idxs) == 0:
multimodal_embeds.append(lang_embeds[i].clone())
multimodal_attention_mask.append(attention_mask[i].clone())
if has_labels:
multimodal_labels.append(labels[i].clone())
continue

# since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
for j, img_idx in enumerate(image_token_idxs):
Expand Down Expand Up @@ -618,14 +628,14 @@ def _prepare_inputs_for_forward(

# stack
multimodal_embeds = stack_with_padding(
multimodal_embeds, padding_value=self.pad_token_id
multimodal_embeds, padding_value=self.pad_token_id, padding_side="left" if generating else "right"
)
multimodal_attention_mask = stack_with_padding(
multimodal_attention_mask, padding_value=0
multimodal_attention_mask, padding_value=0, padding_side="left" if generating else "right"
)
if has_labels:
multimodal_labels = stack_with_padding(
multimodal_labels, padding_value=-100
multimodal_labels, padding_value=-100, padding_side="left" if generating else "right"
)

return {
Expand Down

0 comments on commit 939d460

Please sign in to comment.