Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for Qwen2-VL #64

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
additional_dependencies: ["bandit[toml]"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.7
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
17 changes: 1 addition & 16 deletions maestro/trainer/common/data_loaders/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any

from PIL import Image
from transformers.pipelines.base import Dataset


class JSONLDataset:
Expand Down Expand Up @@ -34,18 +33,4 @@ def __getitem__(self, idx: int) -> tuple[Image.Image, dict[str, Any]]:
except FileNotFoundError:
raise FileNotFoundError(f"Image file {image_path} not found.")
else:
return (image, entry)


class DetectionDataset(Dataset):
def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None:
self.dataset = JSONLDataset(jsonl_file_path, image_directory_path)

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx):
image, data = self.dataset[idx]
prefix = data["prefix"]
suffix = data["suffix"]
return prefix, suffix, image
return image, entry
Empty file.
21 changes: 18 additions & 3 deletions maestro/trainer/models/florence_2/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,23 @@
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
from transformers.pipelines.base import Dataset

from maestro.trainer.common.data_loaders.datasets import DetectionDataset
from maestro.trainer.common.data_loaders.datasets import JSONLDataset


class Florence2Dataset(Dataset):
def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None:
self.dataset = JSONLDataset(jsonl_file_path, image_directory_path)

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx):
image, data = self.dataset[idx]
prefix = data["prefix"]
suffix = data["suffix"]
return prefix, suffix, image


def create_data_loaders(
Expand Down Expand Up @@ -85,7 +100,7 @@ def create_split_data_loader(
def load_split_dataset(
dataset_location: str,
split_name: str,
) -> Optional[DetectionDataset]:
) -> Optional[Florence2Dataset]:
image_directory_path = os.path.join(dataset_location, split_name)
jsonl_file_path = os.path.join(dataset_location, split_name, "annotations.jsonl")
if not os.path.exists(image_directory_path):
Expand All @@ -94,7 +109,7 @@ def load_split_dataset(
if not os.path.exists(jsonl_file_path):
logging.warning(f"Could not find JSONL file: {jsonl_file_path}")
return None
return DetectionDataset(
return Florence2Dataset(
jsonl_file_path=jsonl_file_path,
image_directory_path=image_directory_path,
)
Expand Down
4 changes: 2 additions & 2 deletions maestro/trainer/models/florence_2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
from transformers import AutoProcessor

from maestro.trainer.common.data_loaders.datasets import DetectionDataset
from maestro.trainer.models.florence_2.loaders import Florence2Dataset

DETECTION_CLASS_PATTERN = r"([a-zA-Z0-9 -]+)<loc_\d+>"

Expand Down Expand Up @@ -59,7 +59,7 @@ def process_output_for_text_metric(
return predictions


def get_unique_detection_classes(dataset: DetectionDataset) -> list[str]:
def get_unique_detection_classes(dataset: Florence2Dataset) -> list[str]:
class_set = set()
for i in range(len(dataset)):
_, suffix, _ = dataset[i]
Expand Down
Empty file.
1 change: 1 addition & 0 deletions maestro/trainer/models/qwen2_vl/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DEFAULT_FLORENCE2_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
75 changes: 75 additions & 0 deletions maestro/trainer/models/qwen2_vl/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from transformers.pipelines.base import Dataset

from maestro.trainer.common.data_loaders.datasets import JSONLDataset

START_TOKEN_1 = 151644
START_TOKEN_2 = 77091
END_TOKEN = 151645


def extract_assistant_content_ranges(token_list: list[int]) -> list[tuple[int, int]]:
"""
Identify the start and end indexes of assistant content ranges within a list of
tokens.

The function searches for sequences that mark the start and end of assistant content
in the tokenized list, returning the corresponding index ranges.

Args:
token_list (list[int]): A list of tokens to search.

Returns:
list[tuple[int, int]]: A list of (start_index, end_index) tuples indicating the
assistant content ranges in the input list.

Note:
- Assistant content starts with the sequence [START_TOKEN_1, START_TOKEN_2],
which corresponds to the tokenized value of `"<|im_start|>assistant"`.
- Assistant content ends with END_TOKEN, which corresponds to the tokenized
value of `"<|im_end|>"`.
- Each start sequence has a corresponding end token.
"""
start_indexes = []
end_indexes = []

for i in range(len(token_list) - 1):
if token_list[i] == START_TOKEN_1 and token_list[i + 1] == START_TOKEN_2:
start_indexes.append(i)
for j in range(i + 2, len(token_list)):
if token_list[j] == END_TOKEN:
end_indexes.append(j)
break

return list(zip(start_indexes, end_indexes))


class Qwen2VLDataset(Dataset):
def __init__(self, jsonl_file_path: str, image_directory_path: str) -> None:
self.dataset = JSONLDataset(jsonl_file_path, image_directory_path)

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx):
image, data = self.dataset[idx]
prefix = data["prefix"]
suffix = data["suffix"]
# fmt: off
return {
"messages": [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prefix}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": suffix}
]
}
]
}
# fmt: on
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
dependencies = [
"supervision~=0.24.0rc1",
"requests>=2.31.0,<=2.32.3",
"transformers~=4.44.2",
"transformers @ git+https://github.com/huggingface/transformers",
"torch~=2.4.0",
"accelerate>=0.33,<0.35",
"sentencepiece~=0.2.0",
Expand Down