From 870847ee1cefc015877057f8b826c5e941920f1b Mon Sep 17 00:00:00 2001 From: Richard Date: Sat, 17 Feb 2024 19:15:51 +0000 Subject: [PATCH 1/2] added documentation and doc strings --- .gitignore | 3 + docs/data.md | 43 +++++ docs/llama.md | 41 +++++ docs/ring_attention.md | 53 ++++++ docs/train.md | 34 ++++ docs/vision_chat.md | 43 +++++ docs/vision_llama.md | 28 +++ docs/vqgan.md | 71 ++++++++ lwm/data.py | 111 +++++++++++- lwm/llama.py | 293 +++++++++++++++++++++++++++++-- lwm/ring_attention.py | 61 +++++++ lwm/train.py | 30 ++++ lwm/vision_chat.py | 41 +++++ lwm/vision_generation.py | 16 ++ lwm/vision_llama.py | 226 +++++++++++++++++++++++- lwm/vqgan.py | 332 +++++++++++++++++++++++++++++++++++ requirements.txt | 73 ++++++-- scripts/eval_needle.py | 0 scripts/eval_needle_multi.py | 0 19 files changed, 1450 insertions(+), 49 deletions(-) create mode 100644 docs/data.md create mode 100644 docs/llama.md create mode 100644 docs/ring_attention.md create mode 100644 docs/train.md create mode 100644 docs/vision_chat.md create mode 100644 docs/vision_llama.md create mode 100644 docs/vqgan.md mode change 100644 => 100755 scripts/eval_needle.py mode change 100644 => 100755 scripts/eval_needle_multi.py diff --git a/.gitignore b/.gitignore index b346ef8..e8c0296 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,6 @@ data/ *.pkl *.json __pycache__/ +datasets +checkpoints +.venv diff --git a/docs/data.md b/docs/data.md new file mode 100644 index 0000000..ec7835b --- /dev/null +++ b/docs/data.md @@ -0,0 +1,43 @@ +# Data + +This script defines a flexible dataset loading and processing framework designed for machine learning models, particularly those dealing with natural language processing (NLP) and potentially vision tasks. The framework is built to work with the JAX library for high-performance machine learning and supports parallel processing and distributed training. Here's an overview of the main components: + +## DatasetFactory Class + +A factory class for creating dataset instances based on configuration parameters. It supports loading datasets from Hugging Face's datasets library (huggingface type), as well as custom JSON-formatted datasets (json and json_vision types). +It provides a method to get the default configuration for a dataset, which can be customized with specific parameters. + +## TextProcessor Class + +Processes text data by encoding strings into token IDs using a provided tokenizer. It supports adding special tokens (like BOS and EOS) and can process multiple fields from the data, concatenating them with a specified separator. +The default configuration and the processing behavior can be customized. + +## VisionTextProcessor Class + +Designed for processing datasets that include both vision (image or video frames) and text data. It handles encoding of textual data and can integrate special tokens indicating the start and end of vision-related tokens. +Supports custom configurations for handling vision data, including specifying the number of tokens per frame and the maximum number of frames. + +## HuggingfaceDataset Class + +Loads and processes datasets from the Hugging Face datasets library. It can stream data, making it efficient for large datasets. +The data is processed in chunks, with each chunk transformed into model input and target arrays, along with a loss mask to indicate which tokens should contribute to the loss calculation. + +## JsonDataset Class + +Loads and processes datasets from newline-delimited JSON files, where each line contains a JSON object representing a data example. +Supports parallel processing to tokenize and encode the data efficiently across multiple CPU cores. +Data examples are batched and padded as necessary to create fixed-size arrays suitable for model training. + +## JsonVisionDataset Class + +Similar to JsonDataset but specifically designed for datasets that include both vision and text data. +It can handle special tokens for vision data and supports different modes for padding or not padding the batches. + +## General Workflow + +*Configuration*: The user specifies the dataset type and configuration parameters, including paths to data files, batch sizes, sequence lengths, and any special tokenization or processing requirements. +Dataset Loading: Based on the configuration, the appropriate dataset class is instantiated, which loads the data and prepares it for processing. +Data Processing: Text and/or vision data is tokenized and encoded according to the specified processing rules. The data is then batched, with options for padding batches to a fixed size. + +*Iteration*: The dataset objects are iterable, providing batches of data ready for input into a machine learning model. Each batch includes input tokens, target tokens (for supervised learning), and a mask indicating which tokens should be considered for loss calculation. +This framework is highly modular and customizable, making it suitable for a wide range of machine learning tasks and models. It leverages JAX's capabilities for efficient computation and is designed with distributed and parallel processing in mind. \ No newline at end of file diff --git a/docs/llama.md b/docs/llama.md new file mode 100644 index 0000000..9cea7a4 --- /dev/null +++ b/docs/llama.md @@ -0,0 +1,41 @@ +# LLama + +This script is structured into multiple sections, each defining classes and functions related to the LLaMA model, its configuration, tokenization, and various utilities for handling model layers and attention mechanisms. Here's a detailed overview: + +## LLaMAConfig Class + +Defines the configuration for a LLaMA model, including parameters like vocabulary size, hidden layer size, and the number of attention heads. It supports loading configurations for different sizes of LLaMA models (e.g., 200m, 1b, etc.). + +## FlaxLLaMAAttention Class + +Implements the attention mechanism for LLaMA, including query, key, and value projections, as well as the attention calculation itself. It supports causal attention for autoregressive tasks and incorporates options for efficient attention mechanisms like Flash Attention. + +## FlaxLLaMAMLP Class + +Defines the feed-forward network (MLP) used within each Transformer block, including two linear layers and a GELU activation function. + +## FlaxLLaMABlock Class + +Represents a single Transformer block, combining the attention and MLP components, along with layer normalization. + +## FlaxLLaMAPreTrainedModel and FlaxLLaMAModule Classes + +Provide the base implementation for a LLaMA model in Flax, including methods for weight initialization and handling pretrained models. + +## FlaxLLaMABlockCollection Class + +Manages a collection of Transformer blocks, allowing for sequential processing of inputs through multiple blocks. + +## FlaxLLaMAModel and FlaxLLaMAForCausalLM Classes + +Define specific model variants, such as a basic LLaMA model and a causal language model variant for tasks like text generation. + +## LLaMATokenizer Class + +Implements tokenization for LLaMA using SentencePiece, including methods for encoding text into tokens and decoding tokens back into text. + +## Utility Functions and Classes + +Include various helper functions and classes such as RMSNorm for RMS normalization, apply_rotary_emb for applying rotary embeddings to queries and keys, and methods for managing model parameters and configurations. + +Each class and function is designed to be modular and interoperable, allowing for flexible configuration and usage of the LLaMA model components. The use of Flax and JAX libraries facilitates efficient training and inference on hardware accelerators. \ No newline at end of file diff --git a/docs/ring_attention.md b/docs/ring_attention.md new file mode 100644 index 0000000..a8818fa --- /dev/null +++ b/docs/ring_attention.md @@ -0,0 +1,53 @@ +# Ring Attention + +This module implements the forward and backward passes of the ring attention mechanism, which is designed for efficient computation on TPUs, especially when handling large sequences. It supports blockwise computation to reduce memory cost and incorporates fused attention for TPU compatibility. The module is structured to accommodate both standard and ring-flash attention mechanisms, with an emphasis on blockwise processing to optimize performance and memory usage. + +## Ring Attention Forward Pass + +`_ring_attention_fwd` + +This function computes the forward pass of the ring attention mechanism, dividing the computation into blocks for efficiency. It uses a scan operation to iterate over key-value (KV) blocks, applying blockwise attention and rotating KV pairs across TPU cores to implement the ring structure. + +## Ring Attention Backward Pass + +`_ring_attention_bwd` + +This function handles the backward pass, computing gradients with respect to the inputs. It mirrors the forward pass but in reverse, iterating over the blocks and applying the backward computations for blockwise attention. + +## Standard Attention Forward Pass + +`_ring_attention_standard_fwd` + +A variant of the ring attention forward pass that does not use blockwise computation. It's more straightforward but less memory efficient compared to the blockwise version. + +## Blockwise Attention Functions + +`_blockwise_attention_fwd` and `_blockwise_attention_bwd` + +These functions are core to the blockwise computation, handling the forward and backward computations within each block. They are designed to be efficient and compatible with TPU architecture. + +## Ring Flash Attention TPU-Compatible Functions + +`_ring_flash_attention_fwd_tpu` and `_ring_flash_attention_bwd_tpu` + +These functions are specialized versions of the ring attention mechanism, optimized for TPU execution. They leverage TPU-specific operations and structures to achieve high performance. + +## Utility Functions + +The module includes several utility functions, such as `_chunk_attention_bias` for computing attention bias within chunks and `_flash_attention` for a fused attention mechanism that is efficient on TPUs. + +## Data Structures + +The module defines several data structures, like SegmentIds and BlockSizes, to organize and manage the dimensions and indices involved in blockwise and ring attention computations. + +## Blockwise Computation + +This approach divides the input into smaller blocks, allowing for more efficient processing by reducing memory requirements and leveraging parallelism. + +## Ring Structure + +In the context of TPU computation, the ring structure refers to a method where data (e.g., KV pairs) is passed in a ring-like fashion across TPU cores, enabling efficient parallel computation. +## Fused Attention + +This technique combines multiple attention-related operations into a single, more efficient operation, particularly beneficial on TPUs where memory bandwidth can be a limiting factor. +This module is a comprehensive implementation of advanced attention mechanisms tailored for high-performance computing environments, particularly TPUs, with a focus on efficiency and scalability. \ No newline at end of file diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000..31dd3d8 --- /dev/null +++ b/docs/train.md @@ -0,0 +1,34 @@ +# Training + A training script for LWM model designed for use with the JAX library using a LLaMA (Large Language Model) and its variations, including ones for video and text processing. The script is structured to support distributed training across multiple devices or nodes, using functionalities from JAX for parallel execution and the flax library for model state management. Here's a high-level overview of how it works: + +## Configuration and Initialization: + +Default configurations and flags, include data modality (text or vision+text), dataset loading, model configuration, optimizer setup, logging, and checkpointing. +The main function initializes distributed training using `JaxDistributedConfig` and sets up logging with `tux.WandBLogger` using the integration with the Weights & Biases platform for experiment tracking. + +## Model and Dataset Setup: + +Depending on the specified modality (text or vision+text), you can select appropriate model configuration and class (`LLaMAConfig` and `FlaxLLaMAForCausalLMModule` for text, `VideoLLaMAConfig` and `FlaxVideoLLaMAForCausalLMModule` for vision+text). +The dataset is loaded using a `DatasetFactory`, which provides a way to load and preprocess data suitable for training the model. There's support for resuming training from a checkpoint or loading a specific dataset state. + +## Model Initialization: + +The model is initialized with the specified configuration, and the script prepares for distributed training by setting up a computational mesh using `pjit` (parallel JIT compilation in JAX). This involves defining how the model's parameters and operations should be partitioned across the available hardware. +Training Loop: + +The main training loop iterates over the total number of training steps. For each step, it processes a batch of data, performs a forward pass and backward pass (computing gradients), and updates the model parameters using the defined optimizer. +The script supports different data modalities by branching the logic within the training step function (train_step), handling text and vision+text differently in terms of how the model is applied and how losses are computed. + +## Evaluation and Logging: + +Optionally, the script can perform evaluation steps at a specified frequency, computing metrics on a separate evaluation dataset. +Metrics from both training and evaluation are logged using the configured logger, allowing for monitoring of the training process through the Weights & Biases platform. + +## Checkpointing: + +The script includes functionality for saving model checkpoints at specified intervals, supporting both regular checkpoints and milestone checkpoints. This allows for resuming training from a specific point and provides a way to save model states for later use or analysis. + +## Finalization: + +After completing the training loop, a final checkpoint may be saved, capturing the final state of the model. +The script is designed with modularity and flexibility in mind, allowing for various configurations and supporting complex distributed training setups. It leverages advanced features of JAX and Flax for efficient, scalable training of potentially large models on specialized hardware. \ No newline at end of file diff --git a/docs/vision_chat.md b/docs/vision_chat.md new file mode 100644 index 0000000..af3a59c --- /dev/null +++ b/docs/vision_chat.md @@ -0,0 +1,43 @@ + +# Vision Chat + +The implementation for sampling from a VideoLLaMA model using a VQGAN model for video processing and tokenization. Here's a high-level overview of the script's functionality and key components: + +## Flags Definition + +The script starts by defining various flags for configuring the sampling process, including the prompt, input file, VQGAN checkpoint, temperature for generation, maximum number of frames to consider, and various configurations related to the model and tokenization. + +## Sampler Class + + The core of the script is the Sampler class, which encapsulates the logic for sampling from the VideoLLaMA model. Key functionalities include: + +- Loading and setting up the VQGAN model for video processing. +- Initializing tokenizers for processing text and vision inputs. +- Defining the forward generation function using pjit for parallel execution across the specified JAX mesh. +- Constructing model inputs from prompts and video data, processing the video frames, and tokenizing the inputs. +- Generating outputs from the VideoLLaMA model and decoding them back to text. + +## Main Function +The main function orchestrates the sampling process by initializing the necessary configurations, creating a Sampler instance, and processing the provided prompts to generate responses. + +## Video Processing + +The script processes video inputs (handling both image files and video formats) by resizing and cropping frames to a consistent size, encoding them using the VQGAN model, and tokenizing the encoded frames for input to the VideoLLaMA model. + +## Text Processing + +Prompts and other textual inputs are tokenized using the specified tokenizer configurations. Special tokens are added as needed to mark the beginning and end of vision inputs and to structure the overall input sequence for the model. + +## Model Sampling + +The script uses pjit to define a parallelized forward generation function that leverages the JAX mesh for distributed computation. This function generates sequences from the VideoLLaMA model based on the constructed inputs. + +## Output Decoding + +Generated sequences are decoded back to text, with special handling to trim outputs at the end-of-sequence token and compile the final responses. + +## Usage + +The script is designed to be run with command-line arguments corresponding to the defined flags, allowing users to specify the prompt, input video or image file, and various model and sampling parameters. + +It is a complex integration of multiple components (video processing, tokenization, model sampling) into a cohesive pipeline for generative tasks with video and text inputs. \ No newline at end of file diff --git a/docs/vision_llama.md b/docs/vision_llama.md new file mode 100644 index 0000000..5bb4066 --- /dev/null +++ b/docs/vision_llama.md @@ -0,0 +1,28 @@ +# The FlaxVideoLLaMAForCausalLM + +A VideoLLaMA model architecture, specifically designed for causal language modeling tasks. This module is built to handle and generate sequences where each token is predicted based on the preceding tokens, making it suitable for tasks like text generation. Additionally, it extends these capabilities to multimodal inputs, allowing it to work with both text and visual data, which is particularly useful in scenarios where the model needs to understand and generate content based on a combination of textual and visual cues. + +## Causal Language Modeling + +It is tailored for generating sequences in a causal manner, meaning each token is predicted based on the previous tokens in the sequence. This is essential for generative tasks like story generation, where the narrative flows logically from one sentence to the next. + +## Multimodal Input Handling +The module can process both text and visual inputs, making it versatile for a range of applications, from generating descriptive captions for images to creating content that seamlessly integrates textual and visual information. + +## Configurable Generation + +It offers a variety of settings for sequence generation, such as controlling the maximum length of the generated sequences, specifying the end-of-sequence token, and adjusting the randomness of generation through temperature and top-k sampling parameters. + +## Efficient Generation with Caching + +The module uses a caching mechanism to speed up the generation process, especially for autoregressive generation where each token's prediction can benefit from the computations done for the previous tokens. + +## Flexible Output Formats + +It can provide outputs in different formats, catering to various downstream needs. For example, it can return just the last hidden state, all hidden states, and attention scores depending on the configuration. + +## Generation Strategies Support + +The module supports different generation strategies, including greedy decoding and sampling with temperature, allowing users to balance between the diversity and accuracy of the generated sequences. + +This module is a part of the broader VideoLLaMA framework with handling large-scale models and data. The FlaxVideoLLaMAForCausalLM is particularly noteworthy for its ability to bridge the gap between traditional NLP tasks and the emerging field of multimodal AI. \ No newline at end of file diff --git a/docs/vqgan.md b/docs/vqgan.md new file mode 100644 index 0000000..0f4b9a7 --- /dev/null +++ b/docs/vqgan.md @@ -0,0 +1,71 @@ +# Vqgan + +The provided code defines a VQGAN (Vector Quantized Generative Adversarial Network) model implementation in JAX/Flax, along with configuration and utility classes. Here's a high-level overview of the key components: + +## VQGAN Class + +### Purpose: + +Serves as the main interface for the VQGAN model, handling model initialization, encoding, and decoding functionalities. +#### Initialization: + +Loads model parameters from a checkpoint, sets up the model configuration, and initializes the VQGAN model with these parameters. It supports optional replication of parameters for distributed computing. + +### Encoding/Decoding: + +Provides methods for encoding input pixel values into a latent space and decoding these latent representations back into pixel space. These methods are optimized with jax.jit or jax.pmap for performance. + +## VQGANConfig Class + +### Purpose: + +Stores configuration settings for the VQGAN model, such as image resolution, number of channels, and model architecture specifics like hidden channels and attention resolutions. + +### Initialization: + +Can be instantiated with default settings or loaded from a configuration path. + +## VQGANModel Class + +### Purpose: + +Defines the actual VQGAN model architecture, including the encoder, decoder, and quantizer components. + +## Components: + +### Encoder: + +Transforms input pixel values into a higher-dimensional latent space. + +### Decoder: + +Converts encoded representations back into pixel space. + +### Quantizer: + +Quantizes the continuous latent space into discrete embeddings, facilitating the generation of diverse and high-quality images. +Encoder and Decoder Blocks + +### Purpose: + +Implement specific parts of the VQGAN encoder and decoder, respectively. + +### ResnetBlock: + +Implements a residual block with optional dropout, used in both the encoder and decoder for feature transformation. + +### AttnBlock: + +Adds self-attention mechanisms to the model, allowing it to capture long-range dependencies within the data. +Downsample/Upsample: Adjust the spatial dimensions of feature maps, either reducing (downsampling) or increasing (upsampling) them. +Utility Classes and Functions + +### VectorQuantizer: + +A module for quantizing the continuous latent representations into discrete tokens, a key component of the VQGAN architecture. +DownsamplingBlock and UpsamplingBlock: High-level wrappers for the downsampling and upsampling operations within the encoder and decoder, respectively. + +### MidBlock: + +A middle block used in the encoder and decoder, potentially incorporating attention mechanisms for enhanced representational capacity. +Overall, this VQGAN implementation is structured to provide flexibility in configuring the model for different resolutions and capacities, and it leverages JAX/Flax for efficient execution. The model is designed to be used in applications requiring high-quality image synthesis, such as image-to-image translation, super-resolution, and text-to-image generation when combined with other transformers. \ No newline at end of file diff --git a/lwm/data.py b/lwm/data.py index a6cb177..aa8ab60 100644 --- a/lwm/data.py +++ b/lwm/data.py @@ -14,10 +14,25 @@ class DatasetFactory(object): - """ Datset builder class. """ + """ + A factory class for creating dataset instances based on configuration parameters. + + This class supports loading different types of datasets, including those from Hugging Face's datasets library and custom JSON-formatted datasets. It facilitates the easy creation and configuration of datasets for machine learning models, particularly those dealing with NLP and potentially vision tasks. + + Methods: + get_default_config(updates=None): Returns the default configuration for a dataset, which can be customized with specific parameters. + load_dataset(config, tokenizer, **kwargs): Creates and returns an instance of a dataset based on the provided configuration and tokenizer. + + Usage: + config = DatasetFactory.get_default_config({'type': 'huggingface'}) + dataset = DatasetFactory.load_dataset(config, tokenizer) + Note: + DatasetFactory is a static class and should not be instantiated directly. + """ @staticmethod def get_default_config(updates=None): + config = ConfigDict() config.type = 'huggingface' config.text_processor = TextProcessor.get_default_config() @@ -53,7 +68,23 @@ def __init__(self): class TextProcessor(object): - """ Example processor that converts a dictionary of texts into tokens. """ + """ + Processes text data by encoding strings into token IDs using a provided tokenizer. + + This class supports adding special tokens (like BOS and EOS) and can process multiple fields from the data, concatenating them with a specified separator. It is designed to be flexible and customizable for different text processing needs. + + Parameters: + config (ConfigDict): Configuration parameters for text processing. + tokenizer: The tokenizer instance used to encode text strings into token IDs. + + Methods: + get_default_config(updates=None): Returns the default configuration for text processing. + __call__(example, has_aux=False, add_bos_token=True, add_eos_token=True): Processes a single example, returning token IDs and a corresponding loss mask. + + Usage: + text_processor = TextProcessor(config, tokenizer) + tokens, loss_mask = text_processor(example) + """ @staticmethod def get_default_config(updates=None): config = ConfigDict() @@ -124,6 +155,23 @@ def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=Tru class VisionTextProcessor(object): + """ + Designed for processing datasets that include both vision (image or video frames) and text data. + + This class handles encoding of textual data and integrates special tokens indicating the start and end of vision-related tokens. It supports custom configurations for handling vision data, including specifying the number of tokens per frame and the maximum number of frames. + + Parameters: + config (ConfigDict): Configuration parameters for vision-text processing. + tokenizer: The tokenizer instance used to encode text strings and vision tokens into token IDs. + + Methods: + get_default_config(updates=None): Returns the default configuration for vision-text processing. + __call__(example, has_aux=False, add_bos_token=True, add_eos_token=True): Processes a single example, returning token IDs, a corresponding loss mask, and a vision mask. + + Usage: + vision_text_processor = VisionTextProcessor(config, tokenizer) + tokens, loss_mask, vision_mask = vision_text_processor(example) + """ @staticmethod def get_default_config(updates=None): config = ConfigDict() @@ -240,10 +288,25 @@ def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=Tru class HuggingfaceDataset(object): - """ Huggingface dataset, where the dataset is loaded using the huggingface - datasets.load_dataset() function. """ + Loads and processes datasets from the Hugging Face datasets library. + + This class can stream data, making it efficient for large datasets. The data is processed in chunks, with each chunk transformed into model input and target arrays, along with a loss mask to indicate which tokens should contribute to the loss calculation. + + Parameters: + config (ConfigDict): Configuration parameters for the dataset. + tokenizer: The tokenizer instance used to encode text strings into token IDs. + text_processor (TextProcessor): The text processor instance used to preprocess text data. + + Methods: + get_default_config(updates=None): Returns the default configuration for a HuggingfaceDataset. + __iter__(): Provides an iterator over batches of processed data. + Usage: + dataset = HuggingfaceDataset(config, tokenizer, text_processor) + for batch in dataset: + # Process the batch + """ @staticmethod def get_default_config(updates=None): config = ConfigDict() @@ -331,10 +394,26 @@ def vocab_size(self): class JsonDataset(object): - """ JSON dataset, where each line of the data file contains a JSON - dictionary with text fields. """ + Loads and processes datasets from newline-delimited JSON files. + + Each line in the data file should contain a JSON object representing a data example. This class supports parallel processing to tokenize and encode the data efficiently across multiple CPU cores. Data examples are batched and padded as necessary to create fixed-size arrays suitable for model training. + + Parameters: + config (ConfigDict): Configuration parameters for the dataset. + tokenizer: The tokenizer instance used to encode text strings into token IDs. + text_processor (TextProcessor): The text processor instance used to preprocess text data. + node_info (dict): Information about the distributed training nodes, if applicable. + + Methods: + get_default_config(updates=None): Returns the default configuration for a JsonDataset. + __iter__(): Provides an iterator over batches of processed data. + Usage: + dataset = JsonDataset(config, tokenizer, text_processor, node_info) + for batch in dataset: + # Process the batch + """ @staticmethod def get_default_config(updates=None): config = ConfigDict() @@ -543,6 +622,26 @@ def vocab_size(self): class JsonVisionDataset(object): + """ + Similar to JsonDataset but specifically designed for datasets that include both vision and text data. + + It can handle special tokens for vision data and supports different modes for padding or not padding the batches. This class is ideal for tasks that involve both visual and textual inputs, such as video captioning or multimodal language models. + + Parameters: + config (ConfigDict): Configuration parameters for the dataset. + tokenizer: The tokenizer instance used to encode text strings and vision tokens into token IDs. + text_processor (VisionTextProcessor): The vision-text processor instance used to preprocess data. + node_info (dict): Information about the distributed training nodes, if applicable. + + Methods: + get_default_config(updates=None): Returns the default configuration for a JsonVisionDataset. + __iter__(): Provides an iterator over batches of processed data, including vision masks. + + Usage: + dataset = JsonVisionDataset(config, tokenizer, text_processor, node_info) + for batch in dataset: + # Process the batch + """ @staticmethod def get_default_config(updates=None): config = ConfigDict() diff --git a/lwm/llama.py b/lwm/llama.py index d6cbc64..9f14154 100644 --- a/lwm/llama.py +++ b/lwm/llama.py @@ -461,6 +461,20 @@ def apply_rotary_emb( class FlaxLLaMAAttention(nn.Module): + """ + Implements the attention mechanism for LLaMA models. + + This module computes self-attention, given query, key, and value tensors. It supports causal (autoregressive) masking and rotary position embeddings. + + Parameters: + config (LLaMAConfig): Configuration class instance for LLaMA. + dtype (jnp.dtype): The datatype of the computation (default: float32). + param_dtype (jnp.dtype): The datatype for the parameters (default: float32). + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision of the computation. + + Methods: + __call__(hidden_states, attention_mask, segment_ids, position_ids, deterministic, init_cache, output_attentions, fcm_mask): Computes the attention scores and the attended value vectors. + """ config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -714,6 +728,20 @@ def __call__( class FlaxLLaMAMLP(nn.Module): + """ + Implements the feed-forward network (MLP) used within each Transformer block of LLaMA models. + + This module applies two linear transformations with a GELU activation in between. + + Parameters: + config (LLaMAConfig): Configuration class instance for LLaMA. + dtype (jnp.dtype): The datatype of the computation (default: float32). + param_dtype (jnp.dtype): The datatype for the parameters (default: float32). + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision of the computation. + + Methods: + __call__(x, deterministic): Applies the MLP transformation on input tensor `x`. + """ config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -755,6 +783,27 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: class FlaxLLaMABlock(nn.Module): + """ + A Flax module representing a single Llama block, which is a core component of the + Llama architecture. This block typically consists of multi-head self-attention and + feed-forward neural network layers, with normalization and residual connections. + + Attributes: + d_model (int): The dimensionality of the model's hidden layers. + num_heads (int): The number of heads in the multi-head attention mechanism. + d_ff (int): The dimensionality of the feed-forward layer. + dropout_rate (float): Dropout rate applied to the output of the attention and + feed-forward layers. + attention_dropout_rate (float): Dropout rate applied to the attention weights. + deterministic (bool): If True, the module will behave deterministically, not + applying dropout. Useful for inference. + kernel_init (Callable): Initialization function for the kernel weights. + bias_init (Callable): Initialization function for the bias. + + Methods: + __call__: Applies the Llama block to the input data, including self-attention + and feed-forward layers, with appropriate normalization and residual connections. + """ config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -852,8 +901,27 @@ def __call__( class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. + Base class for all Flax LLaMA models. Handles the loading and initialization of + pretrained LLaMA models and provides methods for weight initialization, setting up + cache for efficient decoding, and the forward pass of inputs through the model. + + Inherits from FlaxPreTrainedModel which provides basic utilities and weight + management functionalities. + + Attributes: + config_class (LLaMAConfig): The configuration class associated with LLaMA models. + base_model_prefix (str): A string prefix used to differentiate the base model's + parameters from other potential components like task-specific heads. + module_class (nn.Module): The Flax module associated with the LLaMA model, defined + in subclasses. + + Args: + config (LLaMAConfig): Model configuration class instance. + input_shape (Tuple[int, int]): The shape of input data expected by the model. + seed (int): Random seed for initialization. + dtype (jnp.dtype): The datatype of the model's parameters. + _do_init (bool): Whether to initialize the model's weights. + **kwargs: Additional keyword arguments passed to the module class. """ config_class = LLaMAConfig @@ -873,6 +941,18 @@ def __init__( super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + """ + Initializes the weights of the model. If `params` is provided, the missing keys + in the provided `params` dictionary are filled with random parameters. + + Args: + rng (jax.random.PRNGKey): Random key used for parameter initialization. + input_shape (Tuple): The shape of the input for which the model is initialized. + params (FrozenDict, optional): Predefined model parameters. + + Returns: + FrozenDict: A dictionary containing the initialized parameters of the model. + """ # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -911,12 +991,15 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz def init_cache(self, batch_size, max_length): r""" + Initializes the cache used for fast auto-regressive decoding. Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. + Returns: + A dictionary representing the initialized cache. """ # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) @@ -944,6 +1027,26 @@ def __call__( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + """ + Runs the forward pass of the model. + + Args: + input_ids: Indices of input sequence tokens in the vocabulary. + attention_mask (optional): Mask to avoid performing attention on padding token indices. + segment_ids (optional): Segment token indices to indicate first and second portions of the inputs. + position_ids (optional): Positional indices of input tokens. + params (dict, optional): Predefined model parameters. + past_key_values (dict, optional): Dictionary containing precomputed key and value hidden states. + dropout_rng (jax.random.PRNGKey, optional): JAX random key for dropout. + train (bool): Whether the model is in training mode. + output_attentions (bool, optional): Whether to output attention weights. + output_hidden_states (bool, optional): Whether to output hidden states. + return_dict (bool, optional): Whether to return a dictionary instead of a tuple. + + Returns: + Model outputs, which could include logits, attentions, hidden states, + depending on the configuration and inputs. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1005,6 +1108,24 @@ def __call__( class FlaxLLaMABlockCollection(nn.Module): + """ + This module represents a collection of LLaMA blocks, encapsulating the entire + transformer architecture. It supports operations like forward pass through all + the transformer blocks, enabling features like deterministic execution, caching + for fast decoding, and outputting attention and hidden states. + + Attributes: + config (LLaMAConfig): Configuration object containing parameters for the LLaMA model. + dtype (jnp.dtype): Data type of the computation (default: jnp.float32). + param_dtype (jnp.dtype): Data type of the parameters (default: jnp.float32). + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision of the computation + to improve performance on certain devices. Default is None, using the highest available + precision. + + Methods: + __call__: Executes a forward pass through the block collection with the given inputs and + configuration. + """ config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -1023,6 +1144,25 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): + """ + Executes a forward pass through all the transformer blocks in the collection. + + Args: + hidden_states: Input tensor to the transformer blocks. + attention_mask (optional): Mask to avoid performing attention on padding token indices. + segment_ids (optional): Segment token indices to indicate first and second portions of the inputs. + position_ids (optional): Positional indices of input tokens. + deterministic (bool): If True, the module will perform deterministically. + init_cache (bool): If True, initializes a cache for fast auto-regressive decoding. + output_attentions (bool): Whether to return attention weights. + output_hidden_states (bool): Whether to return hidden states. + return_dict (bool): Whether to return outputs in a dictionary. + + Returns: + Tuple containing output hidden states, all hidden states (if `output_hidden_states` is True), + and all attentions (if `output_attentions` is True). If `return_dict` is True, these + are returned in a dictionary. + """ all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -1115,6 +1255,24 @@ def __call__( class FlaxLLaMAModule(nn.Module): + """ + This module implements the core LLaMA model architecture, comprising embedding layers, + transformer blocks, and a final layer normalization. It supports features like dropout, + deterministic execution, and optional output of attention and hidden states. + + Attributes: + config (LLaMAConfig): Configuration object containing parameters for the LLaMA model. + dtype (jnp.dtype): Data type of the computation (default: jnp.float32). + param_dtype (jnp.dtype): Data type of the parameters (default: jnp.float32). + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision of the computation + to improve performance on certain devices. Default is None, using the highest available + precision. + + Methods: + setup: Initializes the module components including embedding layers, transformer blocks, + and layer normalization. + __call__: Executes a forward pass through the LLaMA model with the given inputs and configuration. + """ config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -1182,9 +1340,34 @@ def __call__( @add_start_docstrings("", "") class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel): + """ + The FlaxLLaMAModel encapsulates the FlaxLLaMAModule to provide a convenient interface for + pre-trained LLaMA models. It handles weight initialization, offers a simple interface for downloading and + loading pre-trained models, and provides methods for forward passes using pre-trained weights. + + Inherits from FlaxLLaMAPreTrainedModel to leverage pre-trained model utilities and conventions. + + Attributes: + module_class: Points to the FlaxLLaMAModule which contains the actual model implementation. + """ module_class = FlaxLLaMAModule class FlaxLLaMAForCausalLMModule(nn.Module): + """ + This module is an extension of the FlaxLLaMAModule for causal language modeling tasks. It adds a language modeling + head on top of the transformer structure to generate logits for the next token prediction. + + Attributes: + config (LLaMAConfig): Configuration class with all the parameters of the LLaMA model. + dtype (jnp.dtype): Data type for computation (default: jnp.float32). + param_dtype (jnp.dtype): Data type for parameters (default: jnp.float32). + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision for computation to improve performance + on certain devices. Defaults to None, using the highest precision available. + + Methods: + setup: Initializes the transformer module and language modeling head. + __call__: Performs a forward pass through the transformer and language modeling head. + """ config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 @@ -1213,6 +1396,24 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): + """ + Forward pass through the LLaMA transformer module and the language modeling head. + + Args: + input_ids: Indices of input sequence tokens in the vocabulary. + attention_mask: Mask to avoid performing attention on padding token indices. + segment_ids: Token type IDs for segmenting inputs into different sequences. + position_ids: Position indices for input tokens. + deterministic (bool): If True, operations will be deterministic (suitable for inference). + init_cache (bool): If True, initializes a cache for fast auto-regressive decoding. + output_attentions (bool): Whether to return the attentions tensors of all attention layers. + output_hidden_states (bool): Whether to return the hidden states of all layers. + return_dict (bool): Whether to return a `FlaxCausalLMOutput` with named fields or a tuple. + + Returns: + FlaxCausalLMOutput containing logits for the next token predictions, hidden states, and attentions if + requested. + """ batch_size, seq_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones_like(input_ids) @@ -1251,12 +1452,32 @@ def __call__( @add_start_docstrings("", "") class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): + """ + This class provides a LLaMA model for causal language modeling tasks, wrapping the `FlaxLLaMAForCausalLMModule`. + It is designed to be used with pre-trained models and provides methods for generation tasks. + + Inherits from FlaxLLaMAPreTrainedModel to utilize utilities for pre-trained models. + + Attributes: + module_class: Points to the FlaxLLaMAForCausalLMModule containing the actual model implementation. + """ module_class = FlaxLLaMAForCausalLMModule def prepare_inputs_for_generation( self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, ): + """ + Prepares inputs for generation with an auto-regressive causal language model. + + Args: + input_ids: Indices of input sequence tokens in the vocabulary. + max_length: Maximum length of the sequence to be generated. + attention_mask (Optional[jax.Array]): Mask to avoid performing attention on padding token indices. + + Returns: + A dictionary containing the prepared inputs necessary for generation. + """ # initializing the cache batch_size, seq_length = input_ids.shape @@ -1290,10 +1511,20 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): class LLaMATokenizer(PreTrainedTokenizer): """ - Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding. + Constructs a LLaMA tokenizer, which is based on byte-level Byte-Pair-Encoding (BPE). This tokenizer is responsible + for turning input text into tokens that can be processed by the LLaMA model. + Args: - vocab_file (`str`): - Path to the vocabulary file. + vocab_file (str): Path to the file containing the vocabulary, which defines the mapping between tokens and their IDs. + unk_token (str, optional): The token to use for unknown tokens. Defaults to "". + bos_token (str, optional): The token to use for the beginning of sequence. Defaults to "". + eos_token (str, optional): The token to use for the end of sequence. Defaults to "". + sp_model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the SentencePiece processor. + add_bos_token (bool, optional): Whether to add the beginning of sequence token at the start of each sequence. Defaults to False. + add_eos_token (bool, optional): Whether to add the end of sequence token at the end of each sequence. Defaults to False. + + This tokenizer integrates with the PreTrainedTokenizer base class from Hugging Face's transformers library, providing + compatibility with a wide range of pre-trained models and utilities for text processing. """ vocab_files_names = VOCAB_FILES_NAMES @@ -1311,6 +1542,9 @@ def __init__( add_eos_token=False, **kwargs, ): + """ + Initializes the tokenizer with the given vocabulary file and special token settings. + """ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) self.vocab_file = vocab_file @@ -1334,38 +1568,55 @@ def __init__( @property def vocab_size(self): - """Returns vocab size""" + """ + Returns the size of the vocabulary, i.e., the number of unique tokens. + """ return self.sp_model.get_piece_size() @property def bos_token_id(self) -> Optional[int]: + """ + Returns the ID of the beginning of sequence token. + """ return self.sp_model.bos_id() @property def eos_token_id(self) -> Optional[int]: + """ + Returns the ID of the end of sequence token. + """ return self.sp_model.eos_id() def get_vocab(self): - """Returns vocab as a dict""" + """ + Returns the vocabulary as a dictionary mapping tokens to their corresponding IDs in a dict""" vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text): - """Returns a tokenized string.""" + """ + Tokenizes a text string into a list of tokens. + """ return self.sp_model.encode(text, out_type=str) def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" + """ + Converts a token (string) to its corresponding ID using the vocabulary. + """ return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" + """ + Converts an ID (integer) to its corresponding token (string) using the vocabulary. + """ token = self.sp_model.IdToPiece(index) return token def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" + """ + Converts a sequence of tokens (list of strings) back to a single string. + """ current_sub_tokens = [] out_string = "" prev_is_special = False @@ -1385,12 +1636,14 @@ def convert_tokens_to_string(self, tokens): def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: """ - Save the vocabulary and special tokens file to a directory. + Saves the vocabulary and special tokens file to the specified directory. + Args: - save_directory (`str`): - The directory in which to save the vocabulary. + save_directory (str): The directory where the vocabulary will be saved. + filename_prefix (Optional[str], optional): A prefix to add to the saved filename. + Returns: - `Tuple(str)`: Paths to the files saved. + Tuple[str]: The path(s) to the saved vocabulary file(s). """ if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") @@ -1410,6 +1663,16 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): if self.add_bos_token: + """ + Builds model inputs from sequences with special tokens added for sequence classification tasks. + + Args: + token_ids_0 (List[int]): The list of token IDs for the first sequence. + token_ids_1 (Optional[List[int]], optional): The list of token IDs for the second sequence, if applicable. + + Returns: + List[int]: The combined list of token IDs with special tokens added. + """ bos_token_ids = [self.bos_token_id] else: bos_token_ids = [] diff --git a/lwm/ring_attention.py b/lwm/ring_attention.py index a5d38a1..b37e2d5 100644 --- a/lwm/ring_attention.py +++ b/lwm/ring_attention.py @@ -18,6 +18,26 @@ def _ring_attention_fwd(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs): + """ + Computes the forward pass of the ring attention mechanism in a blockwise fashion. + This function divides the computation into blocks to reduce memory consumption and + leverages a ring structure for efficient data sharing across TPU cores. It supports + both standard and float32 precision for logits. + + Args: + q (jax.Array): Query tensor of shape (batch, q_len, num_heads, dim_per_head). + k (jax.Array): Key tensor of shape (batch, kv_len, num_heads, dim_per_head). + v (jax.Array): Value tensor of shape (batch, kv_len, num_heads, dim_per_head). + attn_bias (jax.Array): Attention bias to be added to the computed attention scores. + segment_ids (jax.Array): Segment IDs for segment-based masking in attention. + axis_name: Axis name for pmap to perform cross-device computation. + float32_logits (bool): Flag to use float32 precision for computing logits. + blockwise_kwargs (dict): Dictionary containing blockwise computation settings. + + Returns: + output (jax.Array): The output of the attention mechanism. + cache (tuple): A tuple containing tensors that will be used for backward pass. + """ if float32_logits: q, k = q.astype(jnp.float32), k.astype(jnp.float32) batch, q_len, num_heads, dim_per_head = q.shape @@ -45,6 +65,21 @@ def scan_kv_block(carry, idx): return output.astype(v.dtype), (output, q, k, v, attn_bias, segment_ids, denominator, max_score) def _ring_attention_bwd(axis_name, float32_logits, blockwise_kwargs, res, g): + """ + Computes the backward pass for the ring attention mechanism. This function is designed + to work in tandem with the forward pass, using the cached results from the forward pass + to efficiently compute gradients with respect to the inputs. + + Args: + axis_name: Axis name for pmap to perform cross-device gradient computation. + float32_logits (bool): Indicates if float32 precision was used for logits. + blockwise_kwargs (dict): Settings for blockwise computation used in the forward pass. + res (tuple): Cached results from the forward pass. + g (jax.Array): Gradient with respect to the output of the attention mechanism. + + Returns: + Gradients with respect to the inputs of the attention mechanism. + """ del float32_logits output, q, k, v, attn_bias, segment_ids, denominator, max_score = res batch, q_len, num_heads, dim_per_head = q.shape @@ -145,6 +180,13 @@ def ring_attention_standard(q, k, v, attn_mask, axis_name, float32_logits=True): def _blockwise_attention_fwd(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, bias, segment_ids, causal, query_chunk_size, key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse): + """ + Forward pass for the blockwise computation within the attention mechanism. This function + handles the computation within a single block, applying the attention mechanism to smaller + portions of the input tensors to optimize memory usage. + + [Additional parameters and return values documentation specific to the function] + """ batch, q_len, num_heads, dim_per_head = q.shape batch, kv_len, num_heads, dim_per_head = k.shape batch, kv_len, num_heads, dim_per_head = v.shape @@ -220,6 +262,13 @@ def skip_upper_half(carry, args): return numerator, denominator, max_score def _blockwise_attention_bwd(q, k, v, g, carry, q_chunk_idx_start, k_chunk_idx_start, bias, segment_ids, causal, query_chunk_size, key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse): + """ + Backward pass for the blockwise computation within the attention mechanism. This function + computes gradients for the blockwise attention forward pass, using the results cached during + the forward computation to efficiently compute gradients with respect to the block inputs. + + [Additional parameters and return values documentation specific to the function] + """ batch, q_len, num_heads, dim_per_head = q.shape batch, kv_len, num_heads, dim_per_head = k.shape batch, kv_len, num_heads, dim_per_head = v.shape @@ -543,18 +592,30 @@ class SegmentIds(NamedTuple): different segments in the input sequence. Each array is a list of ids (integers). Only the token with the same id can attend to each other. + Named tuple for segment IDs used in segment-based masking within the attention mechanism. + Segment IDs allow different parts of the input sequence to be treated as separate segments, + preventing attention across segments. Attributes: q: segment ids along the Q sequence. kv: segment ids along the KV sequence. """ + q: jax.Array # [q_seq_len] kv: jax.Array # [kv_seq_len] @dataclasses.dataclass(frozen=True) class BlockSizes: + """ + Class to represent sizes of blocks used in blockwise computation of the attention mechanism. + This class helps in configuring the dimensions for dividing the input tensors into smaller blocks, + optimizing the computation for memory efficiency and parallelism. + + Attributes are configured to represent the size of blocks along different dimensions such as + queries, keys, values, and their corresponding major blocks for handling larger sequences. + """ block_q: int block_k_major: int block_k: int diff --git a/lwm/train.py b/lwm/train.py index 1e1fe28..9f7bf45 100644 --- a/lwm/train.py +++ b/lwm/train.py @@ -55,6 +55,36 @@ def main(argv): + """ + Performs a single training step, including forward pass, loss calculation, and parameter update. + + This function handles the training logic for different data modalities (e.g., text, vision+text) and applies the model to the input batch. It computes the loss and accuracy, performs backpropagation to calculate gradients, and updates the model parameters using the provided optimizer. + + Inputs: + train_state: TrainState object + The current training state containing the model parameters, optimizer state, and other training-related information. It must be compatible with the model being trained and the optimizer in use. + + rng: JaxRNG object + A random number generator state used for stochastic operations within the model, such as dropout. It is crucial for reproducibility and controlled randomness in the training process. + + batch: dict + A batch of data to be processed by the model. The structure of this dictionary depends on the modality of the data. For text data, it typically includes 'input_tokens', 'target_tokens', and 'loss_masks'. For vision+text data, it might also include 'input_vision_masks' and 'target_vision_masks'. + + Process: + The function first applies sharding constraints to the batch data to ensure it's correctly partitioned for distributed training. It then defines a loss function that applies the model to the input data and computes the cross-entropy loss and accuracy. This loss function is differentiated to obtain gradients, which are then used to update the model parameters in the train state. + + Outputs: + Tuple containing the updated TrainState, new RNG state, and a dictionary of metrics. + + updated_train_state: TrainState object + The updated training state after applying the gradients to the model parameters. + + new_rng: JaxRNG object + The updated random number generator state after being used in this training step. + + metrics: dict + A dictionary containing various training metrics calculated during this step, such as loss, accuracy, learning rate, parameter norms, and gradient norms. The structure of this dictionary may vary depending on the data modality and specific metrics being tracked. + """ JaxDistributedConfig.initialize(FLAGS.jax_distributed) variant = tux.get_user_flags(FLAGS, FLAGS_DEF) flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF) diff --git a/lwm/vision_chat.py b/lwm/vision_chat.py index f398d11..e34a808 100644 --- a/lwm/vision_chat.py +++ b/lwm/vision_chat.py @@ -38,6 +38,35 @@ class Sampler: + """ + A sampler for generating outputs from the VideoLLaMA model using video and text inputs. + + This class encapsulates the process of video processing using a VQGAN model, tokenization of video and text inputs, + and sampling from the VideoLLaMA model to generate text outputs. + + Attributes: + mesh (jax.experimental.maps.Mesh): The JAX mesh configuration for parallel computation. + vqgan (VQGAN): The VQGAN model used for encoding video frames. + prefix_tokenizer (Tokenizer): Tokenizer for processing text inputs with specific configurations. + tokenizer (Tokenizer): General tokenizer for processing both text and video inputs. + n_tokens_per_frame (int): The number of tokens allocated for each frame in the input sequence. + min_buffer_size (int): The minimum buffer size for the input sequence to ensure model compatibility. + sharded_rng (jax.random.PRNGKey): The sharded random number generator key for generating outputs. + block_size (int): The block size used for partitioning the input sequence for efficient processing. + data_dim (int): The data dimensionality based on the mesh configuration. + config (VideoLLaMAConfig): The configuration object for the VideoLLaMA model. + model (FlaxVideoLLaMAForCausalLM): The VideoLLaMA model instance for generation. + params (dict): The parameters of the VideoLLaMA model, possibly sharded across devices. + model_ps (dict): PartitionSpec mapping for the model parameters. + + Methods: + __init__(): Initializes the sampler with the specified VQGAN checkpoint and tokenization settings. + _process_frame(image, size): Processes a single image frame to a consistent size and format. + _read_process_vision(path, max_n_frames): Reads and processes video or image input from a given path. + construct_input(prompts, max_n_frames): Constructs model input from prompts and processed video data. + _load_model(): Loads the VideoLLaMA model and its parameters, setting up sharding as needed. + __call__(prompts, max_n_frames): Generates text outputs from the VideoLLaMA model based on input prompts and video data. + """ def __init__(self): self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False) @@ -239,6 +268,18 @@ def __call__(self, prompts, max_n_frames): return output_text def main(argv): + """ + Main function for generating text outputs from the VideoLLaMA model using video and text inputs. + + This function initializes the necessary configurations, creates a Sampler instance, and processes + the provided prompts to generate responses. The generated responses are printed to the console. + + Args: + argv (list): List of command-line arguments passed to the script. + + The function reads the prompt and input file path from the command-line flags, initializes the JAX distributed + configuration, sets the random seed, and then uses the Sampler class to generate and print the model's output. + """ assert FLAGS.prompt != '' assert FLAGS.input_file != '' diff --git a/lwm/vision_generation.py b/lwm/vision_generation.py index c903402..e7f34f0 100644 --- a/lwm/vision_generation.py +++ b/lwm/vision_generation.py @@ -42,6 +42,22 @@ def main(argv): + """ + Main function to generate images or videos based on text prompts using a combination of a VQGAN model for + vision processing and a VideoLLaMA model for text-to-vision generation. + + The function initializes necessary configurations and models, processes text prompts, generates images or + video frames based on these prompts, and saves the output to a specified file. + + Args: + argv (list): Command-line arguments passed to the script, not directly used in the function. + + Raises: + ValueError: If the output file extension is unsupported. + + The output is either a single image or a sequence of images (video) saved to the specified output file, + depending on the command-line arguments provided. + """ assert FLAGS.output_file != '' if FLAGS.output_file.endswith('mp4'): assert FLAGS.n_frames > 1 diff --git a/lwm/vision_llama.py b/lwm/vision_llama.py index 154bec5..f8d6913 100644 --- a/lwm/vision_llama.py +++ b/lwm/vision_llama.py @@ -25,6 +25,20 @@ class VideoLLaMAConfig(LLaMAConfig): + """ + Configuration class for VideoLLaMA. This class extends the LLaMAConfig class, adding additional + configuration options specific to VideoLLaMA model. + + Parameters: + - vision_vocab_size (int): The size of the vision vocabulary. Default is 8448, representing 8192 + 256. + - tie_vision_embeddings (bool): Whether to tie the vision embeddings with some other embeddings. Default is False. + - sample_mode (str): Mode of sampling, can be 'all', 'text', or 'vision'. Determines the type of embeddings to be used. + - **kwargs: Additional keyword arguments passed to the superclass LLaMAConfig. + + Methods: + - get_partition_rules(scan_layers=False, scan_axis=0): Returns partitioning rules for model parallelism. + - load_config(path): Loads the model configuration from a given path or a predefined config. + """ model_type = "video_llama" def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs): @@ -35,10 +49,21 @@ def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_m @staticmethod def get_partition_rules(scan_layers=False, scan_axis=0): - """ Parition rules for GPTJ. Note that these rules are orderd, so that - the beginning rules match first. It is important to use - PartitionSpec() instead of None here because JAX does not treat - None as a pytree leaf. + """ + Defines the partitioning rules for distributing model parameters across devices. + These rules help in achieving model parallelism by splitting the model's computations. + + Parition rules for GPTJ. Note that these rules are orderd, so that + the beginning rules match first. It is important to use + PartitionSpec() instead of None here because JAX does not treat + None as a pytree leaf. + + Parameters: + - scan_layers (bool): Whether to scan through layers for partitioning. Default is False. + - scan_axis (int): Axis along which to scan and partition the layers. Default is 0. + + Returns: + - A tuple of partitioning rules, with each rule specifying the parameter name pattern and its corresponding PartitionSpec. """ if scan_layers: if scan_axis == 0: @@ -109,6 +134,18 @@ def get_partition_rules(scan_layers=False, scan_axis=0): @classmethod def load_config(cls, path): + """ + Loads the model configuration from a predefined configuration or a file. + + Parameters: + - path (str): Path to the configuration file or a key to a predefined configuration. + + Returns: + - An instance of this configuration class initialized with the loaded configuration. + + Raises: + - ValueError: If the path format is unrecognized or the file type is unsupported. + """ if path in VIDEO_LLAMA_STANDARD_CONFIGS: return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path]) load_type, load_path = path.split('::', 1) @@ -124,10 +161,20 @@ def load_config(cls, path): class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel): """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. + Base class for all Flax VideoLLaMA models. This class provides common functionalities for weight initialization, + and offers a simple interface for downloading and loading pretrained models. + + Attributes: + - config_class: Points to the VideoLLaMAConfig class. + - base_model_prefix (str): Prefix indicating the base model. + - module_class: Points to the FlaxVideoLLaMAModule class. To be defined by subclasses. + + Methods: + - __init__: Constructor for the class, initializing the model with the provided configuration. + - init_cache: Initializes the cache for autoregressive generation. + - init_weights: Initializes or loads the model weights. + - __call__: Forward pass for the model, with support for various Flax-specific features like PRNG keys. """ - config_class = VideoLLaMAConfig base_model_prefix = "transformer" module_class: nn.Module = None @@ -145,7 +192,16 @@ def __init__( super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_cache(self, batch_size, max_length): - # init input variables to retrieve cache + """ + Initializes the cache used in the transformer for faster sequential generation. + + Parameters: + - batch_size (int): Batch size for the input data. + - max_length (int): Maximum length of the sequence to be generated. + + Returns: + - Initialized cache variables for the model. + """ input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) segment_ids = jnp.zeros_like(input_ids) @@ -158,7 +214,17 @@ def init_cache(self, batch_size, max_length): return init_variables["cache"] def init_weights(self, rng, input_shape, params=None): - # init input tensors + """ + Initializes or loads the model weights. + + Parameters: + - rng: Random number generator (PRNG key) for weight initialization. + - input_shape: Shape of the input data. + - params (Optional): Pre-trained parameters to load into the model. + + Returns: + - Initialized model parameters, either from scratch or loaded from provided parameters. + """ input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) vision_masks = jnp.ones(input_ids.shape, dtype=bool) @@ -195,6 +261,26 @@ def __call__( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + """ + Forward pass for the VideoLLaMA model. + + Parameters: + - input_ids: Input token ids. + - vision_masks: Masks to distinguish vision tokens from text tokens. + - attention_mask (Optional): Mask to avoid performing attention on padding token indices. + - segment_ids (Optional): Segment ids for token types. + - position_ids (Optional): Position indices for the tokens in the input sequence. + - params (dict, Optional): Pre-trained parameters for model layers. + - past_key_values (dict, Optional): Cached past key values for faster generation. + - dropout_rng: PRNGKey for dropout layers. + - train (bool): Whether the model is in training mode. + - output_attentions (bool, Optional): Whether to return the attentions tensors. + - output_hidden_states (bool, Optional): Whether to return the hidden states. + - return_dict (bool, Optional): Whether to return a FlaxBaseModelOutput instance or a tuple. + + Returns: + - Model outputs, either as a FlaxBaseModelOutput object or a tuple, depending on return_dict. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -349,12 +435,36 @@ def __call__( class FlaxVideoLLaMAForCausalLMModule(nn.Module): + """ + The FlaxVideoLLaMAModule is a core component of the VideoLLaMA model architecture within the Flax framework. + It is responsible for processing input data through embeddings, dropout, a series of transformer blocks, + and layer normalization to produce a representation suitable for various tasks, such as causal language modeling. + + The module supports processing both textual and visual inputs by employing separate embeddings for each and + allows for flexible control over attention mechanisms, caching for efficient sequence generation, and the + inclusion of hidden states and attention distributions in the output. + + Attributes: + config (VideoLLaMAConfig): Configuration class for the VideoLLaMA model. + dtype (jnp.dtype): Data type for the module's parameters. Defaults to jnp.float32. + param_dtype (jnp.dtype): Data type for the parameters of submodules. Defaults to jnp.float32. + precision (Optional[Union[jax.lax.Precision, str]]): Numerical precision configuration for matrix multiplication operations. + + Methods: + setup(): Initializes the module's subcomponents, such as embeddings, dropout, transformer blocks, and layer normalization. + __call__(input_ids, vision_masks, attention_mask, segment_ids, position_ids, deterministic=True, init_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True): Defines the forward pass of the module. + """ config: VideoLLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): + """ + Initializes the module's subcomponents. This includes text and vision embeddings to process different types of inputs, + a dropout layer for regularization, a collection of transformer blocks for sequential processing, and a layer normalization + for stabilizing the outputs of the transformer blocks. + """ self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype) self.vision_head = nn.Dense( self.config.vision_vocab_size, @@ -386,6 +496,26 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, ): + """ + Processes input data through the VideoLLaMA module, returning the final hidden states along with optional hidden states + and attention distributions. + + Parameters: + input_ids (jnp.ndarray): Input token IDs for text and/or vision inputs. + vision_masks (jnp.ndarray): Masks to distinguish between text and vision tokens in the input. + attention_mask (jnp.ndarray): Mask to avoid performing attention on padding token indices. + segment_ids (jnp.ndarray): Segment IDs to distinguish different segments of the inputs (e.g., for tasks that involve multiple inputs like question answering). + position_ids (jnp.ndarray): Position indices for the tokens in the input sequence. + deterministic (bool): Specifies whether to operate in deterministic mode, typically used during inference to disable stochastic operations like dropout. + init_cache (bool): Whether to initialize a cache for efficiently generating sequences autoregressively. + output_attentions (bool): Whether to include attention distributions in the output. + output_hidden_states (bool): Whether to include all hidden states in the output. + return_dict (bool): Whether to return outputs in a dictionary format with named fields. + + Returns: + A FlaxBaseModelOutput object containing the last hidden state, all hidden states (if requested), + and attention distributions (if requested). If return_dict is False, a tuple of these components is returned. + """ batch_size, seq_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones_like(input_ids) @@ -447,11 +577,41 @@ def __call__( @add_start_docstrings("", "") class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel): + """ + This model is a part of the VideoLLaMA architecture for causal language modeling tasks. It is designed to handle + sequences for generative tasks, allowing for the generation of text conditioned on previous tokens as well as + multimodal inputs including vision data. The model supports various generation strategies and configurations. + + Inherits from FlaxVideoLLaMAPreTrainedModel to utilize pre-trained weights and other foundational functionalities. + + Attributes: + module_class: Points to the FlaxVideoLLaMAForCausalLMModule that defines the forward pass of the model. + + Methods: + prepare_inputs_for_generation: Prepares the input data and cache for the generation process. + update_inputs_for_generation: Updates the input data based on the outputs from the previous generation step. + _sample_vision: Generates sequences using the model in a causal manner, specifically for vision-related tasks. + generate_vision: A high-level method for generating data, wrapping around the `_sample_vision` method. + """ module_class = FlaxVideoLLaMAForCausalLMModule def prepare_inputs_for_generation( self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None ): + """ + Prepares the inputs and cache for generating sequences with the model. This method initializes the cache + for autoregressive generation and prepares attention masks and other necessary inputs. + + Parameters: + input_ids (jnp.ndarray): The input token IDs. + max_length (int): The maximum length of the sequence to be generated. + attention_mask (Optional[jax.Array]): The attention mask to avoid attending to padding tokens. + vision_masks (Optional[jnp.ndarray]): Masks to distinguish vision tokens from text tokens. + + Returns: + A dictionary containing prepared inputs for the model, including 'past_key_values' for caching, + 'attention_mask', 'position_ids', and 'vision_masks'. + """ # initializing the cache batch_size, seq_length = input_ids.shape @@ -474,6 +634,17 @@ def prepare_inputs_for_generation( } def update_inputs_for_generation(self, model_outputs, model_kwargs): + """ + Updates the inputs for the next generation step based on the outputs from the model. + + Parameters: + model_outputs: The outputs from the model's forward pass. + model_kwargs: The keyword arguments for the model's forward pass. + + Returns: + A dictionary with updated inputs for the model, including 'past_key_values', 'position_ids', + 'attention_mask', and 'vision_masks'. + """ return { "past_key_values": model_outputs.past_key_values, "position_ids": model_kwargs["position_ids"][:, -1:] + 1, @@ -495,6 +666,26 @@ def _sample_vision( params: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): + """ + Generates sequences for vision-related tasks using the model in a causal manner. This method supports various + generation strategies and configurations, allowing for controlled sequence generation. + + Parameters: + input_ids (jnp.ndarray): The input token IDs. For vision tasks, this can be set to None. + max_length (int, Optional): The maximum length of the sequence to be generated. + pad_token_id (int, Optional): The token ID used for padding. + eos_token_id (int, Optional): The token ID that signifies the end of a sequence. + prng_key (jnp.ndarray, Optional): The pseudo-random number generator key for stochastic operations like sampling. + logits_processor (FlaxLogitsProcessorList, Optional): Processors to manipulate logits during generation. + logits_warper (FlaxLogitsProcessorList, Optional): Processors to warp logits during generation. + cfg_scales (jnp.ndarray): Scales for controlling the randomness of generation in conditional generation tasks. + trace (bool): Whether to trace the execution for more efficient compilation, relevant for TPU execution. + params (Dict[str, jnp.ndarray], Optional): Pre-trained parameters for the model. + model_kwargs (Dict[str, jnp.ndarray], Optional): Additional model-specific keyword arguments. + + Returns: + A FlaxSampleOutput object containing the generated sequences. + """ # init values max_length = max_length if max_length is not None else self.generation_config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id @@ -599,6 +790,23 @@ def generate_vision( logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs, ): + """ + A high-level method for generating sequences, specifically designed for vision-related tasks. This method + wraps around the `_sample_vision` method, providing a user-friendly interface for sequence generation. + + Parameters: + input_ids (jnp.ndarray): The input token IDs for the initial context. + cfg_scales (jnp.ndarray): Scales for controlling the randomness of generation in conditional generation tasks. + generation_config (GenerationConfig, Optional): Configuration for controlling the generation process. + prng_key (jnp.ndarray, Optional): The pseudo-random number generator key for stochastic operations like sampling. + trace (bool): Whether to trace the execution for more efficient compilation, relevant for TPU execution. + params (Dict[str, jnp.ndarray], Optional): Pre-trained parameters for the model. + logits_processor (FlaxLogitsProcessorList, Optional): Processors to manipulate logits during generation. + **kwargs: Additional keyword arguments for generation configurations. + + Returns: + A FlaxSampleOutput object containing the generated sequences. + """ # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() diff --git a/lwm/vqgan.py b/lwm/vqgan.py index 77715f8..29f13e3 100644 --- a/lwm/vqgan.py +++ b/lwm/vqgan.py @@ -12,7 +12,28 @@ class VQGAN: + """ + A class representing a Vector Quantized Generative Adversarial Network (VQGAN) model. + + Attributes: + vqgan_checkpoint (str): Path to the VQGAN model checkpoint. + replicate (bool): Flag to indicate whether to replicate the model parameters across devices. + config (VQGANConfig): Configuration object for the VQGAN model. + params (dict): Loaded model parameters. + model (VQGANModel): The VQGAN model instance. + + Methods: + encode(pixel_values): Encodes input pixel values into latent representations. + decode(encoding): Decodes latent representations back into pixel values. + """ def __init__(self, vqgan_checkpoint, replicate=False): + """ + Initializes the VQGAN model with the given checkpoint and replication settings. + + Parameters: + vqgan_checkpoint (str): Path to the VQGAN model checkpoint. + replicate (bool, optional): Whether to replicate the model parameters across devices. Defaults to False. + """ assert vqgan_checkpoint != '' self.replicate = replicate self.config = VQGANConfig.get_default_config() @@ -24,6 +45,15 @@ def __init__(self, vqgan_checkpoint, replicate=False): self.model = VQGANModel(self.config) def _wrap_fn(self, fn): + """ + Wraps a function with JAX's jit or pmap for performance optimization based on the replication setting. + + Parameters: + fn (Callable): The function to be wrapped. + + Returns: + Callable: The wrapped function, optimized with jit or pmap. + """ if self.replicate: return jax.pmap(fn, devices=jax.local_devices()) else: @@ -31,6 +61,15 @@ def _wrap_fn(self, fn): @cached_property def _encode(self): + """ + Encodes input pixel values into latent representations using the VQGAN model. + + Parameters: + pixel_values (jnp.ndarray): The input pixel values to encode. + + Returns: + jnp.ndarray: The encoded latent representations. + """ def fn(pixel_values, params): return self.model.apply( {'params': params}, @@ -41,6 +80,15 @@ def fn(pixel_values, params): @cached_property def _decode(self): + """ + Decodes latent representations back into pixel values using the VQGAN model. + + Parameters: + encoding (jnp.ndarray): The latent representations to decode. + + Returns: + jnp.ndarray: The decoded pixel values. + """ def fn(encoding, params): return self.model.apply( {'params': params}, @@ -57,6 +105,28 @@ def decode(self, encoding): class VQGANConfig(PretrainedConfig): + """ + Configuration class for the VQGAN model, containing various architectural and training settings. + + Attributes: + resolution (int): The resolution of input images. + num_channels (int): The number of channels in the input images. + hidden_channels (int): The number of hidden channels in the model. + channel_mult (tuple): Multipliers for channels in different stages of the model. + num_res_blocks (int): The number of residual blocks in each stage. + attn_resolutions (tuple): Resolutions at which to apply self-attention. + no_attn_mid_block (bool): Whether to exclude self-attention in the middle block. + z_channels (int): The number of channels in the latent space. + num_embeddings (int): The number of embeddings in the quantizer. + quantized_embed_dim (int): The dimensionality of quantized embeddings. + dropout (float): The dropout rate. + resample_with_conv (bool): Whether to use convolutional layers for resampling. + commitment_cost (float): The commitment cost for vector quantization. + + Methods: + get_default_config(updates): Returns the default configuration, optionally updated with provided values. + load_config(path): Loads the configuration from a specified path. + """ model_type = "vqgan" def __init__( @@ -75,6 +145,24 @@ def __init__( resample_with_conv=True, commitment_cost=0.25 ): + """ + Initializes the VQGAN configuration with the specified settings. + + Parameters: + resolution (int, optional): Resolution of input images. Defaults to 256. + num_channels (int, optional): Number of channels in the input images. Defaults to 3. + hidden_channels (int, optional): Number of hidden channels in the model. Defaults to 128. + channel_mult (tuple, optional): Channel multipliers for different stages. Defaults to (1, 2, 2, 4, 6). + num_res_blocks (int, optional): Number of residual blocks in each stage. Defaults to 2. + attn_resolutions (tuple, optional): Resolutions for applying self-attention. Defaults to (). + no_attn_mid_block (bool, optional): Exclude self-attention in the middle block. Defaults to True. + z_channels (int, optional): Number of channels in the latent space. Defaults to 64. + num_embeddings (int, optional): Number of embeddings in the quantizer. Defaults to 8192. + quantized_embed_dim (int, optional): Dimensionality of quantized embeddings. Defaults to 64. + dropout (float, optional): Dropout rate. Defaults to 0.0. + resample_with_conv (bool, optional): Use convolution for resampling. Defaults to True. + commitment_cost (float, optional): Commitment cost for vector quantization. Defaults to 0.25. + """ self.resolution = resolution self.num_channels = num_channels self.hidden_channels = hidden_channels @@ -91,6 +179,15 @@ def __init__( @classmethod def get_default_config(cls, updates=None): + """ + Returns the default configuration for the VQGAN model, optionally updated with provided values. + + Parameters: + updates (dict, optional): A dictionary of updates to apply to the default configuration. + + Returns: + VQGANConfig: The default (or updated) model configuration. + """ config = function_args_to_config(cls.__init__) if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) @@ -99,13 +196,35 @@ def get_default_config(cls, updates=None): @classmethod def load_config(cls, path): + """ + Loads the VQGAN model configuration from the specified path. + + Parameters: + path (str): The path to the configuration file. + + Returns: + VQGANConfig: The loaded model configuration. + """ return cls.get_default_config(cls) class VQGANModel(nn.Module): + """ + The VQGAN model, consisting of an encoder, decoder, and a vector quantizer for latent space discretization. + + Attributes: + config (VQGANConfig): Configuration object for the VQGAN model. + + Methods: + encode(pixel_values): Encodes input pixel values into quantized latent representations. + decode(encoding, is_codebook_indices): Decodes quantized latent representations (or codebook indices) back into pixel values. + """ config: VQGANConfig def setup(self): + """ + Sets up the VQGAN model components, including the encoder, decoder, quantizer, and related convolutional layers. + """ self.encoder = Encoder(self.config) self.decoder = Decoder(self.config) self.quantize = VectorQuantizer( @@ -115,6 +234,15 @@ def setup(self): self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1]) def encode(self, pixel_values): + """ + Encodes input pixel values into quantized latent representations using the encoder and quantizer. + + Parameters: + pixel_values (jnp.ndarray): The input pixel values to encode. + + Returns: + tuple: A tuple containing the quantized latent representations and the corresponding codebook indices. + """ T = None if len(pixel_values.shape) == 5: # video T = pixel_values.shape[1] @@ -128,6 +256,16 @@ def encode(self, pixel_values): return quantized_states, codebook_indices def decode(self, encoding, is_codebook_indices=True): + """ + Decodes quantized latent representations (or codebook indices) back into pixel values using the decoder. + + Parameters: + encoding (jnp.ndarray): The quantized latent representations or codebook indices to decode. + is_codebook_indices (bool, optional): Flag indicating whether 'encoding' contains codebook indices. Defaults to True. + + Returns: + jnp.ndarray: The decoded pixel values. + """ if is_codebook_indices: encoding = self.quantize(None, encoding) T = None @@ -141,16 +279,43 @@ def decode(self, encoding, is_codebook_indices=True): return jnp.clip(reconstructed_pixel_values, -1, 1) def __call__(self, pixel_values): + """ + Processes input pixel values through the VQGAN model, encoding and then decoding them. + + Parameters: + pixel_values (jnp.ndarray): The input pixel values to process. + + Returns: + jnp.ndarray: The reconstructed pixel values after encoding and decoding. + """ encoding = self.encode(pixel_values)[1] recon = self.decode(encoding) return recon class Encoder(nn.Module): + """ + Encoder part of the VQGAN model, responsible for converting input images into a latent representation. + + Attributes: + config (VQGANConfig): Configuration object specifying model parameters. + + Methods: + __call__(pixel_values): Processes input pixel values through a series of convolutional and downsampling layers. + """ config: VQGANConfig @nn.compact def __call__(self, pixel_values): + """ + Transforms input pixel values into a high-dimensional latent space. + + Parameters: + pixel_values (jnp.ndarray): Input pixel values with shape [batch_size, height, width, channels]. + + Returns: + jnp.ndarray: The encoded latent representation of the input images. + """ assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values) for i_level in range(self.config.num_resolutions): @@ -165,10 +330,28 @@ def __call__(self, pixel_values): class Decoder(nn.Module): + """ + Decoder part of the VQGAN model, responsible for reconstructing images from latent representations. + + Attributes: + config (VQGANConfig): Configuration object specifying model parameters. + + Methods: + __call__(hidden_states): Processes latent representations through a series of upsampling and convolutional layers. + """ config: VQGANConfig @nn.compact def __call__(self, hidden_states): + """ + Reconstructs images from their latent representations. + + Parameters: + hidden_states (jnp.ndarray): Latent representations with shape [batch_size, latent_height, latent_width, latent_channels]. + + Returns: + jnp.ndarray: The reconstructed images with shape [batch_size, height, width, channels]. + """ hidden_states = nn.Conv( self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1], [3, 3] @@ -185,11 +368,31 @@ def __call__(self, hidden_states): class VectorQuantizer(nn.Module): + """ + Module for quantizing the continuous latent space of the encoder's output into discrete embeddings. + + Attributes: + n_e (int): The number of embeddings. + e_dim (int): The dimension of each embedding vector. + + Methods: + __call__(z, encoding_indices): Quantizes the input tensor z into a discrete set of embeddings. + """ n_e: int e_dim: int @nn.compact def __call__(self, z, encoding_indices=None): + """ + Quantizes the input tensor into a set of discrete embeddings or retrieves embeddings by indices. + + Parameters: + z (jnp.ndarray): The input tensor to quantize. + encoding_indices (jnp.ndarray, optional): Indices of embeddings to retrieve. + + Returns: + jnp.ndarray: The quantized tensor or the retrieved embeddings based on encoding_indices. + """ def quantize(encoding_indices): w = jax.device_put(embeddings) return w[(encoding_indices,)] @@ -222,11 +425,30 @@ def quantize(encoding_indices): class DownsamplingBlock(nn.Module): + """ + A downsampling block used in the Encoder, applying a series of convolutions and reducing spatial dimensions. + + Attributes: + config (VQGANConfig): Configuration object specifying model parameters. + block_idx (int): Index of the current block in the encoder. + + Methods: + __call__(hidden_states): Applies downsampling to the input tensor. + """ config: VQGANConfig block_idx: int @nn.compact def __call__(self, hidden_states): + """ + Applies convolutions and reduces the spatial dimensions of the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to the downsampling block. + + Returns: + jnp.ndarray: The downsampled tensor. + """ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] for _ in range(self.config.num_res_blocks): hidden_states = ResnetBlock( @@ -240,12 +462,32 @@ def __call__(self, hidden_states): class ResnetBlock(nn.Module): + """ + A residual block, applying a series of convolutions and adding the input tensor to the output. + + Attributes: + out_channels (int, optional): The number of output channels. + use_conv_shortcut (bool): Whether to use a convolutional layer in the shortcut connection. + dropout_prob (float): Dropout probability. + + Methods: + __call__(hidden_states): Processes the input tensor through the residual block. + """ out_channels: Optional[int] = None use_conv_shortcut: bool = False dropout_prob: float = 0.0 @nn.compact def __call__(self, hidden_states): + """ + Applies convolutions and a residual connection to the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to the residual block. + + Returns: + jnp.ndarray: The output tensor of the residual block. + """ out_channels = self.out_channels or hidden_states.shape[-1] residual = hidden_states hidden_states = nn.GroupNorm()(hidden_states) @@ -264,8 +506,23 @@ def __call__(self, hidden_states): class AttnBlock(nn.Module): + """ + An attention block, applying self-attention to the input tensor. + + Methods: + __call__(hidden_states): Applies self-attention to the input tensor. + """ @nn.compact def __call__(self, hidden_states): + """ + Applies self-attention to the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to the attention block. + + Returns: + jnp.ndarray: The output tensor with applied self-attention. + """ residual = hidden_states hidden_states = nn.GroupNorm()(hidden_states) query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states) @@ -284,10 +541,28 @@ def __call__(self, hidden_states): class Downsample(nn.Module): + """ + Downsamples the input tensor, optionally using a convolutional layer. + + Attributes: + with_conv (bool): Whether to use a convolutional layer for downsampling. + + Methods: + __call__(hidden_states): Applies downsampling to the input tensor. + """ with_conv: bool @nn.compact def __call__(self, hidden_states): + """ + Reduces the spatial dimensions of the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to downsample. + + Returns: + jnp.ndarray: The downsampled tensor. + """ if self.with_conv: hidden_states = jnp.pad( hidden_states, @@ -304,10 +579,28 @@ def __call__(self, hidden_states): class Upsample(nn.Module): + """ + Upsamples the input tensor, optionally using a convolutional layer. + + Attributes: + with_conv (bool): Whether to use a convolutional layer for upsampling. + + Methods: + __call__(hidden_states): Applies upsampling to the input tensor. + """ with_conv: bool @nn.compact def __call__(self, hidden_states): + """ + Increases the spatial dimensions of the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to upsample. + + Returns: + jnp.ndarray: The upsampled tensor. + """ B, H, W, C = hidden_states.shape hidden_states = jax.image.resize( hidden_states, @@ -320,11 +613,30 @@ def __call__(self, hidden_states): class UpsamplingBlock(nn.Module): + """ + An upsampling block used in the Decoder, applying a series of convolutions and increasing spatial dimensions. + + Attributes: + config (VQGANConfig): Configuration object specifying model parameters. + block_idx (int): Index of the current block in the decoder. + + Methods: + __call__(hidden_states): Applies upsampling to the input tensor. + """ config: VQGANConfig block_idx: int @nn.compact def __call__(self, hidden_states): + """ + Applies convolutions and increases the spatial dimensions of the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to the upsampling block. + + Returns: + jnp.ndarray: The upsampled tensor. + """ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] for _ in range(self.config.num_res_blocks + 1): hidden_states = ResnetBlock( @@ -338,12 +650,32 @@ def __call__(self, hidden_states): class MidBlock(nn.Module): + """ + A middle block used in both Encoder and Decoder, applying a series of transformations without changing spatial dimensions. + + Attributes: + config (VQGANConfig): Configuration object specifying model parameters. + no_attn (bool): Whether to exclude self-attention in this block. + dropout (float): Dropout probability. + + Methods: + __call__(hidden_states): Processes the input tensor through the middle block. + """ config: VQGANConfig no_attn: bool dropout: float @nn.compact def __call__(self, hidden_states): + """ + Applies convolutions, optional self-attention, and a residual connection to the input tensor. + + Parameters: + hidden_states (jnp.ndarray): Input tensor to the middle block. + + Returns: + jnp.ndarray: The output tensor of the middle block. + """ hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states) if not self.no_attn: hidden_states = AttnBlock()(hidden_states) diff --git a/requirements.txt b/requirements.txt index ff00ebd..869ab4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,54 @@ -flax==0.7.0 -optax==0.1.7 -chex==0.1.82 -einops -transformers==4.29.2 -datasets==2.13.0 -tqdm -ml_collections -wandb -gcsfs -requests -typing-extensions -sentencepiece -tux @ git+https://github.com/lhao499/tux.git -Pillow -ipdb -imageio[ffmpeg] -decord -tiktoken +aiohttp==3.9.3 +aiosignal==1.3.1 +archspec @ file:///croot/archspec_1697725767277/work +attrs==23.2.0 +boltons @ file:///work/ci_py311/boltons_1677685195580/work +Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work +certifi @ file:///croot/certifi_1700501669400/work/certifi +cffi @ file:///croot/cffi_1700254295673/work +charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work +conda @ file:///croot/conda_1701719518285/work +conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work +conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1702997573971/work/src +conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work +conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work +cryptography @ file:///croot/cryptography_1702070282333/work +datasets==2.17.0 +dill==0.3.8 +distro @ file:///croot/distro_1701455004953/work +filelock==3.13.1 +frozenlist==1.4.1 +fsspec==2023.10.0 +huggingface-hub==0.20.3 +idna @ file:///work/ci_py311/idna_1676822698822/work +jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work +jsonpointer==2.1 +libmambapy @ file:///croot/mamba-split_1698782620632/work/libmambapy +menuinst @ file:///croot/menuinst_1702390294373/work +multidict==6.0.5 +multiprocess==0.70.16 +numpy==1.26.4 +packaging @ file:///croot/packaging_1693575174725/work +pandas==2.2.0 +platformdirs @ file:///croot/platformdirs_1692205439124/work +pluggy @ file:///work/ci_py311/pluggy_1676822818071/work +pyarrow==15.0.0 +pyarrow-hotfix==0.6 +pycosat @ file:///croot/pycosat_1696536503704/work +pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work +pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work +PySocks @ file:///work/ci_py311/pysocks_1676822712504/work +python-dateutil==2.8.2 +pytz==2024.1 +PyYAML==6.0.1 +requests @ file:///croot/requests_1690400202158/work +ruamel.yaml @ file:///work/ci_py311/ruamel.yaml_1676838772170/work +six==1.16.0 +tqdm @ file:///croot/tqdm_1679561862951/work +truststore @ file:///croot/truststore_1695244293384/work +typing_extensions==4.9.0 +tzdata==2024.1 +urllib3 @ file:///croot/urllib3_1698257533958/work +xxhash==3.4.1 +yarl==1.9.4 +zstandard @ file:///work/ci_py311_2/zstandard_1679339489613/work diff --git a/scripts/eval_needle.py b/scripts/eval_needle.py old mode 100644 new mode 100755 diff --git a/scripts/eval_needle_multi.py b/scripts/eval_needle_multi.py old mode 100644 new mode 100755 From 5601c03a7528f71fbbf8da690da76d3c2b1224d7 Mon Sep 17 00:00:00 2001 From: Richard Date: Sat, 17 Feb 2024 19:19:18 +0000 Subject: [PATCH 2/2] added fork --- lwm/ring_attention.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/lwm/ring_attention.py b/lwm/ring_attention.py index b37e2d5..755010c 100644 --- a/lwm/ring_attention.py +++ b/lwm/ring_attention.py @@ -586,19 +586,14 @@ def ring_flash_attention_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32 NUM_SUBLANES = 8 class SegmentIds(NamedTuple): - """SegmentIds for Q and KV sequences. - - SegmentIds are used to generate segment mask, which prevents attention between - different segments in the input sequence. Each array is a list of ids - (integers). - Only the token with the same id can attend to each other. + """ Named tuple for segment IDs used in segment-based masking within the attention mechanism. Segment IDs allow different parts of the input sequence to be treated as separate segments, preventing attention across segments. Attributes: - q: segment ids along the Q sequence. - kv: segment ids along the KV sequence. + q (jax.Array): Segment IDs for the query sequences. + kv (jax.Array): Segment IDs for the key/value sequences. """