-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multimodal collater with interleaved image, cross-attention mask padding #1156
Changes from 16 commits
d294766
8bcbbc9
66bb855
d353f90
8102181
172b75f
9a2eb6e
b606171
92115d7
89bccf1
dc4d9f3
e4e5f58
9573cab
3fb8af4
7e82b08
7ca46bf
d89c33a
3848c40
6a0b462
b5d36e7
85dbb95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict, List, Tuple, Union | ||
from typing import Any, Dict, List, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
@@ -214,6 +214,169 @@ def padded_collate_sft( | |
return {"tokens": input_ids.long(), "labels": labels.long()} | ||
|
||
|
||
# TODO: Generalize this to support any type of encoder input, right now this assumes | ||
# a specific encoder_input signature | ||
def padded_collate_tiled_images_with_cross_attention( | ||
batch: List[Dict[str, Any]], | ||
padding_idx: int = 0, | ||
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, | ||
) -> Dict[str, torch.Tensor]: | ||
"""Pad a batch of text sequences, tiled image tensors, aspect ratios, | ||
and cross attention masks. | ||
|
||
``batch`` is expected to be a list of sample dicts containing the following: | ||
- "tokens": List[int] of length text_seq_len, varies across samples | ||
- "labels": List[int] of length text_seq_len, varies across samples | ||
- "encoder_input": Dict[str, List[torch.Tensor]] | ||
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) | ||
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio | ||
- "encoder_mask": List[Tensor], each with shape (text_seq_len, image_seq_len) | ||
|
||
where c = channel dim, h = height dim, w = weight dim. For each element in the batch, | ||
len(images) == len(encoder_mask) == len(aspect_ratio). | ||
|
||
This collater does the following: | ||
(1) Pad text sequence and encoder mask to the longest sequence length in the batch | ||
(2) Pad image tensors in the tile dimension with zeros to the largest number | ||
of tiles in the batch | ||
(3) Add empty images of zeros to samples up to max number of images in the batch | ||
(4) Pad aspect ratios with (1,1) for all added padding images | ||
|
||
Args: | ||
batch (List[Dict[str, Any]]): A list of sample dicts containing tokens, | ||
labels, images, encoder_mask, and aspect_ratio. | ||
padding_idx (int): Padding index for input token ids. Defaults to 0. | ||
ignore_idx (int): Padding index for labels. Defaults to -100. | ||
|
||
Returns: | ||
Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors. | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- tokens: Tensor of shape (bsz, max_seq_len) | ||
- labels: Tensor of shape (bsz, max_seq_len) | ||
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) | ||
- encoder_mask: Tensor of shape (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images) | ||
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2) | ||
|
||
Example: | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> image_id = 1 | ||
>>> tokens_per_tile = 5 | ||
>>> c, h, w = 1, 1, 1 | ||
>>> batch = [ | ||
... { | ||
... "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7], | ||
... "encoder_input": { | ||
... # One image with two tiles, one image with three tiles | ||
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], | ||
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], | ||
... }, | ||
... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) | ||
... "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)], | ||
... }, | ||
... { | ||
... "tokens": [1, 4], "labels": [8, 9], | ||
... "encoder_input": { | ||
... # One image with four tiles | ||
... "images": [torch.ones(4, c, h, w)], | ||
... "aspect_ratio": [torch.tensor([2, 2])], | ||
... }, | ||
... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) | ||
... "encoder_mask": [torch.ones(2, 5 * 4)], | ||
... }, | ||
... ] | ||
>>> model_inputs = padded_collate_vision_text(batch=batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name doesn't match the current name. I actually prefer padded_collate_vision_text as it's more straight forward and we can either generalize this function or split and rename as we get more vision_text models in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
>>> print(model_inputs["tokens"]) | ||
tensor([[1, 2, 1, 3], | ||
[1, 4, 0, 0]]) | ||
>>> print(model_inputs["labels"]) | ||
tensor([[4, 5, 6, 7], | ||
[8, 9, -100, -100]]) | ||
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) | ||
torch.Size([2, 2, 4, 1, 1, 1]) | ||
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images) | ||
torch.Size([2, 4, 40]) | ||
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this should be [2, 2, 2] or [2, 4]. @felipemello1 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aspect_ratio should be (bsz, max_num_images, 2), and then in the clip we reshape: aspect_ratio = aspect_ratio.reshape(bsz_and_n_imgs, 2) |
||
torch.Size([2, 2, 2]) | ||
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four | ||
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) | ||
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four | ||
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) | ||
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded | ||
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) | ||
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample | ||
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) | ||
""" | ||
# Text tokens can be handled independently by existing collater | ||
text_only = [ | ||
{"tokens": sample["tokens"], "labels": sample["labels"]} for sample in batch | ||
] | ||
collated_text = padded_collate_sft(text_only, padding_idx, ignore_idx) | ||
max_seq_len = collated_text["tokens"].shape[-1] | ||
bsz = len(batch) | ||
|
||
# TODO: Figure out how to make this more efficient or vectorized. Setting | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. didnt think too much about it, but maybe:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pre-allocating would definitely simplify the code. I would still need to loop through each individual image though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will leave this as a follow-up though in the interest of time |
||
# max_num_tiles beforehand will save one nested for loop but may incur more | ||
# memory and compute costs in attention if max_num_tiles > batch_max_num_tiles | ||
|
||
# First loop: get max number of tiles in batch | ||
max_num_tiles = max( | ||
image.shape[0] | ||
for sample in batch | ||
for image in sample["encoder_input"]["images"] | ||
) | ||
# Second loop: pad images and masks to max number of tiles, max text seq len in batch | ||
batch_images = [] | ||
batch_masks = [] | ||
batch_aspect_ratios = [] | ||
for sample in batch: | ||
sample_images = [] | ||
sample_masks = [] | ||
for image, mask in zip( | ||
sample["encoder_input"]["images"], sample["encoder_mask"] | ||
): | ||
# Single image in each sample has shape (n_tiles, c, h, w) | ||
n_tiles = image.shape[0] | ||
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len) | ||
# where image_seq_len = n_tiles * tokens_per_tile | ||
text_seq_len, image_seq_len = mask.shape | ||
tokens_per_tile = image_seq_len // n_tiles | ||
padding_tiles = max_num_tiles - n_tiles | ||
padding_text = max_seq_len - text_seq_len | ||
# Image should now have shape (max_num_tiles, c, h, w) | ||
padded_image = F.pad(image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0) | ||
# Mask should now have shape (max_seq_len, max_image_seq_len), where | ||
# max_image_seq_len = max_num_tiles * tokens_per_tile | ||
padded_mask = F.pad( | ||
mask, (0, padding_tiles * tokens_per_tile, 0, padding_text), value=0 | ||
) | ||
sample_images.append(padded_image) | ||
sample_masks.append(padded_mask) | ||
# Stack multiple images and masks per sample in num_images dimension | ||
batch_images.append(torch.stack(sample_images)) | ||
batch_masks.append(torch.stack(sample_masks)) | ||
batch_aspect_ratios.append(torch.stack(sample["encoder_input"]["aspect_ratio"])) | ||
# Finally, pad images, masks, aspect ratios to max number of images in batch | ||
# (bsz, max_num_images, max_num_tiles, c, h, w) | ||
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) | ||
# (bsz, max_num_images, max_seq_len, max_image_seq_len) | ||
collated_masks = pad_sequence(batch_masks, batch_first=True, padding_value=0) | ||
# (bsz, max_num_images, 2) | ||
collated_aspect_ratios = pad_sequence( | ||
batch_aspect_ratios, batch_first=True, padding_value=1 | ||
) | ||
|
||
# Concatenate masks for multiple images across image_seq_len dimension | ||
concat_masks = collated_masks.view(bsz, max_seq_len, -1) | ||
|
||
return { | ||
"tokens": collated_text["tokens"], | ||
"labels": collated_text["labels"], | ||
"encoder_input": { | ||
"images": collated_images, | ||
"aspect_ratio": collated_aspect_ratios, | ||
}, | ||
"encoder_mask": concat_masks, | ||
} | ||
|
||
|
||
def padded_collate_dpo( | ||
batch: List[Dict[str, List[int]]], | ||
padding_idx: int = 0, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say somewhere
c, h, w
=channel, height, width
?