diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 59cdd51b8..737f6fa3c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,11 +27,12 @@ jobs: - name: Install dependencies run: | + sudo apt install libjpeg-dev pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 0eef80a3e..bb3547496 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,11 +29,12 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build @@ -56,6 +57,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ diff --git a/Dockerfile b/Dockerfile index 71f59fffe..7cf951017 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,15 +29,16 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +# Using varlen_mamba for variable length sequence support RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..9df9b9b86 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None @@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True + else: + batch_images.append([]) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append([]) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc9..692776a24 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,6 +76,10 @@ class GPTSamplingParameters(SamplingParameters): use_preference_loss_spans: bool = False cross_document_attention: bool = True truncate_documents: bool = True + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -142,11 +146,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..b05b79b24 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) + middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..8a4440ae4 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -34,6 +34,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": else GPTSampledIndexedDataset(self, sampling) ) + @property + @abc.abstractmethod + def has_images(self) -> bool: + """ + Whether the dataset contains images. + This is used to determine whether to use image-related fields in the sampled data. + """ + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ @@ -44,11 +52,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes = self._dataset.get_document_sizes() + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) + @property + def has_images(self) -> bool: + return self._dataset.has_images + class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset @@ -57,8 +70,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + sizes = [dataset.get_document_sizes() for dataset in self._datasets] + return ( + np.concatenate([size[0] for size in sizes]), + np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), + ) def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + @property + def has_images(self) -> bool: + return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..4f62561a8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,32 +28,46 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: self._has_preference_spans = struct.unpack("= 4: + self._has_images = struct.unpack("= 2: @@ -77,9 +94,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -87,30 +103,29 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - + offset += self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize # read preference spans self._chosen_spans = None self._rejected_spans = None if self._has_preference_spans and self._version >= 3: self._chosen_spans = [] self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes for idx in range(self._num_documents): self._chosen_spans.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + offset=offset + idx * 2 * np.dtype(np.int32).itemsize, ) ) - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) + rejected_span_offset = offset + np.array(self._chosen_spans).nbytes for idx in range(self._num_documents): self._rejected_spans.append( np.frombuffer( @@ -120,16 +135,53 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, ) ) + offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes + + self._num_pixels = 0 + self._image_sizes = [] + self._image_positions = None + if self._has_images and self._version >= 4: + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._image_sizes = [] + self._image_positions = [] + images_seen = 0 + num_total_images = self._n_images.sum() + for n_images in self._n_images: + self._image_sizes.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * num_total_images * np.dtype(np.int32).itemsize + + +images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) + if num_pixels is not None: + assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -156,6 +208,24 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None + image_positions = None + if self._has_images: + image_positions = self._image_positions[idx] + + # Truncations with images are not yet supported, so we get all images from the document + pixels = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + images = [] + start = 0 + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) + start += n_pixels sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] @@ -202,6 +272,8 @@ def get( return GPTSample( token_ids=token_ids, + images=images, + image_positions=image_positions, loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -218,23 +290,31 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def get_document_sizes(self) -> np.ndarray: + @property + def has_images(self) -> bool: + return self._has_images + + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._image_sizes def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + image_sizes = [] + im_positions = [] + total_images = 0 pointers = [] offset = 0 # number of spans for each document @@ -259,10 +339,28 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image in document.images: + # assume 3 channels (RGB) for all images + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + image_sizes.append(np.array(pixels.shape[1:])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.extend(document.image_positions) + else: + n_images.append(0) # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) @@ -271,11 +369,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans.append(document.chosen_span) if document.rejected_span is not None: rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -285,25 +383,37 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + if total_images: + n_images = np.array(n_images, dtype=np.int32) + image_sizes = np.stack(image_sizes, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + image_sizes = np.array([]) + im_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans + # Version 2 onwards optionally add loss-masking spans # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_image_tokens( + *get_resize_dims( + *size, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for size in sizes + ) + ) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() # Calculate basic stats. if not self._truncate_documents: @@ -143,14 +175,14 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 + long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 ignored_documents = long_docs_filter.sum().item() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -193,7 +225,10 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -294,7 +329,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -317,6 +352,9 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -442,6 +480,10 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + image_positions = [] + image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -449,7 +491,28 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + document_size = text_size + image_tokens if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -468,21 +531,97 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - token_ids.append(sample.token_ids) + start_pos = 0 + has_images = sample.image_positions is not None + if has_images: + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_tokens_added += image_sizes[idx] + start_pos = im_position + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + else: + token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) + if sample.images: + images.append(sample.images) + else: + images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + prev_image_tokens = 0 + image_idx = 0 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if has_images and image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -506,9 +645,17 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: @@ -593,7 +740,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -175,6 +187,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 427309a99..d6d473838 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,6 +10,7 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm @@ -24,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -41,36 +48,44 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _loss_masking_spans_column: str | None def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) + for input_ids, token_spans, image_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + ) + for text, loss_mask_spans, im_char_positions in zip( + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), + ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + return { "input_ids": input_ids, + "image_positions": image_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -143,27 +158,22 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + has_preference_spans = ( + self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None + ) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._loss_masking_spans_column + else None + ), + item["chosen_token_spans"] if has_preference_spans else None, + item["rejected_token_spans"] if has_preference_spans else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -173,6 +183,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -292,6 +303,11 @@ def run(self) -> None: if isinstance(self._config.dataset.source_schema, TextColumnConfig): self._text_column = self._config.dataset.source_schema.input_column self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column + if isinstance(self._config.dataset.source_schema, TextImageColumnConfig): + self._images_column = self._config.dataset.source_schema.images_column + self._image_positions_column = self._config.dataset.source_schema.image_positions_column + # decoding bytes to images is slow and should be done only when needed + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." @@ -300,18 +316,17 @@ def run(self) -> None: if self._text_column not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + if self._loss_masking_spans_column is not None and ( self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: if self._config.dataset.chosen_text not in dataset.column_names: raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") @@ -331,6 +346,13 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._images_column + else 0 + ) + # Add the token-equivalent bytes of pixels to determine shard size + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -359,7 +381,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -399,7 +421,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -429,10 +455,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -445,8 +481,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..d46e38935 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -41,44 +41,75 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ - input_ids = [] + if not image_positions: + image_positions = [] + if not char_spans: + char_spans = [] + + # Collect all positions with their type + positions = [] + for pos in image_positions: + positions.append((pos, "image")) + + for start, end in char_spans: + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + + token_ids = [] token_spans = [] + image_token_positions = [] char_pos = 0 - beginning_of_text = True + current_span_start = None - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 + for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times + if char_pos < position[0]: + tokenized_text = self._tokenize( + text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 + ) + token_ids.extend(tokenized_text) + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6a..f8a42b31a 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -283,7 +283,7 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config # Noqa @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + def _import_config_dict(cls, config: dict[str, typing.Any]) -> dict[str | tuple[str, ...], typing.Any]: kwargs = {} for converter in cls._get_config_converters(): try: @@ -306,7 +306,11 @@ def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # no kwargs[fast_llm_name] = value except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + return kwargs + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + kwargs = cls._import_config_dict(config) return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) def _convert_state_dict( diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005f..4cfff4afa 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -134,6 +134,7 @@ class CustomModelingExportMixin: configuration_file: typing.ClassVar[str] configuration_cls: typing.ClassVar[type[PretrainedConfig]] generation_utils_file: str | None = None + additional_files: typing.ClassVar[list[str]] = [] # Use custom config instead of relying on the transformers library @classmethod @@ -159,3 +160,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" if gen_config.exists(): shutil.copy(gen_config, config.path) + for file in self.additional_files: + shutil.copy(file, config.path) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 7ab5b8e41..b23037e84 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -130,7 +130,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,9 +6,13 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +43,180 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.parallel_dim is not None: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +249,19 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 6ac157dfe..719088057 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -31,6 +31,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) @@ -241,6 +242,10 @@ def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> t def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError + @classmethod + def get_inference_runner_class(cls) -> type["InferenceRunner"]: + raise NotImplementedError + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: raise NotImplementedError diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 87eac31c4..df9259abd 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -138,7 +138,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,12 +185,15 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 272b7c6ae..a5e0a86a6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -48,6 +48,12 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4b8d805b8..9372ad7fb 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -32,7 +32,6 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @@ -388,7 +387,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() @@ -403,10 +402,6 @@ def _setup(self): def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError - @classmethod - def get_inference_runner_class(cls) -> type["InferenceRunner"]: - raise NotImplementedError - def _get_runnable(self) -> typing.Callable[[], None]: from fast_llm.engine.distributed.distributed import Distributed diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 5f5511a15..ec3c4cebe 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -142,7 +142,7 @@ def __init__(self, config: TrainerConfig): self._reference_models = {} for name, reference_config in self._config.reference_models.items(): log_main_rank(f"Creating `{name} reference model...") - self._reference_models[name] = self._config.get_inference_runner_class()( + self._reference_models[name] = reference_config.model.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 684193848..5c8d75a6f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,7 +80,8 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", @@ -86,9 +89,16 @@ def _set_activation_fn_map() -> None: } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 +class ReverseKLImpl(str, enum.Enum): + tp = "tp" + stp = "stp" + no_tp = "no_tp" + + class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98d..d9ca547a7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -49,6 +49,19 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad +def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1): + logits = logits.float() + local_max = logits.max(dim=dim, keepdim=True)[0] + all_reduce(local_max, op=ReduceOp.MAX, group=group) + + logits_shifted = logits - local_max + exp_logits = torch.exp(logits_shifted) + sum_exp = exp_logits.sum(dim=dim, keepdim=True) + all_reduce(sum_exp, op=ReduceOp.SUM, group=group) + + return logits_shifted - sum_exp.log() # log_softmax + + @torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 @@ -151,7 +164,8 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() return loss, grad @@ -213,20 +227,30 @@ def cross_entropy_forward_backward( ) -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_vocab_parallel( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This is used for TP version where we split accross vocab dimantion. + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype @@ -234,32 +258,78 @@ def _torch_reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + batch_size = logits.shape[0] + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked") + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() + + if group is not None and target_format != TargetFormat.labels: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= batch_size + + if grad_output is not None: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + +def _torch_reverse_kl_forward_backward_no_tp( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for no-TP case. + """ + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) # Scale target logits more carefully scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) # Use kl_div with: input=log(p), target=q, log_target=False # This gives: Σ q * (log(q) - log(p)) = exactly what we want! with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) - # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: @@ -274,12 +344,85 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if grad_output is not None: + # note, we never get here in TP over seq. dim. + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for sequence-tensor-parallel case where we split over sequence dimension. + """ + Assert.eq( + total_valid_tokens is not None, + msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL", + ) + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + # Scale target logits more carefully + scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) + + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) + + # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) + # Use kl_div with: input=log(p), target=q, log_target=False + # This gives: Σ q * (log(q) - log(p)) = exactly what we want! + + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + + scaled_logits = logits_ * logits_scale_factor + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked if grad_output is not None: + # note, if we compute gradient w.r.t sum of losses, + # and grad_output should reflect the scaling by 1/valid samples loss.backward(torch.full_like(loss, grad_output)) grad = logits_.grad.to(logits.dtype) else: @@ -288,6 +431,13 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad +REVERSE_KL_IMPLEMENTATIONS = { + ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp, + ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel, + ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel, +} + + def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -297,6 +447,8 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp, + total_valid_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -339,7 +491,15 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature + # TODO: implement fused reverse KL? + return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + total_valid_tokens=total_valid_tokens, ) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ab408368f..f3d9d7d0c 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,8 +47,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - # Triton doesn't like enums, so we use str instead of ActivationType. - if activation_type == "gelu": + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -98,8 +97,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - # Triton doesn't like enums, so we use str instead of ActivationType. - if activation_type == "gelu": + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e2e97f1a..b0bb6ec6f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -7,6 +7,7 @@ from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -37,6 +38,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" @@ -50,6 +52,10 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + vision_encoder: VisionEncoderConfig = Field( + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -244,6 +250,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) + if self.vision_encoder.enabled: + self.vision_encoder.setup_tensor_space(tensor_space) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..b1f3564b9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -11,7 +11,13 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import ( + CrossEntropyImpl, + DistillationLossImpl, + ReverseKLImpl, + TargetFormat, + TritonConfig, +) from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -61,7 +67,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +114,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -125,12 +131,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) @@ -233,13 +243,24 @@ def _get_targets( ).flatten() else: lm_target = None - - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: + targets = (dpo_target, lm_target, distillation_target) + # If we do distillation, no need to split it here as it has already been split in the embedding layer! + # if we do CPT/language modeling, we need to split the targets here! + if ( + self._config.distillation_model is not None + and self._sequence_parallel_logits + and not self._parallel_embeddings + and not self._sequence_parallel + ) or (self._config.distillation_model is None and self._sequence_parallel_logits): + # We dont split targets if they already have been split in the embedding layer! targets = [ None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets ] + # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. + if loss_mask is not None and self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) + targets = (*targets, loss_mask) if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -298,12 +319,13 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) - if loss_count != 1: - loss.div_(loss_count) - if self._sequence_parallel_logits: - # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + assert self._cross_entropy_splits is None, "This is not supported for now" + # loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + # if loss_count != 1: + # loss.div_(loss_count) + # if self._sequence_parallel_logits: + # # TODO: Async + # all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -334,9 +356,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( @@ -397,6 +419,29 @@ def _logits_cross_entropy_forward_backward( if distillation_target is not None and self._distillation_loss_factor > 0.0: if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + local_valid_tokens = total_valid_tokens = logits.shape[0] + if logits.shape[-1] != self._config.vocab_size: + reverse_kl_impl = ReverseKLImpl.tp + assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet" + elif self._sequence_parallel_logits: + # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward + reverse_kl_impl = ReverseKLImpl.stp + if loss_mask is not None: + local_valid_tokens = loss_mask.sum() + total_valid_tokens = local_valid_tokens.clone() + all_reduce( + total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group + ) + else: + local_valid_tokens = logits.shape[0] + total_valid_tokens = local_valid_tokens * self._group_size + # in the loss function we compute grads w.r.t sum of losses, + # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling + # note, the function returns the sum of local losses, so we need to handle this properly for reporting + grad_output *= self._group_size / total_valid_tokens # multiply back by the group size + else: + reverse_kl_impl = ReverseKLImpl.no_tp + distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -408,7 +453,14 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + reverse_kl_impl=reverse_kl_impl, + total_valid_tokens=total_valid_tokens, ) + if self._sequence_parallel_logits: + # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling + all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group) + distillation_loss /= total_valid_tokens # final global loss + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 000000000..a5a789f9e --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,183 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import reduce_forward, split +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert, div + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + + # @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + # Cloning since we will modify the embeddings in-place + embeddings = embeddings.clone() + # the embeddings tensor are full-sized, but we might get a split of the patch embeddings + # We need to determine the offset in the embeddings tensor for each sample + # and also account for the special image tokens if applicable + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + image_embedding_offset += num_patches + continue + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_src + patch_width <= patch_start_offset: + continue + + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) + if self._sequence_parallel: + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx + ] + else: + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index + ] + else: + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx + ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] + image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + # Move to the next image in the input tensor + image_embedding_offset += num_patches + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..194063a26 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,59 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div + +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer + + +class BaseSSMKwargs: + _kwargs_attributes = { + "cu_seqlens": "cu_seqlens", + "seq_idx": "seq_idx", + "ssm_position_ids": "ssm_position_ids", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseSSMKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) + + +class SSMKwargs(BaseSSMKwargs, prefix=""): + pass class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + head_dim = "ssm_head_dim" + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions + composite_heads = "ssm_composite_heads" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" class SSMBlockType(enum.StrEnum): @@ -35,6 +66,32 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + + +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + @config_class() class SSMConfig(LLMBlockConfig): @@ -45,79 +102,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( + default=None, + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +190,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( - default="random", + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +232,79 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) + num_head_groups = num_heads + elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) + num_head_groups = div(self.d_xb, self.state_size) + elif block_type == SSMBlockType.mamba2_discrete: + # TODO: Use different variables? + num_heads = self.n_v_heads + num_head_groups = self.n_qk_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) + tensor_space.add_tensor_dim( + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) + ) + tensor_space.add_tensor_dim( + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state) + ) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) + ) + elif block_type == SSMBlockType.mamba2: + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), + ) + ) + elif block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), + ) + ) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..c9d555de9 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,14 +1,16 @@ import logging -import math +import typing import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -30,229 +32,195 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() - self.config: SSMConfig = config - bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] + heads_dim = tensor_space[SSMDimNames.composite_heads] + + # local_head_groups = head_groups / TP + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size + self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state + self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.concatenated_inner_projection], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not bias - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - def forward(self, hidden_states, kwargs): - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - input_ = hidden_states - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + inner_projection = self.in_proj(input_) + # Standardize to (batch, padded_sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + y = _mamba_chunk_scan_combined( + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") + return self.out_proj(y) + + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", - "swish", - "identity", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..5ed689a73 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,19 +1,34 @@ -import math +import inspect +import logging import typing -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale - +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames, SSMKwargs +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale + +_mamba_varlen = False try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + logging.warning("Using selective_scan_fn from varlen_mamba that supports packing") + else: + _mamba_varlen = False + logging.warning("Using selective_scan_fn from original mamba without packing support") + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + except (ImportError, RuntimeError): _mamba_available = False @@ -24,236 +39,238 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False +logger = logging.getLogger(__name__) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_head_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size + + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + self.conv1d_weight = ParameterMeta.from_dims( + ( + conv1d_dim, + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) - - self.repeat_kv_before_conv = config.repeat_kv_before_conv - - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) - - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.concatenated_inner_projection], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - - self.dt_proj = Linear( - tdt_rank, - td_inner, + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space[SSMDimNames.state]), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states + Note, we are nto doing "read" sequence-tensor parallel trainign here, since inner_projection is gathered over all GPUS. + This is also desired, since the currently used mamba kernel does not support STP. + TODO: use correct kernel from Mamba2! """ assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") - - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() + assert _causal_conv1d_available + cu_seqlens = kwargs.get(SSMKwargs.cu_seqlens) + seq_idx = kwargs.get(SSMKwargs.seq_idx) + position_indices = kwargs.get(SSMKwargs.ssm_position_ids) + + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) + + sequence_length = inner_projection.size(1) + + z, x, b, c = torch.split( + inner_projection, + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], + dim=2, + ) - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) + z = z.transpose(1, 2) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: + if cu_seqlens is not None: + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + x=x.transpose(1, 2).contiguous().transpose(1, 2), + weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D + seq_idx=seq_idx, + activation="silu", + ) else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") - - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") - - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=False, + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + + if not self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) + + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) + + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + + if not _mamba_varlen: + Assert.eq(cu_seqlens, None, msg="This version of Mamba2 does not support cu_seqlens, install verlen mamba") + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + ) + else: + position_indices = position_indices if cu_seqlens is not None else None + + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + position_indices=position_indices, + ) - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) + y = y.transpose(1, 2)[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..9343ef1b8 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,17 @@ +import logging import math -from typing import Callable +import typing -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +20,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -24,171 +29,137 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization - # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor - - return init_ - - -def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + if tensor.numel() != d_state * d_inner: + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, + ) + + return LambdaInitializer(init_, requires_global_initialization=True) + + +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space[SSMDimNames.concatenated_inner_projection], + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + ( + inner_dim, + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], + ), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space[SSMDimNames.concatenated_x_projection], + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space[SSMDimNames.dt_rank]), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space[SSMDimNames.state]), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, + in_proj, self.conv1d_weight, - self.conv1d_bias, + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/ssm/preprocessing.py b/fast_llm/layers/ssm/preprocessing.py new file mode 100644 index 000000000..343f0bb28 --- /dev/null +++ b/fast_llm/layers/ssm/preprocessing.py @@ -0,0 +1,68 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.ssm.config import SSMKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class Mamba2Preprocessor(Preprocessor): + def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._transformer_dim_names = config.transformer._transformer_dim_names + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + """ + Simplified preprocessor that does not take into account micro-sequences. + """ + if TransformerKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + if TransformerKwargs.cu_seqlens_k in kwargs: + # already set this in the transformer preprocessor, so we can use it here + cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k] + cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q] + Assert.eq( + cu_seqlens_k.shape[0], + cu_seqlens_q.shape[0], + msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.", + ) + Assert.all_equal(cu_seqlens_k, cu_seqlens_q) + cu_seqlens = cu_seqlens_k + else: + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + ) + ) + kwargs[SSMKwargs.cu_seqlens] = cu_seqlens + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + kwargs[SSMKwargs.seq_idx] = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ).unsqueeze(0) + + sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(self._tensor_space.distributed.device, dtype=torch.int64) + position_ids = position_ids[ + :, sequence_k - sequence_q : sequence_k + ] # this is only needed if we do micro-sequences? + kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..c03aeed8e 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -7,15 +7,10 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, -) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,42 +45,20 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + _mixer_name: typing.ClassVar[str] = "attn" + + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer + + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -99,22 +72,22 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[self._transformer_dim_names.kv_channels].size + self._head_groups = self._tensor_space[self._transformer_dim_names.head_groups].global_size + self._local_head_groups = self._tensor_space[self._transformer_dim_names.head_groups].size + self._local_heads_per_group = self._tensor_space[self._transformer_dim_names.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[self._transformer_dim_names.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -123,7 +96,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[self._transformer_dim_names.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -137,7 +110,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[self._transformer_dim_names.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -178,10 +151,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,39 +173,31 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, + @property + def _query_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_heads, + self._transformer_dim_names.kv_channels, ) - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, + @property + def _kv_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.group_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def _context_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_dense, ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool @@ -300,7 +265,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -311,12 +276,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -341,7 +306,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -356,7 +321,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -366,12 +331,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -381,7 +346,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -391,25 +356,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[self._transformer_kwargs.attention_mask], + kwargs[self._transformer_kwargs.attention_mask_value], ) - if self._debug_transformer: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + if self._debug_level: + self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, "key", - self._KV_DIMS, + self._kv_dims, kwargs, ) self._debug_log( value, "value", - self._KV_DIMS, + self._kv_dims, kwargs, ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug_log(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf5890..4d83215da 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -29,59 +29,86 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - # TODO: Move - grad_output = "grad_output" +class BaseTransformerDimNames: + _kwargs_attributes = { + "batch": "batch", + "sequence_q": "sequence_q", + "sequence_q_tp": "sequence_q_tp", + "sequence_k": "sequence_k", + "hidden": "hidden", + "head_groups": "head_groups", + "group_heads": "group_heads", + "key_and_value": "key_value", + "kv_channels": "kv_channels", + "composite_heads": "composite_heads", + "composite_query": "composite_query", + "composite_key_value": "composite_key_value", + "composite_dense": "composite_dense", + "mlp": "mlp", + "gate_and_up": "gate_and_up", + "composite_gated_mlp": "composite_gated_mlp", + "experts": "experts", + "top_experts": "top_experts", + "shared_experts": "shared_experts", + "unshared_experts": "unshared_experts", + "composite_expert_mlp": "composite_expert_mlp", + "composite_gated_expert_mlp": "composite_gated_expert_mlp", + "composite_shared_expert_mlp": "composite_shared_expert_mlp", + "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", + } + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): + setattr(cls, attr, f"{cls._prefix}_{value}") + + +class TransformerDimNames(BaseTransformerDimNames, prefix=""): + pass + + +class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): + pass + + +class BaseTransformerKwargs: + _kwargs_attributes = { + "rotary_freq_q": "rotary_freq_q", + "rotary_freq_k": "rotary_freq_k", + "attention_mask": "attention_mask", + "attention_mask_value": "attention_mask_value", + "sequence_lengths": "sequence_lengths", + "cu_seqlens_q": "cu_seqlens_q", + "cu_seqlens_k": "cu_seqlens_k", + "max_seqlen_q": "max_seqlen_q", + "max_seqlen_k": "max_seqlen_k", + "presents": "presents", + "past_key_values": "past_key_values", + "sequence_first": "sequence_first", + "hidden_dims": "hidden_dims", + "sequence_q_dim": "sequence_q_dim", + "sequence_k_dim": "sequence_k_dim", + "sequence_length": "sequence_length", + "micro_batch_size": "micro_batch_size", + "grad_output": "grad_output", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) + + +class TransformerKwargs(BaseTransformerKwargs, prefix=""): + pass + + +class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): + patch_position_ids = "patch_position_ids" class TransformerLossNames: @@ -206,9 +233,19 @@ def _validate(self) -> None: ) -@config_class() +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + +@config_class(registry=True) class TransformerConfig(LLMBlockConfig): _abstract = False + type: TransformerType = Field( + default=TransformerType.lm_decoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, @@ -485,6 +522,11 @@ class TransformerConfig(LLMBlockConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): @@ -604,61 +646,91 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + self._transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim( + kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + if self.type == TransformerType.image_encoder: + return VisionTransformerKwargs + else: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + if self.type == TransformerType.image_encoder: + return VisionTransformerDimNames + else: + return TransformerDimNames + + +for name in TransformerType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerConfig.register_subclass(name.value, TransformerConfig) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,12 +59,12 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..5dee4e077 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,16 +8,19 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index + + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -30,8 +33,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[self._transformer_dim_names.hidden] + self._intermediate_dim = tensor_space[self._transformer_dim_names.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -39,14 +42,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[self._transformer_dim_names.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, @@ -69,9 +72,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..ee30112d7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -24,11 +24,13 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -54,10 +56,10 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,14 +67,14 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[self._transformer_kwargs.attention_mask] = ( + kwargs[self._transformer_kwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -80,12 +82,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=self._transformer_kwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=self._transformer_kwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -96,6 +98,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ @@ -146,17 +150,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() + kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 748f2af28..eb739e5c4 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -140,3 +140,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.transformer.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..6b4b81415 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,14 +8,16 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -82,8 +84,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None @@ -212,3 +214,71 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(max_num_patches) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_q_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_q, + ) + kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_k_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_k, + ) + + def _create_tensors(self, max_num_patches: int) -> None: + if max_num_patches <= self._tensor_cache_max_num_patches: + return + self._tensor_cache_max_num_patches = max_num_patches + + self._rotary_embedding_frequencies = self._get_frequencies( + max_num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + + def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..9289dccfb 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,45 +8,108 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +117,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,10 +130,10 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) @@ -137,14 +200,21 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention + + return Attention(self._config, self._tensor_space, self._block_index) + - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) +class VisionTransformerBlock(TransformerBlock): + _name: str = "Vision transformer layer" diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 000000000..7ec50dfee --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,55 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space[VisionEncoderDimNames.out_channels] + self._activation_type = config.adapter_activation_type + self.layer_1 = Linear( + input_dim, + tensor_space[VisionEncoderDimNames.adapter_size], + bias=True, + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, + ) + self.layer_2 = Linear( + tensor_space[VisionEncoderDimNames.adapter_size], + tensor_space[TransformerDimNames.hidden], + bias=True, + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 000000000..a705d948a --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,181 @@ +import enum + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.utils import Assert + + +class VisionEncoderDimNames: + in_channels = "vision_in_channels" + out_channels = "vision_out_channels" + adapter_size = "vision_adapter_size" + patch_size = "vision_patch_size" + kv_channels = "vision_kv_channels" + + +class VisionEncoderKwargs: + patch_size = "patch_size" + images = "images" + image_patches = "image_patches" + image_positions = "image_positions" + max_image_size = "max_image_size" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" + + +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + +class VisionEncoderType(str, enum.Enum): + none = "none" + # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. + pixtral = "pixtral" + + +@config_class(registry=True) +class VisionEncoderConfig(BaseModelConfig): + _abstract = False + + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) + transformer: TransformerConfig = Field( + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + patch_norm: NormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + adapter_init_method_std: float = Field( + default=None, + desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.adapter_init_method_std is None: + self.adapter_init_method_std = self.adapter_size**-0.5 + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) + self.transformer.setup_tensor_space(tensor_space) + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none + + +for name in VisionEncoderType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py new file mode 100644 index 000000000..6c2a70930 --- /dev/null +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -0,0 +1,62 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space[VisionEncoderDimNames.out_channels], + self._tensor_space[VisionEncoderDimNames.in_channels], + self._tensor_space[VisionEncoderDimNames.patch_size], + self._tensor_space[VisionEncoderDimNames.patch_size], + ), + init_method=init_normal_(), + lr_scale=self._lr_scale, + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space[VisionEncoderDimNames.out_channels],), + init_method=init_normal_(), + lr_scale=self._lr_scale, + ) + else: + self.bias = None + self.norm = config.patch_norm.get_layer(tensor_space[VisionEncoderDimNames.out_channels]) + self.stride = config.patch_size + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 000000000..adacd380c --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,281 @@ +import math +import typing + +import torch +import torchvision.transforms.v2 as torchvision_transforms + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import TensorMeta +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + target_height, target_width = get_resize_dims( + image.size(1), image.size(2), max_height, max_width, patch_size=patch_size + ) + height, width = image.size(1), image.size(2) + while height > 2 * target_height or width > 2 * target_width: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + intermediate_max_width = max(target_width, width // 2) + intermediate_max_height = max(target_height, height // 2) + height, width = get_resize_dims( + height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size + ) + image = torchvision_transforms.functional.resize( + image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return torchvision_transforms.functional.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = max_image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get(VisionEncoderKwargs.images) + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + im_width = kwargs.get(VisionEncoderKwargs.max_image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = [ + [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] + for ims in images + ] + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + images = [ + [ + normalize( + resize(image, max_image_size, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images + ] + + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + kwargs.get(TransformerKwargs.sequence_first) + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] + sample_cu_seqlen = 0 + for image, size, position in zip(imgs, sizes, positions): + seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + if LanguageModelKwargs.labels in kwargs: + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], + max_image_size, + patch_size, + ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed_config.training_dtype.torch).min, + dtype=self._distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..182ad1712 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -71,6 +71,17 @@ class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointForma trust_remote_code: typing.ClassVar[bool] = True +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -163,6 +174,8 @@ class GPTModelConfig(FastLLMModelConfig): MTPLlamaGPTHuggingfaceCheckpointFormat, DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -171,12 +184,37 @@ def get_model_class(cls) -> type["GPTModel"]: return GPTModel + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @@ -225,6 +263,9 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + if self.model.base_model.vision_encoder.enabled: + assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" + Assert.gt(self.batch.max_image_size, 0) @classmethod def _from_dict( @@ -254,9 +295,3 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - return GPTInferenceRunner diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..fb1801067 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,12 +6,15 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreExportWeightConverter, IgnoreImportParamConverter, IgnoreImportWeightConverter, @@ -22,11 +25,16 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LayerNormalizationConfig from fast_llm.layers.transformer.config import RoutingType, TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.transformer.rotary.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + Rotary2DConfig, + YarnRotaryConfig, +) from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, @@ -34,9 +42,11 @@ GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -115,7 +125,37 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class WeightAndBiasConverterMixin: + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): _model: GPTModel _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig architecture: typing.ClassVar[str] @@ -126,6 +166,7 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="lm_decoder"), ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( @@ -173,17 +214,23 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) return converters @@ -254,7 +301,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) @@ -263,20 +310,22 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) # MTP-heads > 0 are thrown away for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True @@ -389,7 +438,7 @@ def __post_init__(self): def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: + if type(rotary_config) is DefaultRotaryConfig or rotary_config is Rotary2DConfig: rotary_scaling = { "rope_type": "default", } @@ -566,6 +615,403 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralNumHeadsConverter(ParamConverter): + """ + Pixtral encoder uses Multi-Head Attention. + Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 2) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads, head_groups) = fast_llm_values + assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" + return (num_heads,) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads,) = export_values + return (num_heads, num_heads) + + +class PixtralRotaryParamConverter(ParamConverter): + """ + Pixtral encoder uses 2D Rotary Embeddings. + Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. + """ + + def __init__(self, fast_llm_names, export_names): + Assert.eq(len(fast_llm_names), 1) + Assert.eq(len(export_names), 1) + self.fast_llm_names = fast_llm_names + self.export_names = export_names + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_config,) = fast_llm_values + if type(rotary_config) is Rotary2DConfig: + return (rotary_config.theta,) + else: + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_theta,) = export_values + rotary_config = { + "type": "rope_2d", + "theta": rotary_theta, + } + return (rotary_config,) + + +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), + ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("hidden_size",),), + ), + PixtralNumHeadsConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ( + "transformer", + "head_groups", + ), + ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", + ), + ), + export_names=(("head_dim",),), + ), + # ConstantImportParamConverter( + # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + # ), + # RenameParamConverter( + # fast_llm_names=( + # ( + # "transformer", + # "rotary", + # "theta", + # ), + # ), + # export_names=(("rope_theta",),), + # ), + PixtralRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if norm_bias: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) + + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 + + +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = self.get_vision_handler_class() + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = self.get_text_handler_class() + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat architecture: typing.ClassVar[str] = "MixtralForCausalLM" @@ -763,4 +1209,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, DiffusionDreamGPTHuggingfaceCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaGPTHuggingfaceCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..da07e5291 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,14 +14,21 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, TransformerKwargs, TransformerLossNames, + VisionTransformerDimNames, + VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerBlock +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.vision_encoder.patch_conv import PatchConv +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -63,16 +70,20 @@ def __init__( if self._config.enable_dpo: # TODO better way to pass in? self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + if self._config.vision_encoder.enabled: + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) + def get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -87,14 +98,32 @@ def get_output_layers(self) -> list[Layer]: ) return layers + def get_vision_layers(self) -> list[Layer]: + vit_layers = [ + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + PatchConv(self._config.vision_encoder, self._tensor_space), + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_embedding_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + else: + return [LanguageModelEmbedding(self._config, self._tensor_space)] + def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + *(self.get_embedding_layers()), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -122,9 +151,42 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True + if self._config.vision_encoder.enabled: + try: + max_image_size = batch_meta.max_image_size + except AttributeError: + max_image_size = 256 + logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + image_mean = [ + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor + vision_kwargs = { + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + VisionEncoderKwargs.max_image_size: max_image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + } + else: + vision_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + if isinstance(batch_meta, GPTBatchConfig): + micro_sequence_length = batch_meta.micro_sequence_length + if micro_sequence_length is None: micro_sequence_length = sequence_length else: @@ -155,12 +217,24 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + if self._config.vision_encoder.enabled: + vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -168,8 +242,10 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.micro_batch_size: micro_batch_size, LanguageModelKwargs.mask_inputs: not truncate_documents, } + common_kwargs.update(vision_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -215,7 +291,11 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - preprocessed_meta.append((tokens, kwargs)) + if self._config.vision_encoder.enabled: + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -261,19 +341,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -293,16 +374,18 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config + labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): + labels_cloned = True + for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -313,31 +396,72 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, idx] = False + labels[start : end + 1, i] = -100 else: - loss_mask[idx, start : end + 1] = False - if self._config.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) + labels[i, start : end + 1] = -100 + if self._config.vision_encoder.enabled: + if self._config.vision_encoder.image_break_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True + labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + if self._config.vision_encoder.image_end_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True + labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + if self._config.distillation_model is not None: + loss_mask = torch.ones_like(labels, dtype=torch.bool) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch_images + ] + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) - preprocessed.append((tokens, kwargs)) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property - def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + def transformer_layers(self) -> list[TransformerBlock]: + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: + if self._config.vision_encoder.enabled: + return self._config.vision_encoder.transformer.num_layers + 2 + else: + return 0 @property def model_head(self) -> LanguageModelHead: @@ -352,7 +476,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..b81a3767e 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,4 +33,13 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.prediction_heads, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "max_image_size": self._config.batch.max_image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..34f3151a6 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,34 +4,37 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.models.gpt.model import GPTInferenceRunner from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMModel + from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel from fast_llm.models.ssm.trainer import HybridSSMTrainer logger = logging.getLogger(__name__) @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,48 +44,16 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -96,34 +67,24 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -133,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -144,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -155,9 +114,9 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -166,6 +125,26 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler +# class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): +# name: typing.ClassVar[str] = "llava" +# # Using default values for vision and text models. Can be overridden in the config +# vision_name: typing.ClassVar[str] = "pixtral" +# text_name: typing.ClassVar[str] = "mistral" + + +class LlavaHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava_hybrid" + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler + + return LlavaHybridHuggingfaceCheckpointHandler + + @config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False @@ -176,6 +155,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): AprielSSMHuggingfaceCheckpointFormat, AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) @classmethod @@ -185,7 +165,17 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: + from fast_llm.models.ssm.model import HybridSSMInferenceRunner + + logger.warning( + "HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled." + ) + + return HybridSSMInferenceRunner + + @classmethod + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM return HuggingfaceHybridSSMModelForCausalLM @@ -194,12 +184,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() @@ -223,6 +207,11 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: def _validate(self) -> None: super()._validate() + Assert.eq( + self.batch.micro_sequence_length, + self.batch.sequence_length, + msg="Micro-sequences not supported for SSMs. at htis point", + ) if (name := self.model.base_model.distillation_model) is None: Assert.empty(self.reference_models) else: @@ -238,14 +227,3 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? - logger.warning( - "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" - ) - - return GPTInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..64afbea06 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,10 +3,14 @@ import pathlib import typing +from transformers.configuration_utils import PretrainedConfig + +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -15,19 +19,29 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType +from fast_llm.models.gpt.conversion import ( + CommonLlamaHuggingfaceCheckpointHandler, + LlavaHuggingfaceCheckpointHandler, + MLPLayer2Converter, +) from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, ) +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -42,11 +56,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +216,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,13 +224,23 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = ( + super()._create_weight_converters( + hf_base_prefix=hf_base_prefix, + offset=offset, + ) + or [] + ) num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -224,55 +248,68 @@ def _create_weight_converters(self) -> list[WeightConverter]: for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{offset+i+1}.mixer.in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{offset+i+1}.mixer.out_proj", f"{hf_base_prefix}model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{offset+i+1}.mixer.D", + f"{hf_base_prefix}model.layers.{i}.mixer.D", + self._model.config.base_model, + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", + f"layers.{offset+i+1}.mixer.conv1d_weight", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", + f"layers.{offset+i+1}.mixer.conv1d_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # ================================================ # Mamba2 specific parameters converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False + f"layers.{offset+i+1}.mixer.dt_in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", + f"layers.{offset+i+1}.mixer.dt_proj_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj.bias", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model + f"layers.{offset+i+1}.mixer.A_log", + f"{hf_base_prefix}model.layers.{i}.mixer.A_log", + self._model.config.base_model, ) ) @@ -566,11 +603,16 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False + # TODO: use hf_base_prefix and offset # Embedding and output if self._model.config.base_model.tie_word_embeddings: converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -689,6 +731,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -703,9 +746,18 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) # num_layers = self._model.config.base_model.transformer.num_layers # # Embedding and output # if self._model.config.base_model.tie_word_embeddings: @@ -725,6 +777,14 @@ def _create_weight_converters(self) -> list[WeightConverter]: @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -749,16 +809,49 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] + # @classmethod + # def _load_config(cls, directory: pathlib.Path | str) -> dict: + # if not os.path.exists(directory / "config.json"): + # raise FileNotFoundError(f"config.json not found in {directory}") + # with open(directory / "config.json") as f: + # config = json.load(f) + # Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + # return config + + # @classmethod + # def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + # with open(directory / "config.json", "w") as f: + # json.dump(config, f) + + +class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" + modeling_file = modeling_llava_hybrid.__file__ + configuration_file = configuration_llava_hybrid.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + additional_files = [ + modeling_ssm_hybrid_apriel15b.__file__, + configuration_ssm_hybrid_apriel15b.__file__, + ] + @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + ), + ] diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb new file mode 100644 index 000000000..a8f0c33b7 --- /dev/null +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -0,0 +1,1562 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "# from transformers import MistralForCausalLM\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM\n", + "# autoreload changes to the code\n", + "%reload_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# AutoConfig.from_pretrained(model_path, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# m = AutoModelForCausalLM.from_pretrained(\n", + "# model_path, trust_remote_code=True,\n", + "# config=AutoConfig.from_pretrained(model_path, trust_remote_code=True),\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Slam 15B upcycled" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Lead the weights of https://huggingface.co/ServiceNow-AI/Slam-15B-Upcycled/ into Thiked modeling, it shoudl work" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/home/toolkit/dev/fml-ops/__oo_playground\")\n", + "from results_analysis.results_loader import ResultsLoader\n", + "layer_importance_path = \"/mnt/evaluations/training_evaluation/model_runs/lm_eval_runner/apriel_ssm_importance/\"\n", + "results_loader = ResultsLoader(layer_importance_path)\n", + "\n", + "results_loader.deserialize_results()\n", + "results_df = results_loader.to_df()\n", + "results_df[\"layer_index\"] = results_df.apply(lambda row: int(row[\"model_name_sanitized\"].split(\"_\")[-1] if \"layers_\" in row[\"model_name_sanitized\"] else -1), axis=1)\n", + "results_df = results_df[results_df[\"metric\"] == \"acc_norm\"]\n", + "columns_to_keep = [\"layer_index\", \"metric_value\"]\n", + "results_df = results_df[columns_to_keep]\n", + "layer_importance = results_df.groupby(\"layer_index\").mean()\n", + "layer_importance = layer_importance.sort_values(by=\"metric_value\", ascending=False).reset_index()\n", + "layer_importance = layer_importance[layer_importance[\"layer_index\"]!= -1]\n", + "layer_importance = list(layer_importance[\"layer_index\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[22,\n", + " 25,\n", + " 20,\n", + " 31,\n", + " 29,\n", + " 46,\n", + " 23,\n", + " 26,\n", + " 33,\n", + " 24,\n", + " 47,\n", + " 27,\n", + " 21,\n", + " 41,\n", + " 17,\n", + " 18,\n", + " 34,\n", + " 42,\n", + " 44,\n", + " 30,\n", + " 16,\n", + " 8,\n", + " 43,\n", + " 35,\n", + " 19,\n", + " 38,\n", + " 15,\n", + " 28,\n", + " 32,\n", + " 45,\n", + " 37,\n", + " 40,\n", + " 7,\n", + " 36,\n", + " 13,\n", + " 10,\n", + " 5,\n", + " 39,\n", + " 6,\n", + " 14,\n", + " 4,\n", + " 12,\n", + " 9,\n", + " 48,\n", + " 1,\n", + " 3,\n", + " 11,\n", + " 49,\n", + " 0]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "n_ssm = 25\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "\n", + "for i in range(n_ssm):\n", + " hybrid_block_layout[layer_importance[i]] = \"m2d\"\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg = {\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 32,\n", + " \"n_qk_heads\": 32,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 32 * 128\n", + " }\n", + ")\n", + "model_hybrid = AprielSSMHybridForCausalLM(config_hybrid)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a model of type llama to instantiate a model of type mistral. This is not supported for all configurations of models and can yield errors.\n", + "Loading checkpoint shards: 0%| | 0/4 [00:00 v, B -> k, C -> q\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"] : 2 * mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + "\n", + " print(\"Init Mamba using Attention\")\n", + "\n", + " transformer.model.layers[layer_idx] = mamba_encoder\n", + "\n", + " # elif type == \"m2d\":\n", + " # print(\"Converting layer %d...\" % layer_idx)\n", + " # mamba_encoder = AprielSSMDecoderLayer(\n", + " # mamba_config,\n", + " # layer_idx,\n", + " # device=\"cpu\",\n", + " # dtype=torch_dtype,\n", + " # )\n", + " # mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " # mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " # mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " # mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " # if init_with_kqvo:\n", + " \n", + "\n", + "\n", + " \n", + " else:\n", + " raise ValueError(f\"Invalid layer type: {type}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converting layer %d... 0\n", + "Skipping transformer layer 0...\n", + "Converting layer %d... 1\n", + "Skipping transformer layer 1...\n", + "Converting layer %d... 2\n", + "Skipping transformer layer 2...\n", + "Converting layer %d... 3\n", + "Skipping transformer layer 3...\n", + "Converting layer %d... 4\n", + "Skipping transformer layer 4...\n", + "Converting layer %d... 5\n", + "Skipping transformer layer 5...\n", + "Converting layer %d... 6\n", + "Skipping transformer layer 6...\n", + "Converting layer %d... 7\n", + "Skipping transformer layer 7...\n", + "Converting layer %d... 8\n", + "Skipping transformer layer 8...\n", + "Converting layer %d... 9\n", + "Skipping transformer layer 9...\n", + "Converting layer %d... 10\n", + "Skipping transformer layer 10...\n", + "Converting layer %d... 11\n", + "Skipping transformer layer 11...\n", + "Converting layer %d... 12\n", + "Skipping transformer layer 12...\n", + "Converting layer %d... 13\n", + "Skipping transformer layer 13...\n", + "Converting layer %d... 14\n", + "Skipping transformer layer 14...\n", + "Converting layer %d... 15\n", + "Skipping transformer layer 15...\n", + "Converting layer %d... 16\n", + "Skipping transformer layer 16...\n", + "Converting layer %d... 17\n", + "Skipping transformer layer 17...\n", + "Converting layer %d... 18\n", + "Skipping transformer layer 18...\n", + "Converting layer %d... 19\n", + "Skipping transformer layer 19...\n", + "Converting layer %d... 20\n", + "Converting layer 20...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 21\n", + "Converting layer 21...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 22\n", + "Converting layer 22...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 23\n", + "Converting layer 23...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 24\n", + "Converting layer 24...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 25\n", + "Converting layer 25...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 26\n", + "Skipping transformer layer 26...\n", + "Converting layer %d... 27\n", + "Converting layer 27...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 28\n", + "Skipping transformer layer 28...\n", + "Converting layer %d... 29\n", + "Skipping transformer layer 29...\n", + "Converting layer %d... 30\n", + "Converting layer 30...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 31\n", + "Converting layer 31...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 32\n", + "Converting layer 32...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 33\n", + "Converting layer 33...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 34\n", + "Skipping transformer layer 34...\n", + "Converting layer %d... 35\n", + "Converting layer 35...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 36\n", + "Converting layer 36...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 37\n", + "Converting layer 37...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 38\n", + "Converting layer 38...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 39\n", + "Converting layer 39...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 40\n", + "Converting layer 40...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 41\n", + "Converting layer 41...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 42\n", + "Converting layer 42...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 43\n", + "Converting layer 43...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 44\n", + "Converting layer 44...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 45\n", + "Converting layer 45...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 46\n", + "Converting layer 46...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 47\n", + "Converting layer 47...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 48\n", + "Skipping transformer layer 48...\n", + "Converting layer %d... 49\n", + "Converting layer 49...\n", + "Init Mamba using Attention\n" + ] + } + ], + "source": [ + "transformer = AutoModelForCausalLM.from_pretrained(path_thinker)\n", + "init_with_kqvo = True\n", + "torch_dtype = torch.bfloat16\n", + "attn_bias = True\n", + "convert_layers(transformer, config_hybrid, hybrid_block_layout, init_with_kqvo, attn_bias, torch_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config = config_hybrid" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridConfig {\n", + " \"architectures\": [\n", + " \"MistralForCausalLM\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 5120,\n", + " \"hybrid_block_layout\": [\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\"\n", + " ],\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 65536,\n", + " \"model_type\": \"apriel_ssm_thinker_hybrid\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 50,\n", + " \"num_key_value_heads\": 8,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_theta\": 1000000.0,\n", + " \"sliding_window\": null,\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"conv_bias\": true,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 4096,\n", + " \"d_state\": 16,\n", + " \"d_xb\": 1024,\n", + " \"dt_init\": \"random\",\n", + " \"dt_init_floor\": 0.0001,\n", + " \"dt_max\": 0.1,\n", + " \"dt_min\": 0.001,\n", + " \"dt_rank\": \"auto\",\n", + " \"dt_scale\": 1.0,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 32,\n", + " \"n_v_heads\": 32\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.53.2\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer.config" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config.architectures=[\"AprielThinkerSSMHybridForCausalLM\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 427.77it/s]\n" + ] + } + ], + "source": [ + "# load state dict from existing pretrained SSM?\n", + "path_25hyb = \"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", + "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(path_25hyb)\n", + "state_dict = model.state_dict()\n", + "\n", + "# missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", + "# print(missing)\n", + "# print(unexpected)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Note: saving as transformer wilkl still keep architectures[\"Mistral....\"]. So currently need to manually update the checkpoints architectures list to have AprielThinkerSSMHybridForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# mamba2, state 16, expand 1, i.e. same as M1, but with discrete mamba2 and MIL\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_m2_16hexp1_init_mil\") # 1 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil\") # 25 ssm\n", + "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "\n", + "\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_40ssm_leastimportant_m2_16hexp1_init_mil_uniform_from_25h5000lm6\") # 40 ssm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data mixing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([])\n", + "KL (global, F.kl_div) = 0.738795\n", + "KL (sum of shards, manual) = 0.738795\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb new file mode 100644 index 000000000..9a33f577e --- /dev/null +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -0,0 +1,416 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "\n", + "import torch\n", + "import random\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "fast_llm_path = \"/home/toolkit/dev/Fast-LLM\"\n", + "\n", + "# add fast_llm to the python path\n", + "import sys\n", + "sys.path.append(fast_llm_path)\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "base = 0.612615\n", + "layer_scores = {\n", + " \"22\": 0.607389,\n", + " \"24\": 0.603498,\n", + " \"19\": 0.597907,\n", + " \"27\": 0.597173,\n", + " \"20\": 0.590442,\n", + " \"5\": 0.578949,\n", + " \"4\": 0.576852,\n", + " \"9\": 0.576484,\n", + " \"23\": 0.574833,\n", + " \"7\": 0.571860,\n", + " \"8\": 0.571790,\n", + " \"6\": 0.571614,\n", + " \"2\": 0.571330,\n", + " \"26\": 0.570205,\n", + " \"11\": 0.567128,\n", + " \"14\": 0.566175,\n", + " \"15\": 0.566076,\n", + " \"3\": 0.562861,\n", + " \"1\": 0.560154,\n", + " \"13\": 0.559304,\n", + " \"16\": 0.559017,\n", + " \"10\": 0.558789,\n", + " \"12\": 0.555186,\n", + " \"17\": 0.554236,\n", + " \"25\": 0.549215,\n", + " \"18\": 0.537257,\n", + " \"0\": 0.233085,\n", + "}\n", + "layer_scores = {k: base - v for k, v in layer_scores.items()}\n", + "layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])\n", + "layer_importanfce_rand = random.sample(layer_importanfce, len(layer_importanfce))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('22', 0.005226000000000064),\n", + " ('24', 0.009117000000000042),\n", + " ('19', 0.014708000000000054),\n", + " ('27', 0.015442000000000067),\n", + " ('20', 0.022173),\n", + " ('5', 0.033665999999999974),\n", + " ('4', 0.03576299999999999),\n", + " ('9', 0.036131000000000024),\n", + " ('23', 0.03778199999999998),\n", + " ('7', 0.040754999999999986),\n", + " ('8', 0.040825),\n", + " ('6', 0.041001000000000065),\n", + " ('2', 0.041285000000000016),\n", + " ('26', 0.04241000000000006),\n", + " ('11', 0.045487000000000055),\n", + " ('14', 0.04644000000000004),\n", + " ('15', 0.046539),\n", + " ('3', 0.049754000000000076),\n", + " ('1', 0.05246099999999998),\n", + " ('13', 0.053311),\n", + " ('16', 0.053598000000000035),\n", + " ('10', 0.05382600000000004),\n", + " ('12', 0.05742900000000006),\n", + " ('17', 0.05837900000000007),\n", + " ('25', 0.06340000000000001),\n", + " ('18', 0.07535800000000004),\n", + " ('0', 0.37953000000000003)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importanfce" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "layer_importanfce = layer_importanfce_rand" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create hybrid with any number of SSM layers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "device = \"cuda\"\n", + "n_hybrid = 0\n", + "\n", + "index_swaped = []\n", + "hybrid_block_layout = [\"t\"] * config.num_hidden_layers\n", + "for i in range(n_hybrid):\n", + " hybrid_block_layout[int(layer_importanfce[i][0])] = \"m2d\"\n", + " index_swaped.append(int(layer_importanfce[i][0]))\n", + "\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrdif_apriel_config.hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridForCausalLM(\n", + " (model): AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)\n", + "hybrid_apriel_model.to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.22it/s]\n" + ] + } + ], + "source": [ + "\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: []\n", + "Unexpected keys: []\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00, 1.22it/s]\n" + ] + } + ], + "source": [ + "from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMModel, AprielSSMForCausalLM\n", + "\n", + "mohawk_path = \"/mnt/checkpoints/ssm/mohawk_distributed_stage2_apriel_8GPU_16ksteps_lr0.0_layernorm/final\"\n", + "# config = AutoConfig.from_pretrained(mohawk_path, trust_remote_code=True)\n", + "apriel_model = AprielSSMForCausalLM.from_pretrained(mohawk_path, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight']\n", + "Unexpected keys: ['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight']\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_leastimportant_init_MOHAWK\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_20ssm_leastimportant_init_rand\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_randplacement_init_rand\")\n", + "hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_0ssm_full_transformer_debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# save the hybrid model\n", + "output_path = \"/mnt/checkpoints/ssm/iterative_hybrids_5b\"\n", + "assert len(index_swaped) == 1\n", + "layer_swaped = index_swaped[0]\n", + "hybrid_apriel_model.save_pretrained(\n", + " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", + " )\n", + "print(f\"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..9f4588a29 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -357,7 +357,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() @@ -843,9 +849,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -888,7 +893,7 @@ def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, + mamba_mask: Optional[torch.Tensor] = None, return_mixer_matrix=False, **kwargs, ): @@ -900,6 +905,10 @@ def forward( assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape + # mamba_mask = ( + # None if seqlen == 1 else mamba_mask + # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time + # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -933,8 +942,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +962,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -971,7 +989,7 @@ def forward( # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( @@ -979,7 +997,10 @@ def forward( weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - ) + ) # .transpose(1, 2) + # x = apply_mask_to_padding_states(x, mamba_mask).transpose( + # 1, 2 + # ) # zero out everything that comes from padding tokens if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -1034,14 +1055,14 @@ def step(self, hidden_states, conv_state, ssm_state): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states_input) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states_input) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -1209,6 +1230,42 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return output + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1390,6 +1447,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, + mamba_mask=attention_mask, # non-expended mask **kwargs, ) diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 000000000..b8e822d9f --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,117 @@ +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + projector_intermediate_size=4096, + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 000000000..68073f9cd --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,132 @@ +from torch import nn +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN + +from .configuration_llava_hybrid import LlavaHybridConfig + +try: + # In the fast-llm repo, import from the SSM modeling file + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config.text_config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py new file mode 100644 index 000000000..8a21c906f --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py @@ -0,0 +1,163 @@ +import gc + +import click +import torch +from transformers import AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMM2DecoderLayer, + AprielThinkerSSMHybridForCausalLM, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" + +dstate = 16 +expand = 1 +# Calculate derived dimensions for the Mamba1 configuration +# d_model = config_base.text_config.hidden_size +d_inner = 4096 # hard code to match thinker #expand * d_model +d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads) + + +def convert_layers( + transformer_config, + transformer_model, + mamba_config, + hybrid_block_layout, + init_with_kqvo, + torch_dtype=torch.bfloat16, +): + config = transformer_config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + num_heads_kv = config.num_key_value_heads + head_dim = embed_dim // num_heads + head_dim * num_heads + head_dim * num_heads_kv + + for layer_idx, type in enumerate(hybrid_block_layout): + print("Converting layer %d...", layer_idx) + # Fetch the layer module for easier access + layer_module = transformer_model.layers._modules[f"{layer_idx}"] + if type == "t": + print("Skipping transformer layer %d..." % layer_idx) + elif type == "m2": + print("Converting layer %d..." % layer_idx) + # Use MambaDecoderLayer for the remaining layers + mamba_encoder = AprielSSMM2DecoderLayer( + mamba_config, + layer_idx, + device="cpu", + dtype=torch_dtype, + ) + + mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) + mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) + mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) + mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) + + if init_with_kqvo: + # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : + ].copy_(layer_module.self_attn.v_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.k_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.q_proj.weight.data) + + print("Init Mamba using Attention") + + transformer_model.layers[layer_idx] = mamba_encoder + + else: + raise ValueError(f"Invalid layer type: {type}") + + +def make_hybrid_config(transformer): + config_dict = transformer.config.to_dict() + config_dict["hybrid_block_layout"] = ["t"] * transformer.config.num_hidden_layers + config_dict["model_type"] = "apriel_ssm_thinker_hybrid" + config_dict["ssm_cfg"] = { + "activation": "silu", + "d_state": dstate, + "d_xb": d_xb, + "expand": expand, + "d_conv": 4, + "d_inner": d_inner, + "conv_bias": True, + "bias": False, + } + hybrid_config = AprielSSMHybridConfig.from_dict(**config_dict) + return hybrid_config + + +@click.command() +@click.option( + "--base_checkpoint", type=str, required=False, default="/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" +) +@click.option("--m2_indices", type=int, multiple=True, required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +def main(base_checkpoint: str, m2_indices: list, hybrid_checkpoint: str, save_dir: str): + """ + base_checkpoint: path to base transformer-model (teacher model) + m2_indices: indices of layers to convert to mamba layers with MiL init + hybrid_checkpoint: path to hybrid model (student model). + save_dir: directory to save the converted model. + + TODO: base_checkpoint can actually be a hybrid. Rename transformer variable to a better name + """ + m2_indices = list(m2_indices) # convert tuple -> list + transformer = AutoModelForCausalLM.from_pretrained(base_checkpoint, trust_remote_code=True) + if hybrid_checkpoint == "none": + print("No hybrid checkpoint provided, creating new config from base model.") + hybrid_config = make_hybrid_config(transformer) + else: + hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.hybrid_block_layout + for m2_index in m2_indices: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + convert_layers( + transformer.config, + transformer.model, + hybrid_config, + hybrid_block_layout, + init_with_kqvo=True, + torch_dtype=torch.bfloat16, + ) + hybrid_config.ssm_cfg["activation"] = "silu" + + # load all existing ssm layers + if hybrid_checkpoint != "none": + hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) + state_dict = hybrid_model.state_dict() + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + for m2_index in m2_indices: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print("MISSING", missing) + print("UNEXPECTED", unexpected) + + # Save state-dict + transformer.save_pretrained(save_dir) + + hybrid_config.save_pretrained(save_dir) + + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py deleted file mode 100644 index dde11cfbc..000000000 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py +++ /dev/null @@ -1,176 +0,0 @@ -import click -import torch -import transformers -from transformers import AutoConfig, AutoModelForCausalLM - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( - AprielSSMM2DecoderLayer, - AprielThinkerSSMHybridForCausalLM, -) - -device = "cuda" if torch.cuda.is_available() else "cpu" - -print("Transformers version:", transformers.__version__) - - -def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype): - - for layer_idx, type in enumerate(hybrid_block_layout): - # print("Converting layer %d...", layer_idx) - # Fetch the layer module for easier access - layer_module = transformer.model.layers._modules[f"{layer_idx}"] - if type == "t": - print("Skipping transformer layer %d..." % layer_idx) - elif type == "m2": - print("Converting layer %d to Mamba2 with MIL init..." % layer_idx) - # Use MambaDecoderLayer for the remaining layers - mamba_encoder = AprielSSMM2DecoderLayer( - mamba_config, - layer_idx, - device="cpu", - dtype=torch_dtype, - ) - - mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) - mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) - mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) - mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) - - if init_with_kqvo: - # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : - ].copy_(layer_module.self_attn.v_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.k_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.q_proj.weight.data) - - print("Init Mamba using Attention") - - transformer.model.layers[layer_idx] = mamba_encoder - - elif type == "m2d": - raise NotImplementedError("Discrete Mamba2 not implemented") - else: - raise ValueError(f"Invalid layer type: {type}") - - -@click.command() -@click.option("--index_to_swap", type=int, required=True) -@click.option("--checkpoint", type=str, required=True) -@click.option("--output_model_path", type=str, required=True) -@click.option("--layer_type", type=str, default="m2") -@click.option("--mil_init", type=bool, default=True) -def main( - index_to_swap: int, - checkpoint=None, - output_model_path="/mnt/checkpoints/ssm/iterative_hybrids_15b_rkl_m2/apriel_ssm_thinker_15b_hybrid", - layer_type="m2", - mil_init=True, -): - print(f"index_to_swap: {index_to_swap}, checkpoint: {checkpoint}") - - layer_importance = [ - 47, - 39, - 24, - 36, - 31, - 43, - 32, - 20, - 38, - 37, - 30, - 33, - 22, - 23, - 40, - 42, - 44, - 35, - 41, - 27, - 21, - 46, - 45, - 49, - 25, - 34, - 29, - 28, - 19, - 26, - 18, - 17, - 16, - 13, - 15, - 14, - 8, - 9, - 12, - 6, - 11, - 5, - 48, - 7, - 10, - 3, - 4, - 1, - 0, - ] - path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" - config_base = AutoConfig.from_pretrained(path_base) - hybrid_block_layout = ["t"] * config_base.num_hidden_layers - - for i in range(index_to_swap + 1): - layer_idx = int(layer_importance[i]) - print(f"Swapping layer {layer_idx} to {layer_type}") - hybrid_block_layout[layer_idx] = layer_type - - transformer = AutoModelForCausalLM.from_pretrained(path_base) - model_hybrid_prev = AprielThinkerSSMHybridForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to( - torch.bfloat16 - ) - config_hybrid = AprielSSMHybridConfig(**model_hybrid_prev.config.to_dict()) - config_hybrid.hybrid_block_layout = hybrid_block_layout - convert_layers(transformer, config_hybrid, hybrid_block_layout, mil_init, torch.bfloat16) - - missing, unexpected = transformer.load_state_dict( - model_hybrid_prev.state_dict(), strict=False - ) # will not load the newly innitialized layer (will stay MIL), but will overwrite previous layers - if missing: - print("Missing keys:", missing) - if unexpected: - print("Unexpected keys:", unexpected) - transformer.to(torch.bfloat16) - model_hybrid_prev = None - print(transformer) - model_hybrid = AprielThinkerSSMHybridForCausalLM(config_hybrid) - missing, unexpected = model_hybrid.load_state_dict(transformer.state_dict()) - assert len(missing) == 0, "Missing keys: " + str(missing) - assert len(unexpected) == 0, "Unexpected keys: " + str(unexpected) - - model_hybrid.save_pretrained(f"{output_model_path}") - # config_hybrid.save_pretrained(f"{output_model_path}") - - -if __name__ == "__main__": - main() - # main( - # index_to_swap=1, - # checkpoint="/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-ihyb1lrklm216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000", - # layer_type="m2", - # ) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py new file mode 100644 index 000000000..6ce283525 --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py @@ -0,0 +1,107 @@ +import gc + +import click +import torch +from transformers import AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMM2DecoderLayer, + AprielThinkerSSMHybridForCausalLM, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype=torch.bfloat16): + config = transformer.config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + num_heads_kv = config.num_key_value_heads + head_dim = embed_dim // num_heads + head_dim * num_heads + head_dim * num_heads_kv + + for layer_idx, type in enumerate(hybrid_block_layout): + print("Converting layer %d...", layer_idx) + # Fetch the layer module for easier access + layer_module = transformer.model.layers._modules[f"{layer_idx}"] + if type == "t": + print("Skipping transformer layer %d..." % layer_idx) + elif type == "m2": + print("Converting layer %d..." % layer_idx) + # Use MambaDecoderLayer for the remaining layers + mamba_encoder = AprielSSMM2DecoderLayer( + mamba_config, + layer_idx, + device="cpu", + dtype=torch_dtype, + ) + + mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) + mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) + mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) + mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) + + if init_with_kqvo: + # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : + ].copy_(layer_module.self_attn.v_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.k_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.q_proj.weight.data) + + print("Init Mamba using Attention") + + transformer.model.layers[layer_idx] = mamba_encoder + + else: + raise ValueError(f"Invalid layer type: {type}") + + +@click.command() +@click.option("--m2_indexes", type=int, nargs="-1", required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): + m2_indexes = list(m2_indexes) # convert tuple -> list + path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" + transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) + hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.hybrid_block_layout + for m2_index in m2_indexes: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) + hybrid_config.ssm_cfg["activation"] = "silu" + + # load all existing ssm layers + hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) + state_dict = hybrid_model.state_dict() + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + for m2_index in m2_indexes: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print(missing) + print(unexpected) + transformer.save_pretrained(save_dir) + + hybrid_config.save_pretrained(save_dir) + + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py new file mode 100644 index 000000000..1f9808f1b --- /dev/null +++ b/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py @@ -0,0 +1,153 @@ +import gc +import json +import shutil + +import click +import torch +from transformers import AutoModelForVision2Seq + +from fast_llm.models.ssm.external.apriel_15b_hybrid import modeling_ssm_hybrid_apriel15b +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid +from fast_llm.models.ssm.external.llava_hybrid.configuration_llava_hybrid import LlavaHybridConfig +from fast_llm.models.ssm.external.llava_hybrid.modeling_llava_hybrid import LlavaHybridForConditionalGeneration +from fast_llm.models.ssm.external.make_hybrid_checkpoint import convert_layers + +device = "cuda" if torch.cuda.is_available() else "cpu" + +dstate = 16 +expand = 1 +# Calculate derived dimensions for the Mamba1 configuration +# d_model = config_base.text_config.hidden_size +d_inner = 4096 # hard code to match thinker #expand * d_model +d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads) + + +def make_hybrid_llava_config(transformer): + config_dict = transformer.config.to_dict() + config_dict["text_config"]["hybrid_block_layout"] = ["t"] * transformer.config.text_config.num_hidden_layers + config_dict["text_config"]["model_type"] = "apriel_ssm_thinker_hybrid" + config_dict["text_config"]["ssm_cfg"] = { + "activation": "silu", + "d_state": dstate, + "d_xb": d_xb, + # "d_model": d_model, # will be set automatically + "expand": expand, + "d_conv": 4, + "d_inner": d_inner, # will be same as d_model * expand, + "conv_bias": True, + "bias": False, + } + llava_hybrid_config = LlavaHybridConfig(**config_dict) + return llava_hybrid_config + + +def make_hybrid_llava_model(transformer, llava_hybrid_config): + """ + Create a LlavaHybridForConditionalGeneration model with the same configuration as the given transformer model. + """ + llava_hybrid_model = LlavaHybridForConditionalGeneration(llava_hybrid_config) + # llava_hybrid_model.to(dtype=torch.bfloat16).to(device) + llava_hybrid_model.load_state_dict(transformer.state_dict(), strict=False) + return llava_hybrid_model + + +@click.command() +@click.option("--base_checkpoint", type=str, required=False, default="ServiceNow-AI/Apriel-Nemotron-15b-Thinker") +@click.option("--m2_indices", type=int, multiple=True, required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +@click.option( + "--tokenizer_dir", type=str, required=False, default="/mnt/plato/checkpoints/upstream/Mistral-Nemo-Base-2407/" +) +def main(base_checkpoint: str, m2_indices: list[int], hybrid_checkpoint: str, save_dir: str, tokenizer_dir: str): + """ + base_checkpoint: path to base transformer-model (teacher model) + m2_indices: indices of layers to convert to mamba layers with MiL init + hybrid_checkpoint: path to hybrid model (student model). Can be a hybrid with only transformer layers for the first distillation run. + save_dir: directory to save the converted model. + tokenizer_dir: directory containing tokenizer files to copy over to save_dir. + """ + m2_indices = list(m2_indices) # convert tuple -> list + transformer = AutoModelForVision2Seq.from_pretrained(base_checkpoint, trust_remote_code=True) + if hybrid_checkpoint == "none": + print("No hybrid checkpoint provided, creating new config from base model.") + hybrid_config = make_hybrid_llava_config(transformer) + else: + hybrid_config = LlavaHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.text_config.hybrid_block_layout + for m2_index in m2_indices: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + # MiL init + convert_layers( + transformer.model.language_model.config, + transformer.model.language_model, + hybrid_config.text_config, + hybrid_block_layout, + init_with_kqvo=True, + torch_dtype=torch.bfloat16, + ) + hybrid_config.text_config.ssm_cfg["activation"] = "silu" + + # Load existing SSM layers + if hybrid_checkpoint != "none": + hybrid_llava_model = AutoModelForVision2Seq.from_pretrained( + hybrid_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + llava_state_dict = hybrid_llava_model.state_dict() + missing, unexpected = transformer.load_state_dict(llava_state_dict, strict=False) + for m2_index in m2_indices: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print("MISSING", missing) + print("UNEXPECTED", unexpected) + + # Save state-dict + transformer.save_pretrained(save_dir) + # Save new config + hybrid_config.save_pretrained(save_dir) + + # Copy modeling and tokenizer files + modeling_files = [ + configuration_llava_hybrid.__file__, + modeling_llava_hybrid.__file__, + modeling_ssm_hybrid_apriel15b.__file__, + ] + tokenizer_files = [ + f"{tokenizer_dir}/tokenizer.json", + f"{tokenizer_dir}/tokenizer_config.json", + f"{tokenizer_dir}/generation_config.json", + f"{tokenizer_dir}/special_tokens_map.json", + ] + for f in modeling_files + tokenizer_files: + shutil.copy(f, save_dir) + + # Update config with auto_maps + config_file = f"{save_dir}/config.json" + with open(config_file) as f: + dumped_config = json.load(f) + + dumped_config["auto_map"] = { + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + } + dumped_config["text_config"]["auto_map"] = { + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + } + dumped_config["architectures"] = ["LlavaHybridForConditionalGeneration"] + dumped_config["text_config"]["architectures"] = ["AprielThinkerSSMHybridForCausalLM"] + with open(config_file, "w") as f: + json.dump(dumped_config, f, indent=2) + + torch.cuda.empty_cache() + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 77cd346f7..1ece10edf 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -1,9 +1,10 @@ import logging +import typing -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel logger = logging.getLogger(__name__) @@ -17,5 +18,7 @@ class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config_class = HuggingfaceSSMModelConfig config: HuggingfaceSSMModelConfig + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..fafe44090 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,13 +3,12 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.ssm.preprocessing import Mamba2Preprocessor +from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,60 +30,39 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) + self._preprocessors.append(Mamba2Preprocessor(config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: """ Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +72,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = self.get_embedding_layers() # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() @@ -165,3 +115,8 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel + + +class HybridSSMInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + batch_config_class: typing.ClassVar[type[GPTBatchConfig]] = GPTBatchConfig diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,17 +1,21 @@ +import abc import functools +import logging import math import typing import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -146,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -158,22 +162,23 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +187,44 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim) - return tensor_ if expand else tensor_.reshape(self.shape) + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +258,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +268,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +297,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -293,12 +318,20 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +363,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) - return init_ + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +396,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c2eb1f6f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 @@ -50,13 +50,20 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba cartesia_pytorch>=0.0.2 -GENERATION = - lm_eval>=0.4.9 +# GENERATION = +# lm_eval>=0.4.9 +# Required for supporting vision inputs +VISION = + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 + DEV = # Pre-commit git hook pre-commit>=4.2.0 diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..23ed9d76b 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -127,10 +127,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) @@ -224,10 +224,15 @@ def __len__(self) -> int: return self._config.num_documents def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( + [], dtype=np.int64 + ) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() + + def has_images(self) -> bool: + return False diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 32d76fa4c..296102f7d 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -98,14 +98,24 @@ def __len__(self) -> int: return len(self._samples) def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + doc_sizes = [] + im_sizes = [] + for index in range(len(self)): + doc_size, im_size = self.get_document_size(index) + doc_sizes.append(doc_size) + im_sizes.append(im_size) + return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) def get_document_size(self, index: int) -> int: - return len(self._samples[index]) + return len(self._samples[index]), [] def name(self) -> str: return "dataset" + @property + def has_images(self) -> bool: + return False + TEST_DATASET = SimpleGPTIndexedDataset( [ diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..6d00d05ba 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,12 +23,10 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) student_log_probs = torch.log_softmax(logits, dim=-1) if loss_mask is None: loss = torch.nn.functional.kl_div( diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..665faf7ed 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, @@ -307,11 +312,12 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM model_as_hf = auto_model.from_pretrained( hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_config.py b/tests/test_config.py index b6a9a9854..52c00f0a1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -137,6 +137,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "tie_word_embeddings": False, "vocab_size": 1000, + "vision_encoder": { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + }, } else: base_model_update["transformer"]["peft"] = { @@ -146,6 +154,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["rotary"] = {"type": "none"} + base_model_update["vision_encoder"] = { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + } expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..2a338f1ba 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -1,19 +1,60 @@ +import inspect +import itertools import pathlib +from functools import partial import pytest import torch +from mamba2 import Mamba2 from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): + hidden_size = 512 + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), + ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config + @pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") @pytest.mark.slow @@ -80,3 +121,229 @@ def test_load_from_llamba_checkpoint(): logits = input_data[0][1]["logits"].cpu() assert torch.allclose(logits, hf_logits, atol=1e-2) + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + # cu_seqlens = [0] + split_points + [seq_len] + cu_seqlens = [0] + split_points + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py +# TODO: integrate in the testing framework +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") +@pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available") +def test_mamba_varlen_block(distributed_config, distributed): + """ + Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. + """ + hybrid_config = get_hybrid_config(hybrid_block_layout=["m2", "t"]) + tensor_space = TensorSpace(distributed_config=distributed_config) + tensor_space.setup(distributed) + hybrid_config.setup_tensor_space(tensor_space) + layer_idx = 0 + + mixer_cls = partial(Mamba2, block_index=layer_idx) + block_packed = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + block_ref = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + device = "cuda" + materialize_meta_tensors(block_packed, tensor_space) + materialize_meta_tensors(block_ref, tensor_space) + block_ref.load_state_dict(block_packed.state_dict()) + block_packed.to(device) + block_ref.to(device) + + batch_size = 2 + seq_len = 64 + packages_num = 2 + hidden_dim = hybrid_config.transformer.hidden_size + + cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) + cu_seqlens = torch.tensor(cu_seqlens).cuda() + ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() + seq_idx = ( + torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + # Generate packed_hidden_states with random values for testing + hidden_states_list = [ + torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) + for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + ] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + output_states_packed = block_packed( + packed_hidden_states, + {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, + ) + output_states_unpacked = block_ref( + hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} + ) + tollerance = 1e-4 + assert output_states_packed.shape == packed_hidden_states.shape + assert output_states_unpacked.shape == hidden_states.shape + assert not torch.isnan(hidden_states).any() + assert not torch.isinf(hidden_states).any() + + output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) + torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) + + loss = output_states_packed.sum() + loss.backward() + loss_ref = output_states_unpacked.sum() + loss_ref.backward() + assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) + assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) + assert torch.allclose( + block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.dt_in_proj.weight.grad_buffer, + block_ref.mixer.dt_in_proj.weight.grad_buffer, + atol=tollerance, + ) + + assert torch.allclose( + block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..96982e510 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -13,13 +13,18 @@ DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -465,6 +470,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_update_and_add_testing_config( + # Tests hybrid Mamba, llamba converter. + "llama", + "llava", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=256", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaGPTHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=8.0, +) + _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. "llama", @@ -472,10 +512,8 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -483,57 +521,111 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), +) + +_update_and_add_testing_config( + # Tests hybrid Mamba 2. + "llama", + "hybrid_mamba2", + model_type="hybrid_ssm", + extra_args=[ + "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + ], + megatron_args=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=2.0, + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid discrete Mamba 2. + "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. - "llamba", "hybrid_mamba2", + "vision_hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=512", + "model.distributed.training_dtype=torch.bfloat16", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=16.0, )