diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca7ea749d..520281793 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,11 +27,12 @@ jobs: - name: Install dependencies run: | + sudo apt install libjpeg-dev pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 632fa7b93..75ba3bb31 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,11 +29,12 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Build the documentation run: mkdocs build @@ -56,6 +57,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ diff --git a/Dockerfile b/Dockerfile index e98223de8..6c013c14d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..2c728ed4b 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None @@ -49,12 +51,24 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = any(sample.images is not None for sample in batch) + if has_images: + images = [ + [] if sample.images is None else [torch.from_numpy(image) for image in sample.images] for sample in batch + ] + image_positions = [ + [] if sample.image_positions is None else torch.from_numpy(sample.image_positions) for sample in batch + ] + else: + images, image_positions = None, None return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + images=images, + image_positions=image_positions, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc9..7bfdc8515 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,15 @@ class GPTSamplingConfig(SamplingConfig): @dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): +class ImageSamplingParameters: + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None + + +@dataclasses.dataclass(kw_only=True) +class GPTSamplingParameters(SamplingParameters, ImageSamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ @@ -142,11 +150,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..843f6735d 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, add_eos=False)], dtype=np.int64) + middle = np.array([*self._tokenizer.tokenize(middle, add_bos=False, add_eos=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer.tokenize(suffix, add_bos=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..59e701a63 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -19,12 +19,26 @@ def get_document_sizes(self) -> np.ndarray: and derived classes should try to avoid holding the whole array im memory. """ + def get_image_sizes(self) -> list[np.ndarray]: + """ + The size of each image in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + raise NotImplementedError() + @abc.abstractmethod def get_document_size(self, index: int) -> int: """ The size of a document in the dataset. """ + def get_image_size(self, index: int) -> np.ndarray: + """ + The size of an image in the dataset. + """ + raise NotImplementedError() + def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset @@ -34,6 +48,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": else GPTSampledIndexedDataset(self, sampling) ) + @property + @abc.abstractmethod + def has_images(self) -> bool: + """ + Whether the dataset contains images. + This is used to determine whether to use image-related fields in the sampled data. + """ + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ @@ -46,9 +68,20 @@ def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return self._dataset.get_document_sizes()[self._begin : self._end] + def get_image_sizes(self) -> list[np.ndarray]: + # TODO: This can be really big. + return self._dataset.get_image_sizes()[self._begin : self._end] + def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) + def get_image_size(self, index: int) -> np.ndarray: + return self._dataset.get_image_size(self._begin + index) + + @property + def has_images(self) -> bool: + return self._dataset.has_images + class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset @@ -59,6 +92,18 @@ def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + def get_image_sizes(self) -> list[np.ndarray]: + # TODO: This can be really big. + return sum([dataset.get_image_sizes() for dataset in self._datasets], []) + def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + def get_image_size(self, index: int) -> np.ndarray: + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_image_size(index - self._dataset_splits[dataset].item()) + + @property + def has_images(self) -> bool: + return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..e0473b7e1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,25 +28,29 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 2 else False + self._has_preference_spans = bool(struct.unpack("= 3 else False + self._has_images = bool(struct.unpack("= 4 else False self._dtype = MEMMAP_DTYPES[struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, - ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) + if self._has_spans: + offset = self._init_spans(offset) + + if self._has_preference_spans: + offset = self._init_preference_spans(offset) + + total_pixels, _ = self._init_images(offset) if self._has_images else (0, offset) + if num_pixels is not None: + assert total_pixels == num_pixels self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - total_pixels, np.dtype(self._dtype).itemsize) if num_tokens is not None: assert self._num_tokens == num_tokens + def _init_spans(self, offset: int) -> int: + num_spans = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._num_documents, + offset=offset, + ) + num_spans_cumsum = np.r_[0, np.cumsum(num_spans[:-1], dtype=np.int64)] + self._spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=num_spans[idx] * 2, + offset=offset + num_spans.nbytes + num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + for idx in range(self._num_documents) + ] + return offset + num_spans.nbytes + num_spans.sum() * 2 * np.dtype(np.int32).itemsize + + def _init_preference_spans(self, offset: int) -> int: + item_size = np.dtype(np.int32).itemsize + self._chosen_spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + 2 * idx * item_size, + ) + for idx in range(self._num_documents) + ] + offset += 2 * item_size * self._num_documents + self._rejected_spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + 2 * idx * item_size, + ) + for idx in range(self._num_documents) + ] + return offset + 2 * item_size * self._num_documents + + def _init_images(self, offset: int) -> tuple[int, int]: + total_pixels = 0 + image_counts = np.frombuffer(self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset) + offset += image_counts.nbytes + + self._image_sizes = [] + self._image_positions = [] + item_size = np.dtype(np.int32).itemsize + + for image_count in image_counts: + self._image_sizes.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=image_count * 2, + offset=offset, + ).reshape(-1, 2) + ) + total_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() + offset += 2 * image_count * item_size + + for image_count in image_counts: + self._image_positions.append( + np.frombuffer(self._index_bin_buffer, dtype=np.int32, count=image_count, offset=offset) + ) + offset += image_count * item_size + return total_pixels, offset + def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -156,57 +192,77 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - - rejected_span = self._rejected_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + + loss_masking_spans = self._get_loss_masking_spans(idx, offset, token_ids) + chosen_span, rejected_span = ( + self._get_preference_spans(idx, offset, token_ids) if use_preference_loss_spans else (None, None) + ) + images, image_positions = self._get_images(idx) return GPTSample( token_ids=token_ids, - loss_masking_spans=sample_spans, + images=images, + image_positions=image_positions, + loss_masking_spans=loss_masking_spans, chosen_span=chosen_span, rejected_span=rejected_span, ) + def _get_loss_masking_spans(self, idx: int, offset: int, token_ids: np.ndarray) -> np.ndarray | None: + if not self._has_spans: + return None + loss_masking_spans = self._spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + loss_masking_spans = loss_masking_spans[ + (loss_masking_spans[:, 0] < offset + len(token_ids)) & (loss_masking_spans[:, 1] >= offset) + ] + + # subtract by offset to normalize span boundaries + loss_masking_spans[:, 0] = np.maximum(loss_masking_spans[:, 0], offset) - offset # offset + loss_masking_spans[:, 1] = np.minimum(loss_masking_spans[:, 1], offset + len(token_ids) - 1) - offset + return loss_masking_spans + + def _get_preference_spans(self, idx: int, offset: int, token_ids: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + if not self._has_preference_spans: + raise ValueError(f"Dataset {self.name} doesn't have preference spans.") + chosen_span = self._chosen_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset + chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + + rejected_span = self._rejected_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + rejected_span = rejected_span[(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset + rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + return chosen_span, rejected_span + + def _get_images(self, idx: int) -> tuple[list[np.ndarray] | None, np.ndarray | None]: + if not self._has_images: + return None, None + # Truncations with images are not yet supported, so we get all images from the document + pixels = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + images = [] + start = 0 + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) + start += n_pixels + return images, self._image_positions[idx] + @property def name(self) -> str: return self._name @@ -218,6 +274,10 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens + @property + def has_images(self) -> bool: + return self._has_images + def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. @@ -226,15 +286,25 @@ def get_document_sizes(self) -> np.ndarray: """ return self._document_sizes + def get_image_sizes(self) -> list[np.ndarray]: + return self._image_sizes if self._has_images else [np.array([])] * self._num_documents + def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() + def get_image_size(self, index: int) -> np.ndarray: + return self._image_sizes[index] if self._has_images else [] + @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + image_sizes = [] + image_positions = [] + has_images = False pointers = [] offset = 0 # number of spans for each document @@ -253,17 +323,16 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if dtype is None: dtype = document.token_ids.dtype assert dtype is not None, "Document dtype could not be inferred from the data." - # Ensure all documents have the same dtype assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." + pointers.append(offset) + doc_lengths.append(doc_length := len(document.token_ids)) + # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + offset += doc_length * np.dtype(dtype).itemsize - # Update metadata - doc_length = len(document.token_ids) - lengths.append(doc_length) - pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) @@ -271,48 +340,88 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans.append(document.chosen_span) if document.rejected_span is not None: rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + + if document.images is not None: + n_images.append(len(document.images)) + has_images = True + for image in document.images: + # assume 3 channels (RGB) for all images + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + image_sizes.append(np.array(pixels.shape[1:])) + bin_stream.write(pixels.tobytes(order="C")) + offset += pixels.size * np.dtype(np.uint8).itemsize + image_positions.extend(document.image_positions) + else: + n_images.append(0) + num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: + + assert len(spans) == len(num_spans) + if has_loss_masking_spans := len(spans) > 0: + assert len(spans) == num_documents + num_spans = np.array(num_spans, dtype=np.int32) spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + + assert len(chosen_spans) == len(rejected_spans) + if has_preference_spans := len(chosen_spans) > 0: + assert len(chosen_spans) == num_documents + chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) + rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + + if has_images: + n_images = np.array(n_images, dtype=np.int32) + image_sizes = np.stack(image_sizes, dtype=np.int32) + image_positions = np.array(image_positions, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) + idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + idx_stream.write(struct.pack(" None: """ Create a `GPTSampledDataset` with the requested parameters. """ - # Get the document sizes, the main information needed for sampling. + # Get the size each document, the main information needed for sampling. + # Note: "document" may refer to more than just text. + document_sizes = self._get_document_sizes() + + documents_per_epoch, tokens_per_epoch, long_docs_filter = self._get_epoch_size(document_sizes) + num_epochs, shuffled_epochs = self._get_epoch_count(documents_per_epoch, tokens_per_epoch) + + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data, cached = self._get_and_compare_yaml_data(documents_per_epoch, tokens_per_epoch, unshuffled_epochs) + if cached: + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) + + document_shuffling = self._get_document_shuffling(documents_per_epoch, shuffled_documents, shuffled_epochs) + + if self._parameters.use_preference_loss_spans: + # index of all documents less than seq length long + self._doc_length_filtered_indices.save(torch.nonzero(long_docs_filter, as_tuple=True)[0].numpy(force=True)) + self._document_sizes.save(document_sizes.numpy(force=True)) + if shuffled_epochs > 0: + self._document_shuffling.save(document_shuffling[: self._parameters.num_samples].numpy(force=True)) + unshuffled_tokens = 0 + + else: + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + cumsum_dtype = get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs) + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(document_sizes, 0, cumsum_dtype) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + unshuffled_tokens = 0 + + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + self._unshuffled_tokens, + cumsum_dtype, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(force=True) + ) + + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def _get_document_sizes(self) -> torch.Tensor: document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + if self._indexed_dataset.has_images: + image_sizes = self._indexed_dataset.get_image_sizes() + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_image_tokens( + *get_resize_dims( + *size, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for size in sizes + ) + ) + document_sizes += torch.tensor(image_token_sizes).to(self._device) + return document_sizes - # Calculate basic stats. - if not self._truncate_documents: + def _get_epoch_size(self, document_sizes: torch.Tensor) -> tuple[int, int, torch.Tensor | None]: + documents_per_epoch = document_sizes.numel() + if self._truncate_documents: + tokens_per_epoch = document_sizes.sum().item() + long_docs_filter = None + else: assert _extension_available, ( "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = long_docs_filter.sum().item() - if ignored_documents: + long_docs_filter = document_sizes <= self._parameters.sequence_length + 1 + documents_per_epoch_filtered = long_docs_filter.sum().item() + if ignored_documents := documents_per_epoch_filtered - documents_per_epoch: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents" + f" are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + # TODO: WHY?!?!?!? + if self._parameters.use_preference_loss_spans: + documents_per_epoch = documents_per_epoch_filtered + tokens_per_epoch = document_sizes[long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1}" + f" tokens found in dataset {self._indexed_dataset.name}." ) + return documents_per_epoch, tokens_per_epoch, long_docs_filter + def _get_epoch_count(self, documents_per_epoch: int, tokens_per_epoch: int) -> tuple[int, int]: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) elif self._truncate_documents: num_epochs = math.ceil( @@ -174,32 +293,34 @@ def _sample(self) -> None: ) # Prepare for shuffling. - generator = torch.Generator(device=self._device) if self._config.shuffle == ShufflingType.skip_first_epoch: shuffled_epochs = num_epochs - 1 elif self._config.shuffle == ShufflingType.disabled: shuffled_epochs = 0 else: shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs + return num_epochs, shuffled_epochs + def _get_and_compare_yaml_data( + self, + documents_per_epoch: int, + tokens_per_epoch: int, + unshuffled_epochs: int, + ) -> tuple[dict[str, typing.Any], bool]: yaml_data = { "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._parameters.num_samples, + "sampling": self._parameters.__dict__, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._parameters.sequence_length, - "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } if self._truncate_documents: yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - if self._yaml_path is not None and self._yaml_path.is_file(): + if cached := (self._yaml_path is not None and self._yaml_path.is_file()): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: @@ -216,120 +337,8 @@ def _sample(self) -> None: ) # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") - return - - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling + return yaml_data, cached def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -372,6 +381,50 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - ] return out, num_tokens + def _get_document_shuffling( + self, + documents_per_epoch: int, + shuffled_documents: int, + shuffled_epochs: int, + ) -> torch.Tensor | None: + generator = torch.Generator(device=self._device) + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + return document_shuffling + def __len__(self) -> int: return self._parameters.num_samples @@ -384,37 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: self._lazy_load() if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample + return self._get_preference_loss_span_sample(index) # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample @@ -441,7 +464,13 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] token_ids = [] - loss_masking_spans = [] + if self._parameters.use_loss_masking_spans: + loss_masking_spans = [] + if self._indexed_dataset.has_images: + images = [] + image_positions = [] + image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -449,7 +478,32 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + text_size = self._indexed_dataset.get_document_size(document_index) + if self._indexed_dataset.has_images: + image_lengths = self._indexed_dataset.get_image_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + document_size = text_size + image_tokens + else: + document_size = text_size if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -468,21 +522,96 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - token_ids.append(sample.token_ids) + if self._indexed_dataset.has_images: + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_tokens_added += image_sizes[idx] + start_pos = im_position + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + images.append(sample.images) + else: + token_ids.append(sample.token_ids) + text_tokens_added += len(token_ids[-1]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + if self._indexed_dataset.has_images: + prev_image_tokens = 0 + image_idx = 0 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + # TODO: Unused, meant to be inside loop? What about 2 lines above? + prev_image_tokens += span_image_tokens + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -506,9 +635,50 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if self._indexed_dataset.has_images else None + image_positions = np.array(image_positions) if self._indexed_dataset.has_images else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) + + def _get_preference_loss_span_sample(self, index: int): + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indices[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indices[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._parameters.use_loss_masking_spans, + use_preference_loss_spans=self._parameters.use_preference_loss_spans, + ) + + chosen_span_end = sample.chosen_span[1] + 1 + sequence_lengths = [ + chosen_span_end, + len(sample.token_ids) - chosen_span_end, + ] + + # compute padding size + padding = np.full((self._parameters.sequence_length + 1,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + if not self._parameters.cross_document_attention: + sample.sequence_lengths = np.array(sequence_lengths) + + return sample @property def name(self) -> str: @@ -593,7 +763,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -175,6 +187,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 427309a99..fce0f022c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,6 +10,7 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm @@ -24,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -41,36 +48,44 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _loss_masking_spans_column: str | None def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) + for input_ids, token_spans, image_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + ) + for text, loss_mask_spans, im_char_positions in zip( + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), + ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + return { "input_ids": input_ids, + "image_positions": image_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -143,27 +158,22 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + has_preference_spans = ( + self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None + ) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._loss_masking_spans_column + else None + ), + item["chosen_token_spans"] if has_preference_spans else None, + item["rejected_token_spans"] if has_preference_spans else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -173,6 +183,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -292,6 +303,11 @@ def run(self) -> None: if isinstance(self._config.dataset.source_schema, TextColumnConfig): self._text_column = self._config.dataset.source_schema.input_column self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column + if isinstance(self._config.dataset.source_schema, TextImageColumnConfig): + self._images_column = self._config.dataset.source_schema.images_column + self._image_positions_column = self._config.dataset.source_schema.image_positions_column + # decoding bytes to images is slow and should be done only when needed + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." @@ -300,18 +316,17 @@ def run(self) -> None: if self._text_column not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + if self._loss_masking_spans_column is not None and ( self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: if self._config.dataset.chosen_text not in dataset.column_names: raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") @@ -331,6 +346,13 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._images_column + else 0 + ) + # Add the token-equivalent bytes of pixels to determine shard size + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -359,7 +381,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -399,7 +421,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -429,10 +455,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes.any(): + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -445,8 +481,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..93fa9b81b 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -41,44 +41,77 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, add_bos=True, add_eos=True, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ - input_ids = [] + if not image_positions: + image_positions = [] + if not char_spans: + char_spans = [] + + # Collect all positions with their type + positions = [] + for pos in image_positions: + positions.append((pos, "image")) + + for start, end in char_spans: + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + + token_ids = [] token_spans = [] + image_token_positions = [] char_pos = 0 - beginning_of_text = True + current_span_start = None - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 + for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times + if char_pos < position[0]: + tokenized_text = self._tokenize( + text[char_pos : position[0]], + begin=(char_pos == 0) and add_bos, + end=position[0] > len(text) - 1 and add_eos, + ) + token_ids.extend(tokenized_text) + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0) and add_bos, end=add_eos) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index a2a9d9d33..0b8bb94f2 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -137,7 +137,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 272b7c6ae..a5e0a86a6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -48,6 +48,12 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 684193848..2c553d906 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -78,7 +79,6 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e2e97f1a..9d8a65929 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -7,6 +7,7 @@ from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -37,6 +38,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" @@ -50,6 +52,10 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + vision_encoder: VisionEncoderConfig = Field( + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -245,6 +251,11 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) + if self.vision_encoder.enabled: + # TODO: Remove tensor spaces so we don't need this hack. + tensor_space.vision = TensorSpace(tensor_space.distributed_config) + self.vision_encoder.setup_tensor_space(tensor_space.vision) + @property def num_absolute_position_embeddings(self) -> int: # TODO: Rename from max embeddings. diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 000000000..a5a789f9e --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,183 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import reduce_forward, split +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert, div + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + + # @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + # Cloning since we will modify the embeddings in-place + embeddings = embeddings.clone() + # the embeddings tensor are full-sized, but we might get a split of the patch embeddings + # We need to determine the offset in the embeddings tensor for each sample + # and also account for the special image tokens if applicable + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + image_embedding_offset += num_patches + continue + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_src + patch_width <= patch_start_offset: + continue + + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) + if self._sequence_parallel: + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx + ] + else: + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index + ] + else: + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx + ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] + image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + # Move to the next image in the input tensor + image_embedding_offset += num_patches + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..e72171ec1 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -371,7 +371,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -381,7 +381,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf5890..4d63b927f 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -80,10 +80,16 @@ class TransformerKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" sequence_length = "sequence_length" + batch_dim = "batch_dim" + micro_batch_size = (micro_batch_size,) # TODO: Move grad_output = "grad_output" +class VisionKwargs: + patch_position_ids = "patch_position_ids" + + class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -206,9 +212,19 @@ def _validate(self) -> None: ) -@config_class() +class TransformerType(str, enum.Enum): + language_model_decoder = "language_model_decoder" + image_encoder = "image_encoder" + + +@config_class(registry=True) class TransformerConfig(LLMBlockConfig): _abstract = False + type: TransformerType = Field( + default=TransformerType.language_model_decoder, + desc="Type of the transformer. Choices: language_model_decoder, image_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, @@ -485,6 +501,11 @@ class TransformerConfig(LLMBlockConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): @@ -662,3 +683,9 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + +for name in TransformerType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerConfig.register_subclass(name.value, TransformerConfig) diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 748f2af28..ba598e385 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -10,7 +10,14 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.transformer.rotary.rotary import ( + DefaultRotary, + Llama3Rotary, + NoRotary, + Rotary, + Rotary2D, + YarnRotary, + ) @config_class(registry=True) @@ -140,3 +147,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.transformer.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..46ea2a8b2 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,14 +8,16 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -212,3 +214,71 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(max_num_patches) + position_ids = kwargs[VisionKwargs.patch_position_ids] + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_q_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=TransformerKwargs.rotary_freq_q, + ) + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_k_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=TransformerKwargs.rotary_freq_k, + ) + + def _create_tensors(self, max_num_patches: int) -> None: + if max_num_patches <= self._tensor_cache_max_num_patches: + return + self._tensor_cache_max_num_patches = max_num_patches + + self._rotary_embedding_frequencies = self._get_frequencies( + max_num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + + def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..761399e5d 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -148,3 +148,7 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + +class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 000000000..fecc6d086 --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,53 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import PixtralVisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer that projects vision encoder features into the language model token embeddings. + """ + + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 000000000..14ff578dc --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,171 @@ +import functools +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.utils import Assert + + +class VisionEncoderDimNames: + in_channels = "vision_in_channels" + out_channels = "vision_out_channels" + adapter_size = "vision_adapter_size" + patch_size = "vision_patch_size" + kv_channels = "vision_kv_channels" + + +class VisionEncoderKwargs: + patch_size = "patch_size" + images = "images" + image_patches = "image_patches" + image_positions = "image_positions" + max_image_size = "max_image_size" + image_sizes = "image_sizes" + image_rescale_factor = "image_rescale_factor" + kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" + + +@config_class() +class ImageNormalizationConfig(Config): + mean_red: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_green: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_blue: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_red: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_green: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_blue: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + @functools.cached_property + def mean(self) -> list[float]: + return [self.mean_red, self.mean_green, self.mean_blue] + + @functools.cached_property + def std(self) -> list[float]: + return [self.std_red, self.std_green, self.std_blue] + + +@config_class(registry=True) +class VisionEncoderConfig(BaseModelConfig): + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is VisionEncoderConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return NoVisionEncoderConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={VisionEncoderConfig: "none"}) +class NoVisionEncoderConfig(BaseModelConfig): + _abstract = False + + +@config_class(dynamic_type={VisionEncoderConfig: "pixtral"}) +class PixtralVisionEncoderConfig(BaseModelConfig): + _abstract = False + + transformer: TransformerConfig = Field( + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) + self.transformer.setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py new file mode 100644 index 000000000..d91fed163 --- /dev/null +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -0,0 +1,66 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import PixtralVisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class PatchConvolution(Layer): + """ + A convolution layer applied to image patches to create embeddings for each patch. These embeddings are fed into the vision transformer. + """ + + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._config = config + self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_tensor_parallel = self._distributed_config.sequence_tensor_parallel + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + ), + init_method=init_normal_(), + lr_scale=self._config.adapter_lr_scale, + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_sclae=self._config.adapter_lr_scale, + ) + else: + self.bias = None + self.normalization = config.patch_normalization.get_layer( + tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[TransformerKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._config.patch_size) + patch_embeddings = self.normalization(input_.flatten(1)).view( + kwargs[TransformerKwargs.batch_dim].size, + kwargs[TransformerKwargs.sequence_q_dim].size, + self._config.transformer.hidden_size, + ) + if kwargs[TransformerKwargs.sequence_first]: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + if self._sequence_tensor_parallel: + patch_embeddings = split(patch_embeddings, group=self._tensor_space.distributed.tensor_group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 000000000..8f9be8012 --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,253 @@ +import math +import typing + +import torch +import torchvision.transforms.v2 as torchvision_transforms + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionKwargs +from fast_llm.layers.vision_encoder.config import ( + PixtralVisionEncoderConfig, + VisionEncoderDimNames, + VisionEncoderKwargs, +) +from fast_llm.tensor import TensorMeta +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + target_height, target_width = get_resize_dims( + image.size(1), image.size(2), max_height, max_width, patch_size=patch_size + ) + height, width = image.size(1), image.size(2) + while height > 2 * target_height or width > 2 * target_width: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + intermediate_max_width = max(target_width, width // 2) + intermediate_max_height = max(target_height, height // 2) + height, width = get_resize_dims( + height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size + ) + image = torchvision_transforms.functional.resize( + image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = max_image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + +class VisionPreprocessor(Preprocessor): + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + TransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, self._config.patch_size), + TensorDim(VisionEncoderDimNames.patch_size, self._config.patch_size), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get(VisionEncoderKwargs.images) + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + im_width = kwargs.get(VisionEncoderKwargs.max_image_size) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = [ + [ + get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=self._config.patch_size) + for im in ims + ] + for ims in images + ] + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + images = [ + [ + torchvision_transforms.functional.normalize( + resize(image, max_image_size, im_width, self._config.patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=self._config.image_normalization.mean, + std=self._config.image_normalization.std, + ) + for image in imgs + ] + for imgs in images + ] + + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + kwargs.get(TransformerKwargs.sequence_first) + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] + sample_cu_seqlen = 0 + for image, size, position in zip(imgs, sizes, positions): + seqlen = get_num_patches(*size, self._config.patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=self._config.patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold( + image, kernel_size=self._config.patch_size, stride=self._config.patch_size + ).T.reshape(-1, 3, self._config.patch_size, self._config.patch_size), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, self._config.patch_size, self._config.patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + if sizes: + position_ids = torch.cat( + [ + position_ids_in_meshgrid( + *size, max_image_size // self._config.patch_size, self._config.patch_size + ) + for size in sizes + ] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionKwargs.patch_position_ids] = patch_position_ids + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, self._config.patch_size**2) + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k + kwargs[TransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[TransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[TransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[TransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..039b97f8c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -71,6 +71,17 @@ class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointForma trust_remote_code: typing.ClassVar[bool] = True +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -163,6 +174,8 @@ class GPTModelConfig(FastLLMModelConfig): MTPLlamaGPTHuggingfaceCheckpointFormat, DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -177,6 +190,25 @@ def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelF return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..a15a237f9 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,8 +6,10 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, @@ -22,11 +24,16 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LayerNormalizationConfig from fast_llm.layers.transformer.config import RoutingType, TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.transformer.rotary.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + Rotary2DConfig, + YarnRotaryConfig, +) from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, @@ -34,9 +41,11 @@ GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -115,7 +124,37 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class WeightAndBiasConverterMixin: + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): _model: GPTModel _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig architecture: typing.ClassVar[str] @@ -126,6 +165,9 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter( + fast_llm_names=(("transformer", "type"),), fast_llm_value="language_model_decoder" + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( @@ -173,17 +215,23 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) return converters @@ -254,7 +302,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) @@ -263,20 +311,22 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) # MTP-heads > 0 are thrown away for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True @@ -319,10 +369,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -389,7 +439,7 @@ def __post_init__(self): def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: + if type(rotary_config) is DefaultRotaryConfig or rotary_config is Rotary2DConfig: rotary_scaling = { "rope_type": "default", } @@ -446,10 +496,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -498,10 +548,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -544,10 +594,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -566,12 +616,402 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralNumHeadsConverter(ParamConverter): + """ + Pixtral encoder uses Multi-Head Attention. + Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 2) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads, head_groups) = fast_llm_values + assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" + return (num_heads,) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads,) = export_values + return (num_heads, num_heads) + + +class PixtralRotaryParamConverter(ParamConverter): + """ + Pixtral encoder uses 2D Rotary Embeddings. + Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. + """ + + def __init__(self, fast_llm_names, export_names): + Assert.eq(len(fast_llm_names), 1) + Assert.eq(len(export_names), 1) + self.fast_llm_names = fast_llm_names + self.export_names = export_names + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_config,) = fast_llm_values + if type(rotary_config) is Rotary2DConfig: + return (rotary_config.theta,) + else: + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_theta,) = export_values + rotary_config = { + "type": "rope_2d", + "theta": rotary_theta, + } + return (rotary_config,) + + +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), + ConstantImportParamConverter(fast_llm_names=(("patch_normalization", "type"),), fast_llm_value="rms_norm"), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("hidden_size",),), + ), + PixtralNumHeadsConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ( + "transformer", + "head_groups", + ), + ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", + ), + ), + export_names=(("head_dim",),), + ), + # ConstantImportParamConverter( + # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + # ), + # RenameParamConverter( + # fast_llm_names=( + # ( + # "transformer", + # "rotary", + # "theta", + # ), + # ), + # export_names=(("rope_theta",),), + # ), + PixtralRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + norm_bias = isinstance( + self._model.config.base_model.vision_encoder.patch_normalization, LayerNormalizationConfig + ) + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.normalization.weight", f"{hf_base_prefix}ln_pre.weight")) + if norm_bias: + converters.append(WeightConverter(f"layers.{offset}.normalization.bias", f"{hf_base_prefix}ln_pre.bias")) + + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 + + +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -609,13 +1049,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -697,7 +1137,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -705,6 +1144,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -725,7 +1165,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat - architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -733,6 +1172,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -763,4 +1203,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..754a0235e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,6 +14,7 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -21,7 +22,11 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.patch_conv import PatchConvolution +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -63,6 +68,10 @@ def __init__( if self._config.enable_dpo: # TODO better way to pass in? self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + if self._config.vision_encoder.enabled: + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) + def get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): @@ -87,9 +96,25 @@ def get_output_layers(self) -> list[Layer]: ) return layers + def get_vision_layers(self) -> list[Layer]: + vit_layers = [ + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + PatchConvolution(self._config.vision_encoder, self._tensor_space), + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + *( + self.get_vision_layers() + if not self._config.vision_encoder.enabled + else [LanguageModelEmbedding(self._config, self._tensor_space)] + ), *[ TransformerLayer( self._config.transformer, @@ -114,17 +139,18 @@ def preprocess_meta( micro_batch_size = batch_meta.micro_batch_size sequence_length = batch_meta.sequence_length micro_sequence_length = batch_meta.micro_sequence_length - truncate_documents = batch_meta.truncate_documents else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length - truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + if isinstance(batch_meta, GPTBatchConfig): + micro_sequence_length = batch_meta.micro_sequence_length + if micro_sequence_length is None: micro_sequence_length = sequence_length else: @@ -168,9 +194,43 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, - LanguageModelKwargs.mask_inputs: not truncate_documents, + TransformerKwargs.batch_dim: batch_dim, } + if self._config.vision_encoder.enabled: + max_image_size = batch_meta.max_image_size + image_mean = [ + self._config.vision_encoder.image_normalization.mean_red, + self._config.vision_encoder.image_normalization.mean_green, + self._config.vision_encoder.image_normalization.mean_blue, + ] + image_std = [ + self._config.vision_encoder.image_normalization.std_red, + self._config.vision_encoder.image_normalization.std_green, + self._config.vision_encoder.image_normalization.std_blue, + ] + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor + vision_kwargs = { + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + VisionEncoderKwargs.max_image_size: max_image_size, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + TransformerDimNames.kv_channels + ).size, + } + vision_hidden_dim = self._tensor_space.vision.get_tensor_dim(TransformerDimNames.hidden) + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + TransformerKwargs.hidden_dims: vision_hidden_dims, + } + ) + common_kwargs["vision"] = vision_kwargs + sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, sequence_length, @@ -215,7 +275,11 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - preprocessed_meta.append((tokens, kwargs)) + if self._config.vision_encoder.enabled: + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -322,22 +386,56 @@ def preprocess( if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + if self._config.vision_encoder.enabled: + if self._config.vision_encoder.image_break_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + if self._config.vision_encoder.image_end_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch_images + ] + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) - preprocessed.append((tokens, kwargs)) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: + if self._config.vision_encoder.enabled: + return self._config.vision_encoder.transformer.num_layers + 2 + else: + return 0 @property def model_head(self) -> LanguageModelHead: @@ -352,7 +450,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..b81a3767e 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,4 +33,13 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.prediction_heads, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "max_image_size": self._config.batch.max_image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/setup.cfg b/setup.cfg index 2f69b8e06..715f6b630 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,13 @@ HUGGINGFACE = SSM = mamba_ssm[causal-conv1d]==2.2.4 +# Required for supporting vision inputs +VISION = + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 + DEV = # Pre-commit git hook pre-commit>=4.2.0 diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..6bd6b2126 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -231,3 +231,6 @@ def get_document_size(self, index: int) -> int: def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() + + def has_images(self) -> bool: + return False diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 32d76fa4c..b8e7a92ff 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -106,6 +106,10 @@ def get_document_size(self, index: int) -> int: def name(self) -> str: return "dataset" + @property + def has_images(self) -> bool: + return False + TEST_DATASET = SimpleGPTIndexedDataset( [ diff --git a/tests/test_config.py b/tests/test_config.py index b6a9a9854..9e09e8041 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -137,6 +137,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "tie_word_embeddings": False, "vocab_size": 1000, + "vision_encoder": {"type": "none"}, } else: base_model_update["transformer"]["peft"] = { @@ -146,6 +147,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["rotary"] = {"type": "none"} + base_model_update["vision_encoder"] = {"type": "none"} expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config)