Skip to content

Commit

Permalink
feat: Support text only data.
Browse files Browse the repository at this point in the history
  • Loading branch information
2U1 committed Oct 29, 2024
1 parent 0731c55 commit 80b59b8
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
video_file = os.path.join(video_folder, video_file)
videos.append(get_video_info(video_file, self.max_pixel, self.data_args.fps))
else:
grid_key = None
pixel_key = None
images = None
videos = None

Expand Down Expand Up @@ -224,8 +226,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
labels=labels,
)

data_dict[pixel_key] = pixel_values
data_dict[grid_key] = image_thw
if pixel_key and grid_key:
data_dict[pixel_key] = pixel_values
data_dict[grid_key] = image_thw

return data_dict

Expand All @@ -247,32 +250,43 @@ def __call__(self, examples):
grid_key = "video_grid_thw"
pixel_key = "pixel_values_videos"

else:
elif "pixel_values" in sample:
grid_key = "image_grid_thw"
pixel_key = "pixel_values"

pixel_key = "pixel_values"\

else:
grid_key = None
pixel_key = None

for example in examples:
batch_input_ids.append(example["input_ids"])
batch_label_ids.append(example["labels"])
batch_pixel_values.append(example[pixel_key])
batch_image_thw.append(example[grid_key])

if pixel_key and grid_key:
batch_pixel_values.append(example[pixel_key])
batch_image_thw.append(example[grid_key])

input_ids = pad_sequence(
batch_input_ids, padding_side='right', padding_value=self.pad_token_id
)

attention_mask = input_ids != self.pad_token_id
labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX)
pixel_values = torch.cat(batch_pixel_values, dim=0)
image_thw = torch.cat(batch_image_thw, dim=0)

return {

data_dict = {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
pixel_key: pixel_values,
grid_key: image_thw,
}

if pixel_key and grid_key:
pixel_values = torch.cat(batch_pixel_values, dim=0)
image_thw = torch.cat(batch_image_thw, dim=0)
data_dict[pixel_key] = pixel_values
data_dict[grid_key] = image_thw

return data_dict


def replace_image_tokens(input_string, is_video=False):
Expand Down

0 comments on commit 80b59b8

Please sign in to comment.