From f689ce5d39c6f1475dfc71503288ea2905c8f685 Mon Sep 17 00:00:00 2001 From: zachcp Date: Fri, 15 Nov 2024 02:30:15 -0500 Subject: [PATCH] Documentation Pass for Models (#2617) * links in chinese_clip * links for clip model * add mod docs for flux and llava * module doc for MMDIT and MIMI * add docs for a few more modesl * mod docs for bert naser and beit * add module docs for convmixer colpali codegeex and chatglm * add another series of moddocs * add fastvit-llama2_c * module docs mamba -> mobileone * module docs from moondream-phi3 * mod docs for quantized and qwen * update to yi * fix long names * Update llama2_c.rs * Update llama2_c_weights.rs * Fix the link for mimi + tweaks --------- Co-authored-by: Laurent Mazare --- candle-transformers/src/models/based.rs | 7 +++---- candle-transformers/src/models/beit.rs | 7 +++++++ candle-transformers/src/models/bert.rs | 6 ++++++ candle-transformers/src/models/bigcode.rs | 7 +++++++ candle-transformers/src/models/blip.rs | 7 +++++++ candle-transformers/src/models/blip_text.rs | 6 ++++++ candle-transformers/src/models/chatglm.rs | 7 +++++++ .../src/models/chinese_clip/mod.rs | 5 +++-- candle-transformers/src/models/clip/mod.rs | 5 +++-- .../src/models/codegeex4_9b.rs | 7 +++++++ candle-transformers/src/models/colpali.rs | 5 +++++ candle-transformers/src/models/convmixer.rs | 7 +++++++ candle-transformers/src/models/convnext.rs | 14 ++++++------- candle-transformers/src/models/dac.rs | 7 ++++++- .../src/models/depth_anything_v2.rs | 6 ++++++ candle-transformers/src/models/dinov2.rs | 5 +++++ candle-transformers/src/models/dinov2reg4.rs | 7 +++++++ candle-transformers/src/models/distilbert.rs | 5 +++++ .../src/models/efficientnet.rs | 5 +++++ .../src/models/efficientvit.rs | 7 +++---- candle-transformers/src/models/encodec.rs | 6 ++++++ candle-transformers/src/models/eva2.rs | 6 ++++++ candle-transformers/src/models/falcon.rs | 6 ++++++ candle-transformers/src/models/fastvit.rs | 8 +++---- candle-transformers/src/models/flux/mod.rs | 7 +++++++ candle-transformers/src/models/gemma.rs | 6 ++++++ candle-transformers/src/models/gemma2.rs | 6 ++++++ candle-transformers/src/models/glm4.rs | 6 ++++++ candle-transformers/src/models/granite.rs | 7 +++++++ candle-transformers/src/models/hiera.rs | 8 +++---- candle-transformers/src/models/jina_bert.rs | 6 ++++++ candle-transformers/src/models/llama.rs | 6 ++++++ candle-transformers/src/models/llama2_c.rs | 6 ++++++ .../src/models/llama2_c_weights.rs | 6 ++++++ candle-transformers/src/models/llava/mod.rs | 10 +++++++++ candle-transformers/src/models/mamba.rs | 9 ++++++-- candle-transformers/src/models/marian.rs | 6 ++++++ candle-transformers/src/models/metavoice.rs | 6 ++++++ candle-transformers/src/models/mimi/mod.rs | 11 +++++++--- candle-transformers/src/models/mistral.rs | 7 +++++++ candle-transformers/src/models/mixformer.rs | 7 +++++++ candle-transformers/src/models/mixtral.rs | 17 +++++++++++++++ candle-transformers/src/models/mmdit/mod.rs | 9 ++++++++ candle-transformers/src/models/mobileclip.rs | 16 ++++++++++++++ candle-transformers/src/models/mobilenetv4.rs | 11 +++++++--- candle-transformers/src/models/mobileone.rs | 5 +++-- candle-transformers/src/models/moondream.rs | 11 ++++++++++ candle-transformers/src/models/mpt.rs | 8 +++++++ candle-transformers/src/models/olmo.rs | 16 ++++++++++++++ .../src/models/openclip/mod.rs | 8 +++++++ candle-transformers/src/models/paligemma.rs | 16 ++++++++++++++ candle-transformers/src/models/parler_tts.rs | 17 +++++++++++++++ candle-transformers/src/models/persimmon.rs | 16 ++++++++++++++ candle-transformers/src/models/phi.rs | 17 +++++++++++++++ candle-transformers/src/models/phi3.rs | 19 +++++++++++++++++ candle-transformers/src/models/pixtral/mod.rs | 8 +++++++ .../src/models/quantized_blip.rs | 16 ++++++++++++++ .../src/models/quantized_blip_text.rs | 17 +++++++++++++++ .../src/models/quantized_llama.rs | 17 +++++++++++++++ .../src/models/quantized_llama2_c.rs | 16 ++++++++++++++ .../src/models/quantized_metavoice.rs | 16 ++++++++++++++ .../src/models/quantized_mistral.rs | 17 +++++++++++++++ .../src/models/quantized_mixformer.rs | 13 ++++++++++++ .../src/models/quantized_moondream.rs | 15 +++++++++++++ .../src/models/quantized_mpt.rs | 18 ++++++++++++++++ .../src/models/quantized_phi.rs | 17 +++++++++++++++ .../src/models/quantized_phi3.rs | 15 +++++++++++++ .../src/models/quantized_qwen2.rs | 15 +++++++++++++ .../src/models/quantized_recurrent_gemma.rs | 17 +++++++++++++++ .../src/models/quantized_rwkv_v5.rs | 17 +++++++++++++++ .../src/models/quantized_rwkv_v6.rs | 18 ++++++++++++++++ .../src/models/quantized_stable_lm.rs | 15 +++++++++++++ .../src/models/quantized_t5.rs | 18 ++++++++++++++-- candle-transformers/src/models/qwen2.rs | 17 +++++++++++++++ candle-transformers/src/models/qwen2_moe.rs | 18 ++++++++++++++++ .../src/models/recurrent_gemma.rs | 21 +++++++++++++++++-- candle-transformers/src/models/repvgg.rs | 11 ++++++++++ candle-transformers/src/models/resnet.rs | 14 ++++++++++--- candle-transformers/src/models/rwkv_v5.rs | 17 +++++++++++++++ candle-transformers/src/models/rwkv_v6.rs | 16 ++++++++++++++ candle-transformers/src/models/segformer.rs | 16 ++++++++++++++ .../src/models/segment_anything/mod.rs | 8 +++++++ candle-transformers/src/models/siglip.rs | 8 +++++++ .../src/models/stable_diffusion/mod.rs | 9 ++++++++ candle-transformers/src/models/stable_lm.rs | 15 +++++++++++++ candle-transformers/src/models/starcoder2.rs | 17 +++++++++++++++ .../src/models/stella_en_v5.rs | 17 +++++++++++++++ candle-transformers/src/models/t5.rs | 18 ++++++++++++++-- candle-transformers/src/models/trocr.rs | 16 ++++++++++++++ candle-transformers/src/models/vgg.rs | 15 +++++++++++-- candle-transformers/src/models/vit.rs | 17 +++++++++++++++ candle-transformers/src/models/whisper/mod.rs | 8 +++++++ .../src/models/wuerstchen/mod.rs | 9 ++++++++ candle-transformers/src/models/yi.rs | 16 +++++++++++++- 94 files changed, 1001 insertions(+), 51 deletions(-) diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index aa28f52333..c54ff96629 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -1,10 +1,9 @@ //! Based from the Stanford Hazy Research group. //! //! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 -//! - -//! Original code: -//! https://github.com/HazyResearch/based +//! - [Arxiv](https://arxiv.org/abs/2402.18668) +//! - [Github](https://github.com/HazyResearch/based) +//! use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index 8f6284a8e6..2f61d9d6f1 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -1,3 +1,10 @@ +//! Based on the BEIT vision-language model. +//! +//! See "BEIT: BERT Pre-Training of Image Transformers", Bao et al. 2021 +//! - [Arxiv](https://arxiv.org/abs/2106.08254) +//! - [Github](https://github.com/microsoft/unilm/tree/master/beit) +//! + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index bdc0385deb..a7db075cbb 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,3 +1,9 @@ +//! BERT (Bidirectional Encoder Representations from Transformers) +//! +//! See "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", Devlin et al. 2018 +//! - [Arxiv](https://arxiv.org/abs/1810.04805) +//! - [Github](https://github.com/google-research/bert) +//! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index f6b4a4efdc..8ed1462b1c 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,3 +1,10 @@ +//! BigCode implementation in Rust based on the GPT-BigCode model. +//! +//! See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2305.06161) +//! - [Github](https://github.com/bigcode-project/starcoder) +//! + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index e0b0b6a596..0330386574 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -1,3 +1,10 @@ +//! Based on the BLIP paper from Salesforce Research. +//! +//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" +//! - [Arxiv](https://arxiv.org/abs/2201.12086) +//! - [Github](https://github.com/salesforce/BLIP) +//! + use super::blip_text; use super::with_tracing::{conv2d, linear, Conv2d, Linear}; use candle::{Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index 1862abef4b..aceaf4ac1b 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -1,3 +1,9 @@ +//! Implementation of BLIP text encoder/decoder. +//! +//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" +//! https://arxiv.org/abs/2201.12086 +//! + use super::with_tracing::{linear, Embedding, Linear}; use candle::{Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 0686b34ef3..8d5d9ec601 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -1,3 +1,10 @@ +//! Implementation of the ChatGLM2/3 models from THUDM. +//! +//! See: +//! - ChatGLM3: ["ChatGLM3: Advancing Multilingual Conversational Language Models with High-Quality Data"](https://github.com/THUDM/ChatGLM3) +//! - ChatGLM2: ["ChatGLM2: An Open Bilingual Chat LLM"](https://github.com/THUDM/ChatGLM2-6B) +//! + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 0f6eedd0f2..86616baa1c 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -3,8 +3,9 @@ //! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/OFA-Sys/Chinese-CLIP -//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py +//! - [GH Link](https://github.com/OFA-Sys/Chinese-CLIP) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! use candle::{Module, Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs index 3dd5fb485b..e83f27e388 100644 --- a/candle-transformers/src/models/clip/mod.rs +++ b/candle-transformers/src/models/clip/mod.rs @@ -3,8 +3,9 @@ //! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/openai/CLIP -//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip +//! - [GH Link](https://github.com/openai/CLIP) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip) + use self::{ text_model::{Activation, ClipTextTransformer}, vision_model::ClipVisionTransformer, diff --git a/candle-transformers/src/models/codegeex4_9b.rs b/candle-transformers/src/models/codegeex4_9b.rs index aaa99fd96d..baf4745922 100644 --- a/candle-transformers/src/models/codegeex4_9b.rs +++ b/candle-transformers/src/models/codegeex4_9b.rs @@ -1,3 +1,10 @@ +//! CodeGeeX4 - A multi-language code generation model +//! +//! See "CodeGeeX: A Pre-Trained Model For Code Generation with Multilingual Evaluations on HumanEval-X", Qian et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2303.17568) +//! - [Github](https://github.com/THUDM/CodeGeeX) +//! + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs index 1299b0a410..16ca4eb304 100644 --- a/candle-transformers/src/models/colpali.rs +++ b/candle-transformers/src/models/colpali.rs @@ -1,3 +1,8 @@ +//! Colpali Model for text/image similarity scoring. +//! +//! Colpali combines a vision encoder with an efficient LM for retrieving content. +//! + use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index f5abfa5da3..e095f793a4 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -1,3 +1,10 @@ +//! ConvMixer implementation. +//! +//! See "Patches Are All You Need?" by Trockman et al. 2022 +//! - [Arxiv](https://arxiv.org/abs/2201.09792) +//! - [Github](https://github.com/locuslab/convmixer) +//! + use candle::Result; use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder}; diff --git a/candle-transformers/src/models/convnext.rs b/candle-transformers/src/models/convnext.rs index 94b1833ec2..d791895f1d 100644 --- a/candle-transformers/src/models/convnext.rs +++ b/candle-transformers/src/models/convnext.rs @@ -1,15 +1,13 @@ //! ConvNeXt implementation. //! -//! See "A ConvNet for the 2020s" Liu et al. 2022 -//! +//! See ["A ConvNet for the 2020s" Liu et al. 2022](https://arxiv.org/abs/2201.03545) //! and -//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023 -//! - +//! ["ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023](https://arxiv.org/abs/2301.00808) +//! //! Original code: -//! https://github.com/facebookresearch/ConvNeXt/ -//! https://github.com/facebookresearch/ConvNeXt-V2/ -//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py +//! - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/) +//! - [ConvNeXt-V2](https://github.com/facebookresearch/ConvNeXt-V2/) +//! - [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py) use candle::shape::ShapeWithOneHole; use candle::{Result, D}; diff --git a/candle-transformers/src/models/dac.rs b/candle-transformers/src/models/dac.rs index fa6c8c7120..78728b4d09 100644 --- a/candle-transformers/src/models/dac.rs +++ b/candle-transformers/src/models/dac.rs @@ -1,4 +1,9 @@ -/// Adapted from https://github.com/descriptinc/descript-audio-codec +//! Implementation of the Descript Audio Codec (DAC) model +//! +//! See: [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec) +//! +/// An efficient neural codec for compressing/decompressing audio +/// use crate::models::encodec; use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder}; diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 9eee6d1130..411b0764ff 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -1,3 +1,9 @@ +//! Implementation of the Depth Anything model from FAIR. +//! +//! See: +//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything) +//! + use candle::D::Minus1; use candle::{Module, Result, Tensor}; use candle_nn::ops::Identity; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 706dfda0e7..df8834d1f7 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -1,3 +1,8 @@ +//! Implementation of the DINOv2 models from Meta Research. +//! +//! See: +//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs index 1d81703c9c..0d2320e14c 100644 --- a/candle-transformers/src/models/dinov2reg4.rs +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -1,3 +1,10 @@ +//! Implementation of the DINOv2 revision (4 regularization) +//! +//! See: +//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! +//! This code implements the regularization tokens version with 4 regularization tokens. +//! use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs index f899d772a2..fad76cfcce 100644 --- a/candle-transformers/src/models/distilbert.rs +++ b/candle-transformers/src/models/distilbert.rs @@ -1,3 +1,8 @@ +//! Implementation of DistilBert, a distilled version of BERT. +//! +//! See: +//! - ["DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter"](https://arxiv.org/abs/1910.01108) +//! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index f15c9c797e..ecca2509ae 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -1,3 +1,8 @@ +//! Implementation of EfficientBert, an efficient variant of BERT for computer vision tasks. +//! +//! See: +//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462) +//! use candle::{Result, Tensor, D}; use candle_nn as nn; use nn::{Module, VarBuilder}; diff --git a/candle-transformers/src/models/efficientvit.rs b/candle-transformers/src/models/efficientvit.rs index b17c4ea0a1..9724f702a6 100644 --- a/candle-transformers/src/models/efficientvit.rs +++ b/candle-transformers/src/models/efficientvit.rs @@ -1,9 +1,8 @@ //! EfficientViT (MSRA) inference implementation based on timm. //! -//! See "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention" -//! https://arxiv.org/abs/2305.07027 - -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py +//! See ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027) +//! +//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py) use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index ba6686f605..a8d509ce8b 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -1,3 +1,9 @@ +//! EnCodec neural audio codec based on the Encodec implementation. +//! +//! See ["High Fidelity Neural Audio Compression"](https://arxiv.org/abs/2210.13438) +//! +//! Based on implementation from [huggingface/transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py) + #![allow(unused)] use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D}; use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder}; diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs index 013c385d1c..ee84cca43c 100644 --- a/candle-transformers/src/models/eva2.rs +++ b/candle-transformers/src/models/eva2.rs @@ -1,3 +1,9 @@ +//! EVA-2 inference implementation. +//! +//! See ["EVA-02: A Visual Representation for Neon Genesis"](https://arxiv.org/abs/2303.11331) +//! +//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py) + use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 50ec66f316..c75b4d70d3 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,3 +1,9 @@ +//! Falcon language model inference implementation +//! +//! See ["Falcon: a new approach to large language models"](https://huggingface.co/blog/falcon) +//! +//! Based on implementation from [Huggingface Transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon) + use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; use serde::Deserialize; diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 8eae8bb200..4e29665358 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -1,9 +1,9 @@ -//! FastViT inference implementation based on timm +//! # FastViT inference implementation based on timm //! -//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization" -//! https://arxiv.org/pdf/2303.14189 +//! ## Description +//! See ["FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"](https://arxiv.org/pdf/2303.14189) //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py +//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py) use candle::{DType, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index b0c8a6939a..8eb928f557 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,3 +1,10 @@ +//! Flux Model +//! +//! Flux is a series of text-to-image generation models based on diffusion transformers. +//! +//! - [GH Link](https://github.com/black-forest-labs/flux) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! use candle::{Result, Tensor}; pub trait WithForward { diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index c22a39480c..4b656d6a7f 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,3 +1,9 @@ +//! Gemma inference implementation. +//! +//! See ["Gemma: Open Models Based on Gemini Technology"](https://blog.google/technology/developers/gemma-open-ai-model/) +//! +//! Based on implementation from Google and PyTorch + use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/gemma2.rs b/candle-transformers/src/models/gemma2.rs index f0d650479e..ec23efc529 100644 --- a/candle-transformers/src/models/gemma2.rs +++ b/candle-transformers/src/models/gemma2.rs @@ -1,3 +1,9 @@ +//! Gemma LLM architecture (Google) inference implementation. +//! +//! See ["Gemma: Open Models Based on Gemini Technology"](https://blog.google/technology/developers/gemma-open-models/) +//! +//! Based on implementations from Google and OpenLLM + use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index 3b436eaa6d..de6581d0b7 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -1,3 +1,9 @@ +//! GLM-4 inference implementation. +//! +//! An open bilingual language model with 130B parameters. +//! +//! Based on implementation from [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/granite.rs b/candle-transformers/src/models/granite.rs index 6d25c339b2..f1b2c4db5b 100644 --- a/candle-transformers/src/models/granite.rs +++ b/candle-transformers/src/models/granite.rs @@ -1,3 +1,10 @@ +//! Granite is a Long Context Transformer Language Model. +//! +//! A high performance transformer model optimized for efficient processing +//! of very long context sequences +//! +//! Based on implementation from [Nod.ai](https://github.com/nod-ai/granite) + use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/hiera.rs b/candle-transformers/src/models/hiera.rs index 52efb78ea3..39f8d639b6 100644 --- a/candle-transformers/src/models/hiera.rs +++ b/candle-transformers/src/models/hiera.rs @@ -1,9 +1,9 @@ -//! Hiera inference implementation based on timm. +//! [Hiera] inference implementation based on timm. //! -//! See "Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles" -//! https://arxiv.org/abs/2306.00989 +//! See "[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]" +//! [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]: https://arxiv.org/abs/2306.00989 //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py +//! [Hiera]: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py use candle::{Result, D}; use candle_nn::{conv2d, layer_norm, linear, ops::softmax, Conv2dConfig, Func, VarBuilder}; diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs index 1f0fae1ee4..40535a8bb9 100644 --- a/candle-transformers/src/models/jina_bert.rs +++ b/candle-transformers/src/models/jina_bert.rs @@ -1,3 +1,9 @@ +//! # JinaBERT inference implementation +//! +//! Based on implementation from huggingface for Jina BERT and its variants +//! +//! See: [Jina Embeddings on HuggingFace](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) + use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e77697340e..4396063ff7 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,3 +1,9 @@ +//! Llama inference implementation. +//! +//! See ["LLaMA: Open and Efficient Foundation Language Models"](https://arxiv.org/abs/2302.13971) +//! +//! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) + use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 923a270646..d825d8e4dd 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -1,3 +1,9 @@ +//! Llama2 inference implementation. +//! +//! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288) +//! +//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/llama2_c_weights.rs b/candle-transformers/src/models/llama2_c_weights.rs index e5a8bb8806..8149c214c9 100644 --- a/candle-transformers/src/models/llama2_c_weights.rs +++ b/candle-transformers/src/models/llama2_c_weights.rs @@ -1,3 +1,9 @@ +//! Llama2 inference implementation. +//! +//! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288) +//! +//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation + use byteorder::{LittleEndian, ReadBytesExt}; use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index 1ed3b50c63..44a00bf9a1 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -1,3 +1,13 @@ +//! The LLaVA (Large Language and Vision Assistant) model. +//! +//! This provides the main model implementation combining a vision tower (CLIP) with +//! language model (Llama) for multimodal capabilities. +//! +//! The architecture implements the training-free projection technique from the paper: +//! [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485). +//! +//! - [GH Link](https://github.com/haotian-liu/LLaVA/tree/main) +//! pub mod config; pub mod utils; diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index a75ee87a6e..18a0285ff6 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -1,5 +1,10 @@ -/// A fast implementation of mamba for inference only. -/// This is based on: https://github.com/LaurentMazare/mamba.rs +//! Mamba inference implementation. +//! +//! See ["Mamba: Linear-Time Sequence Modeling with Selective State Spaces"](https://arxiv.org/abs/2312.00752) +//! +//! Based on reference implementation from the AlbertMamba project +//! A fast implementation of mamba for inference only. +//! Based on Laurent Mazare's rust implementation: [mamba.rs](https://github.com/LaurentMazare/mamba.rs) use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index e93370c23e..c4ba0a154d 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -1,3 +1,9 @@ +//! Marian Neural Machine Translation +//! +//! See "Marian: Fast Neural Machine Translation in C++" Junczys-Dowmunt et al. 2018 +//! - [ACL Anthology](https://aclanthology.org/P18-4020/) +//! - [Github](https://github.com/marian-nmt/marian) +//! use super::with_tracing::{linear, Embedding, Linear}; use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 43de594f9d..92d3ffba08 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -1,3 +1,9 @@ +//! MetaVoice Studio ML Models +//! +//! See MetaVoice's TTS and voice cloning models: +//! - [Github](https://github.com/metavoiceio/metavoice-src) +//! - [Website](https://studio.metavoice.ai/) + use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/mimi/mod.rs b/candle-transformers/src/models/mimi/mod.rs index dc40e38e29..f19f9ae5fa 100644 --- a/candle-transformers/src/models/mimi/mod.rs +++ b/candle-transformers/src/models/mimi/mod.rs @@ -1,9 +1,14 @@ -// Adapted from the reference implementation at: -// https://github.com/kyutai-labs/moshi +//! mimi model +//! +//! Mimi is a state-of-the-art audio neural codec. +//! +//! - [HuggingFace Model Card](https://huggingface.co/kyutai/mimi) +//! - [GitHub](https://github.com/kyutai-labs/moshi) +//! + // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. - pub use candle; pub use candle_nn; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e8f7a7c4b8..f927f88b2d 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,3 +1,10 @@ +//! Mixtral Model, based on the Mistral architecture +//! +//! See Mistral and Mixtral at: +//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral) +//! - [Github](https://github.com/mistralai/mistral-src) +//! + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 700829e33b..2c2909c3e0 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -1,3 +1,10 @@ +//! MixFormer (Microsoft's Phi Architecture) +//! +//! See "Textbooks Are All You Need II: phi-1.5 technical report", Lin et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2309.05463) +//! - [Github](https://huggingface.co/microsoft/phi-1_5) +//! + use crate::models::with_tracing::{linear, Embedding as E, Linear}; /// MixFormer model. /// https://huggingface.co/microsoft/phi-1_5 diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs index a578d6fed0..70115e10a3 100644 --- a/candle-transformers/src/models/mixtral.rs +++ b/candle-transformers/src/models/mixtral.rs @@ -1,3 +1,20 @@ +//! Mixtral Model, a sparse mixture of expert model based on the Mistral architecture +//! +//! See Mixtral model details at: +//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral) +//! - [Mixtral-8x7B Blog Post](https://mistral.ai/news/mixtral-of-experts/) +//! +//! The model uses a mixture of experts architecture with: +//! - 8 experts per layer +//! - Top 2 expert routing +//! - Sliding window attention +//! - RoPE embeddings +//! +//! References: +//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py) +//! - [Mixtral Blog Post](https://mistral.ai/news/mixtral-of-experts/) +//! + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs index 9c4db6e085..ce4872e0b2 100644 --- a/candle-transformers/src/models/mmdit/mod.rs +++ b/candle-transformers/src/models/mmdit/mod.rs @@ -1,3 +1,12 @@ +//! Mix of Multi-scale Dilated and Traditional Convolutions +//! +//! Mix of Multi-scale Dilated and Traditional Convolutions (MMDiT) is an architecture +//! introduced for Stable Diffusion 3, with the MMDiT-X variant used in Stable Diffusion 3.5. +//! +//! - [Research Paper](https://arxiv.org/abs/2403.03206) +//! - ComfyUI [reference implementation](https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py) +//! - Stability-AI [MMDiT-X implementation](https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py) + pub mod blocks; pub mod embedding; pub mod model; diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs index 45a5dbad9f..f0baf9e10c 100644 --- a/candle-transformers/src/models/mobileclip.rs +++ b/candle-transformers/src/models/mobileclip.rs @@ -1,3 +1,19 @@ +//! Mobile CLIP model, combining a lightweight vision encoder with a text encoder +//! +//! A mobile-optimized CLIP implementation that uses: +//! - FastViT as the vision encoder +//! - OpenCLIP text encoder +//! - Projection layers to align the feature spaces +//! +//! See model details at: +//! - [FastViT](https://arxiv.org/abs/2303.14189) +//! - [OpenCLIP](https://github.com/mlfoundations/open_clip) +//! +//! References: +//! - [MobileVLM](https://huggingface.co/mobileVLM) +//! - [MetaCLIP](https://arxiv.org/abs/2309.16671) +//! + use super::fastvit; use super::openclip::text_model; use candle::{Result, Tensor, D}; diff --git a/candle-transformers/src/models/mobilenetv4.rs b/candle-transformers/src/models/mobilenetv4.rs index 7cbae7c385..ab1e70803f 100644 --- a/candle-transformers/src/models/mobilenetv4.rs +++ b/candle-transformers/src/models/mobilenetv4.rs @@ -1,9 +1,14 @@ +//! # MobileNet-v4 +//! //! MobileNet-v4 inference implementation based on timm. //! -//! See "MobileNetV4 - Universal Models for the Mobile Ecosystem" -//! https://arxiv.org/abs/2404.10518 +//! ## Paper +//! +//! ["MobileNetV4 - Universal Models for the Mobile Ecosystem"](https://arxiv.org/abs/2404.10518) +//! +//! ## References //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py +//! - [PyTorch Implementation](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py) use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/mobileone.rs b/candle-transformers/src/models/mobileone.rs index 674da40b97..e8836745b9 100644 --- a/candle-transformers/src/models/mobileone.rs +++ b/candle-transformers/src/models/mobileone.rs @@ -1,7 +1,8 @@ +//! # MobileOne +//! //! MobileOne inference implementation based on timm and candle-repvgg //! -//! See "MobileOne: An Improved One millisecond Mobile Backbone" -//! https://arxiv.org/abs/2206.04040 +//! See ["MobileOne: An Improved One millisecond Mobile Backbone"](https://arxiv.org/abs/2206.04040) use candle::{DType, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index cde59d43d6..d351d7c019 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -1,3 +1,14 @@ +//! MoonDream Model vision-to-text +//! +//! The model consists of: +//! - Vision encoder using a ViT-style architecture +//! - Text decoder based on Microsoft's Phi model +//! - Vision projection module to align vision and text embeddings +//! +//! References: +//! - [MoonDream Original Implementation](https://github.com/vikhyat/moondream) +//! + use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear}; use candle::{IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index d46524fcc2..d4170d6bff 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -1,3 +1,11 @@ +//! Module implementing the MPT (Multi-Purpose Transformer) model +//! +//! References: +//! - [MPT Model used by replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py) +//! - [Configuration](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py) +//! +//! The model uses grouped query attention and alibi positional embeddings. + use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; /// MPT model used by replit-code-v1_5-3b /// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py diff --git a/candle-transformers/src/models/olmo.rs b/candle-transformers/src/models/olmo.rs index 983a33340a..6cf5b1f79d 100644 --- a/candle-transformers/src/models/olmo.rs +++ b/candle-transformers/src/models/olmo.rs @@ -1,3 +1,19 @@ +//! OLMo (Open Language Model) implementation +//! +//! See OLMo model details at: +//! - [Hugging Face](https://huggingface.co/allenai/OLMo) +//! - [OLMo Paper](https://allenai.org/olmo) +//! +//! The model uses: +//! - RoPE embeddings +//! - Sliding window attention +//! - Transformer architecture +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/allenai/OLMo) +//! - [OLMo Paper](https://allenai.org/olmo) +//! + use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{linear_b, linear_no_bias, Activation, LayerNorm, Linear, VarBuilder}; use std::sync::Arc; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs index ee2a501d6a..dacb627f9e 100644 --- a/candle-transformers/src/models/openclip/mod.rs +++ b/candle-transformers/src/models/openclip/mod.rs @@ -1 +1,9 @@ +//! Open Contrastive Language-Image Pre-Training +//! +//! Open Contrastive Language-Image Pre-Training (OpenCLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! - [GH Link](https://github.com/mlfoundations/open_clip) +//! + pub mod text_model; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs index a5e7f694f5..e992869923 100644 --- a/candle-transformers/src/models/paligemma.rs +++ b/candle-transformers/src/models/paligemma.rs @@ -1,3 +1,19 @@ +//! Multimodal multi-purpose model combining Gemma-based language model with SigLIP image understanding +//! +//! See PaLiGemma details at: +//! - [Paper](https://arxiv.org/abs/2402.05257) +//! - [Google Blog Post](https://blog.research.google/2024/02/paligemma-scaling-language-image.html) +//! +//! The model is a multimodal combination of: +//! - SigLIP vision encoder +//! - Gemma language model +//! - Cross-projection layers +//! +//! References: +//! - [HuggingFace Implementation](https://huggingface.co/google/paligemma-3b) +//! - [Paper: PaLI-3 and Beyond: Scaling Language-Image Learning](https://arxiv.org/abs/2402.05257) +//! + use crate::models::{gemma, siglip}; use candle::{Module, Result, Tensor}; use candle_nn::{linear, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs index da40124741..0c08aa9427 100644 --- a/candle-transformers/src/models/parler_tts.rs +++ b/candle-transformers/src/models/parler_tts.rs @@ -1,3 +1,20 @@ +//! Parler Model implementation for parler_tts text-to-speech synthesis +//! +//! Implements a transformer-based decoder architecture for generating audio tokens +//! from text using discrete tokens. The model converts text into audio segments +//! using multiple codebooks of quantized audio tokens. +//! +//! The model architecture includes: +//! - Multi-head attention layers for text and audio processing +//! - Feed-forward networks +//! - Layer normalization +//! - Positional embeddings +//! - Multiple codebook prediction heads +//! +//! The implementation follows the original parler_tts architecture while focusing +//! on audio token generation for text-to-speech synthesis. +//! + use crate::generation::LogitsProcessor; use crate::models::t5; use candle::{IndexOp, Result, Tensor}; diff --git a/candle-transformers/src/models/persimmon.rs b/candle-transformers/src/models/persimmon.rs index afee7c83ee..0996decf55 100644 --- a/candle-transformers/src/models/persimmon.rs +++ b/candle-transformers/src/models/persimmon.rs @@ -1,3 +1,19 @@ +//! Persimmon Model +//! +//! A transformer language model for efficient inference and general-purpose tasks. See Persimmon model details at: +//! - [Hugging Face](https://huggingface.co/adept/persimmon-8b-base) +//! +//! The model uses a standard transformer architecture with: +//! - Layer normalization for Q/K attention +//! - RoPE embeddings with partial rotary factor +//! - ReLU activation +//! - Separate number of attention heads and KV heads +//! +//! References: +//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/modeling_persimmon.py) +//! - [Persimmon Config](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py) +//! + use candle::DType; use serde::Deserialize; diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index bffc14faed..36a08bb3c6 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -1,3 +1,20 @@ +//! Microsoft Phi model implementation +//! +//! See Phi model details at: +//! - [Phi-2 Model](https://huggingface.co/microsoft/phi-2) +//! +//! The Phi series are decoder-only transformers designed for code and language tasks. +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE embeddings +//! - Layer normalization +//! - QK normalization +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-2) +//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-2/tree/main) +//! + use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; /// Phi model. /// https://huggingface.co/microsoft/phi-2 diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index a5e3e9a948..7ce9e987c9 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -1,3 +1,22 @@ +//! Microsoft Phi-3 model implementation +//! +//! See Phi model details at: +//! - [Phi-3 Model](https://huggingface.co/microsoft/phi-3) +//! +//! The Phi series are decoder-only transformers designed for code and language tasks. +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE embeddings +//! - Layer normalization +//! - QK normalization +//! - Mixed activation functions +//! - Improved context window handling +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-3) +//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-3/tree/main) +//! + // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs index 9d0eccfb57..53f9ef9182 100644 --- a/candle-transformers/src/models/pixtral/mod.rs +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -1,3 +1,11 @@ +//! Pixtral Language-Image Pre-Training +//! +//! Pixtral is an architecture trained for multimodal learning +//! using images paired with text descriptions. +//! +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) +//! + pub mod llava; pub mod vision_model; diff --git a/candle-transformers/src/models/quantized_blip.rs b/candle-transformers/src/models/quantized_blip.rs index 31e22b4570..acba9ba191 100644 --- a/candle-transformers/src/models/quantized_blip.rs +++ b/candle-transformers/src/models/quantized_blip.rs @@ -1,3 +1,19 @@ +//! BLIP model implementation with quantization support. +//! +//! BLIP is a vision-language model for image understanding and generation tasks. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Vision encoder using ViT architecture +//! - Text decoder using BERT-style transformer +//! - Cross-attention between vision and text features +//! - Support for 8-bit quantization +//! +//! References: +//! - [BLIP Paper](https://arxiv.org/abs/2201.12086) +//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip) +//! + use super::quantized_blip_text as blip_text; use crate::quantized_nn::{layer_norm, linear, Linear}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_blip_text.rs b/candle-transformers/src/models/quantized_blip_text.rs index 652205d6f6..61e468e78b 100644 --- a/candle-transformers/src/models/quantized_blip_text.rs +++ b/candle-transformers/src/models/quantized_blip_text.rs @@ -1,3 +1,20 @@ +//! Quantized BLIP text module implementation. +//! +//! Provides the text decoder portion of the BLIP model with 8-bit quantization. +//! Uses a BERT-style transformer architecture for text processing. +//! +//! Key components: +//! - Text embeddings layer with position embeddings +//! - Multi-head self attention layers +//! - Cross-attention for vision-text fusion +//! - Layer normalization and feed-forward layers +//! - Quantized linear transformations +//! +//! References: +//! - [BLIP Paper](https://arxiv.org/abs/2201.12086) +//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip) +//! + use crate::models::with_tracing::QMatMul; use crate::quantized_nn::{layer_norm, linear, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 04a50981b6..7efd385d61 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -1,3 +1,20 @@ +//! Quantized llama model implementation. +//! +//! This provides a quantized implementation of the llama language model architecture. +//! The model implements parameter efficient quantization for reduced memory usage +//! while maintaining model quality. +//! +//! Key characteristics: +//! - Transformer decoder architecture +//! - Support for 2/3/4/8-bit quantization +//! - Optimized memory usage through quantization +//! - Configurable model sizes and parameter counts +//! +//! References: +//! - [LLaMA Paper](https://arxiv.org/abs/2302.13971) +//! - [LLaMA Model](https://github.com/facebookresearch/llama) +//! + use std::collections::HashMap; use crate::quantized_nn::RmsNorm; diff --git a/candle-transformers/src/models/quantized_llama2_c.rs b/candle-transformers/src/models/quantized_llama2_c.rs index cbb8aad8da..3eb14bb9e6 100644 --- a/candle-transformers/src/models/quantized_llama2_c.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -1,3 +1,19 @@ +//! Quantized Llama2 model implementation. +//! +//! This provides an 8-bit quantized implementation of Meta's LLaMA2 language model +//! for reduced memory usage and faster inference. +//! +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE position embeddings +//! - Grouped Query Attention +//! - 8-bit quantization of weights +//! +//! References: +//! - [LLaMA2 Paper](https://arxiv.org/abs/2307.09288) +//! - [LLaMA2 Technical Report](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) +//! + use super::llama2_c::{Cache, Config}; use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs index 947ab750cd..ac72162715 100644 --- a/candle-transformers/src/models/quantized_metavoice.rs +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -1,3 +1,19 @@ +//! Quantized MetaVoice model implementation. +//! +//! MetaVoice is a conditional text-to-speech model based on a transformer architecture. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Transformer-based autoregressive decoder +//! - Speaker conditioning +//! - Support for 8-bit quantization +//! - Key-value caching for efficient inference +//! - RMS normalization layers +//! +//! References: +//! - [MetaVoice Code](https://github.com/metavoiceio/metavoice) +//! + use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 0583810a0d..cdb687d573 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -1,3 +1,20 @@ +//! Mistral model implementation with quantization support. +//! +//! Mistral is a large language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Sliding window attention mechanism +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Mistral Paper](https://arxiv.org/abs/2310.06825) +//! - [Model Card](https://huggingface.co/mistralai/Mistral-7B-v0.1) +//! + use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index fa72672a9e..8736544625 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -1,3 +1,16 @@ +//! Module containing quantized MixFormer model implementation. +//! +//! MixFormer is an efficient transformer variant for text generation that uses +//! mixture-of-experts and parallel attention/feed-forward blocks. +//! This implementation provides quantization for reduced memory usage. +//! +//! Key features: +//! - Parallel attention and feed-forward computation +//! - Rotary positional embeddings +//! - Optional key-value caching +//! - Support for 8-bit quantization +//! + use crate::quantized_nn::{layer_norm, linear, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_moondream.rs b/candle-transformers/src/models/quantized_moondream.rs index 1b125d9306..c1daffafe4 100644 --- a/candle-transformers/src/models/quantized_moondream.rs +++ b/candle-transformers/src/models/quantized_moondream.rs @@ -1,3 +1,18 @@ +//! Implementation of a quantized Moondream vision language model. +//! +//! Moondream is a lightweight vision-language model for image understanding and generation. +//! This module provides a quantized version for reduced memory usage and faster inference. +//! +//! Key features: +//! - ViT-based vision encoder +//! - Phi-2 text decoder model +//! - Memory efficient 8-bit quantization +//! - Optimized for efficient deployment +//! +//! References: +//! - [Moondream Model](https://github.com/vikhyat/moondream) +//! + use crate::models::moondream::{Config, VisionConfig}; use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel; use crate::quantized_nn::{layer_norm, linear_b, Linear}; diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 056fcac2d1..44d8566b7b 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -1,3 +1,21 @@ +//! Quantized MPT model implementation. +//! +//! MPT (MPT-7B) is a causal transformer model series optimized for code generation. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Multi-Query Grouped Attention (MQA) +//! - Support for KV-caching +//! - Pre-computed ALiBi attention biases +//! - Support for 8-bit quantization +//! +//! References: +//! - [Replit Code Models](https://huggingface.co/replit/replit-code-v1_5-3b) +//! - [MPT-7B Implementation](https://github.com/mosaicml/llm-foundry) +//! +/// MPT model used by replit-code-v1_5-3b +/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py +/// use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; /// MPT model used by replit-code-v1_5-3b diff --git a/candle-transformers/src/models/quantized_phi.rs b/candle-transformers/src/models/quantized_phi.rs index 0ebf7f4d4b..b874ad94ea 100644 --- a/candle-transformers/src/models/quantized_phi.rs +++ b/candle-transformers/src/models/quantized_phi.rs @@ -1,3 +1,20 @@ +//! Phi2 model implementation with quantization support. +//! +//! Phi2 is a 2.7B parameter language model using scaled-up Transformer decoder architecture. +//! This implementation provides quantization for reduced memory and compute usage. +//! +//! Key characteristics: +//! - Partial attention with learned mixing to reduce quadratic costs +//! - Layer reuse for improved inference efficiency +//! - Linear transformations with scalar mixing +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Phi2 Paper](https://arxiv.org/abs/2309.05463) +//! - [Model Card](https://huggingface.co/microsoft/phi-2) +//! + use std::collections::HashMap; use candle::quantized::gguf_file; diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index 257ad98379..51a75f3895 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -1,3 +1,18 @@ +//! Phi3 model implementation with quantization support. +//! +//! Phi3 is a language model intended for research purposes. +//! This implementation provides quantization for reduced memory usage. +//! +//! Key characteristics: +//! - Multi-head attention +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for quantization +//! +//! References: +//! - [Model Card](https://huggingface.co/microsoft/phi-3) +//! + use std::collections::HashMap; use candle::quantized::gguf_file; diff --git a/candle-transformers/src/models/quantized_qwen2.rs b/candle-transformers/src/models/quantized_qwen2.rs index addfab2b04..c04da56925 100644 --- a/candle-transformers/src/models/quantized_qwen2.rs +++ b/candle-transformers/src/models/quantized_qwen2.rs @@ -1,3 +1,18 @@ +//! Qwen2 model implementation with quantization support. +//! +//! Qwen2 is a chat-optimized language model that supports 8-bit quantization +//! for reduced memory usage and faster inference. +//! +//! Key characteristics: +//! - Group Query Attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Model Card](https://huggingface.co/Qwen/Qwen2) +//! + use crate::{quantized_nn::RmsNorm, utils::repeat_kv}; use candle::{ quantized::{gguf_file, QMatMul}, diff --git a/candle-transformers/src/models/quantized_recurrent_gemma.rs b/candle-transformers/src/models/quantized_recurrent_gemma.rs index c28064da6b..e40daa1f33 100644 --- a/candle-transformers/src/models/quantized_recurrent_gemma.rs +++ b/candle-transformers/src/models/quantized_recurrent_gemma.rs @@ -1,3 +1,20 @@ +//! Recurrent Gemma model implementation with quantization support. +//! +//! Gemma is a large language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Recurrent blocks with gated recurrent units +//! - Convolution and attention blocks +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Gemma Paper](https://arxiv.org/abs/2401.06751) +//! - [Model Card](https://ai.google.dev/gemma) +//! + use crate::quantized_nn::{linear_b as linear, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_rwkv_v5.rs b/candle-transformers/src/models/quantized_rwkv_v5.rs index c41d7b4e08..cc5204bf24 100644 --- a/candle-transformers/src/models/quantized_rwkv_v5.rs +++ b/candle-transformers/src/models/quantized_rwkv_v5.rs @@ -1,3 +1,20 @@ +//! RWKV v5 model implementation with quantization support. +//! +//! RWKV v5 is an attention-free language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - GroupNorm layer normalization +//! - Time-mixing layers +//! - State-based sequential processing +//! - Support for 8-bit quantization +//! +//! References: +//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v5 Architecture](https://www.rwkv.com/v5) +//! + use crate::{ quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, quantized_var_builder::VarBuilder, diff --git a/candle-transformers/src/models/quantized_rwkv_v6.rs b/candle-transformers/src/models/quantized_rwkv_v6.rs index 81150c3ec0..91288c2e61 100644 --- a/candle-transformers/src/models/quantized_rwkv_v6.rs +++ b/candle-transformers/src/models/quantized_rwkv_v6.rs @@ -1,3 +1,21 @@ +//! RWKV v6 model implementation with quantization support. +//! +//! RWKV is a linear attention model that combines the efficiency of RNNs +//! with the parallelizable training of Transformers. Version 6 builds on previous +//! versions with further optimizations. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - Time mixing layers +//! - Channel mixing layers +//! - RMSNorm for normalization +//! - Support for 8-bit quantization +//! +//! References: +//! - [RWKV Architecture](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v6 Release](https://huggingface.co/BlinkDL/rwkv-6) +//! + use crate::{ quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, quantized_var_builder::VarBuilder, diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index da4475220f..d74ed743d8 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,3 +1,18 @@ +//! Module for quantized StableLM implementation. +//! +//! StableLM is a series of open-source large language models +//! optimized for performance and stability. This implementation +//! provides quantization support for efficient model deployment. +//! +//! Key characteristics: +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [StableLM](https://github.com/Stability-AI/StableLM) +//! + use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 88224d2da3..9f770d69d9 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -1,5 +1,19 @@ -// T5 Text Model, quantized version -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +//! T5 model implementation with quantization support. +//! +//! T5 is an encoder-decoder model pre-trained on a multi-task mixture of supervised +//! and unsupervised tasks. This implementation provides quantization for reduced +//! memory and compute requirements. +//! +//! Key characteristics: +//! - Encoder-decoder architecture +//! - Layer normalization +//! - Relative positional encodings +//! - Support for 8-bit quantization +//! +//! References: +//! - [T5 Paper](https://arxiv.org/abs/1910.10683) +//! - [Model Card](https://huggingface.co/t5-base) +//! - Original model from [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating}; use crate::models::with_tracing::QMatMul; diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 187ea98a10..8dbca36b3e 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -1,3 +1,20 @@ +//! Qwen2 model implementation with quantization support. +//! +//! Qwen2 is a large language model from Alibaba optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Streaming decode support +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Qwen2 Model](https://huggingface.co/Qwen/Qwen2-7B) +//! - [Model Card](https://huggingface.co/Qwen/Qwen2-7B) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs index 8d1d2f70f4..40e0279748 100644 --- a/candle-transformers/src/models/qwen2_moe.rs +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -1,3 +1,21 @@ +//! Qwen2 model implementation with Mixture of Experts support. +//! +//! Qwen2 is a large language model using sparse Mixture of Experts (MoE). +//! This implementation provides support for sparsely activated MoE layers. +//! +//! Key characteristics: +//! - Mixture of Experts architecture +//! - Sparse expert activation +//! - Shared expert routing mechanism +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! +//! References: +//! - [Qwen2 Paper](https://arxiv.org/abs/2401.08985) +//! - [Model Card](https://huggingface.co/Qwen/Qwen2-7B-beta) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/recurrent_gemma.rs b/candle-transformers/src/models/recurrent_gemma.rs index 24d2b7e38b..d6a029babc 100644 --- a/candle-transformers/src/models/recurrent_gemma.rs +++ b/candle-transformers/src/models/recurrent_gemma.rs @@ -1,5 +1,22 @@ -// This implementation is based on the python version from huggingface/transformers. -// https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2 +//! Recurrent Gemma model implementation +//! +//! Recurrent Gemma is a version of the Gemma language model that incorporates recurrent memory. +//! This allows the model to maintain state between predictions and have longer-range memory. +//! +//! Key characteristics: +//! - Real-gated linear recurrent units (RGLRU) +//! - 1D convolution for local context +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Grouped query attention +//! +//! References: +//! - [Gemma: Open Models Based on Gemini Technology](https://blog.google/technology/developers/gemma-open-models/) +//! - [Recurrent Memory model architecture](https://arxiv.org/abs/2402.00441) +//! +//! This implementation is based on the python version from huggingface/transformers. +//! https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2 +//! use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{linear_b as linear, Linear, VarBuilder}; use std::sync::Arc; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs index 34016e5b45..a6ffce0d6d 100644 --- a/candle-transformers/src/models/repvgg.rs +++ b/candle-transformers/src/models/repvgg.rs @@ -2,6 +2,17 @@ //! //! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 //! https://arxiv.org/abs/2101.03697 +//! +//! Key characteristics: +//! - Efficient inference architecture through structural reparameterization +//! - Single 3x3 conv layer after fusing 3x3 branch, 1x1 branch and identity branch +//! - Different configurations including a0-a2, b0-b3 and variants with group convolutions +//! - High accuracy with VGG-like plain architecture and training +//! +//! References: +//! - [RepVGG Paper](https://arxiv.org/abs/2101.03697) +//! - [Official Implementation](https://github.com/DingXiaoH/RepVGG) +//! use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index 30029a0bd1..31395c8f84 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -1,7 +1,15 @@ -//! ResNet implementation. +//! # ResNet Implementation //! -//! See "Deep Residual Learning for Image Recognition" He et al. 2015 -//! +//! Implementation of ResNet architectures as described in the paper: +//! +//! ## Reference +//! +//! [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) +//! He et al. (2015) +//! +//! This paper introduced ResNet, a deep neural network architecture that utilizes +//! skip connections ("residual connections") to enable training of very deep networks. + use candle::{Result, D}; use candle_nn::{batch_norm, Conv2d, Func, VarBuilder}; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index eb51273196..6390f886d2 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -1,3 +1,20 @@ +//! RWKV v5 model implementation. +//! +//! RWKV is an RNN with transformer-level performance that can be implemented +//! as either a transformer or RNN. +//! +//! Key characteristics: +//! - Time-mix attention mechanism +//! - Channel-mix feed-forward network +//! - Linear attention +//! - Group normalization +//! - Token shift mechanism +//! +//! References: +//! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main) +//! + use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/rwkv_v6.rs b/candle-transformers/src/models/rwkv_v6.rs index 457c351ec1..c75aa885e9 100644 --- a/candle-transformers/src/models/rwkv_v6.rs +++ b/candle-transformers/src/models/rwkv_v6.rs @@ -1,3 +1,19 @@ +//! RWKV v6 model implementation. +//! +//! RWKV is an RNN with transformer-like performance. +//! Version 6 introduces refinements to the architecture. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - Time-mixing for temporal dependencies +//! - Group normalization +//! - Feed forward gating +//! - State recycling for efficient inference +//! +//! References: +//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM) +//! + use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{IndexOp, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 260ceb3a84..9e0461bc70 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -1,3 +1,19 @@ +//! Segformer model implementation for semantic segmentation and image classification. +//! +//! Segformer is a transformer-based model designed for vision tasks. It uses a hierarchical +//! structure that progressively generates features at different scales. +//! +//! Key characteristics: +//! - Efficient self-attention with sequence reduction +//! - Hierarchical feature generation +//! - Mix-FFN for local and global feature interaction +//! - Lightweight all-MLP decode head +//! +//! References: +//! - [SegFormer Paper](https://arxiv.org/abs/2105.15203) +//! - [Model Card](https://huggingface.co/nvidia/mit-b0) +//! + use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; use candle::{Module, ModuleT, Result, Tensor, D}; use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs index c54493d296..3e85fe3594 100644 --- a/candle-transformers/src/models/segment_anything/mod.rs +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -1,3 +1,11 @@ +//! Segment Anything Model (SAM) +//! +//! SAM is an architecture for image segmentation, capable of segmenting any object +//! in an image based on prompts like points or boxes. +//! +//! - [GH Link](https://github.com/facebookresearch/segment-anything) +//! - [Paper](https://arxiv.org/abs/2304.02643) +//! pub use crate::models::with_tracing::Linear; use candle::{Result, Tensor}; use candle_nn::{Module, VarBuilder}; diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 63b6635dc1..2046401428 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -1,3 +1,11 @@ +//! Siglip model implementation. +//! +//! Siglip architecture combining vision and language for zero-shot tasks. +//! +//! References: +//! - [Model Card](https://huggingface.co/google/siglip-base-patch16-224) +//! + use crate::models::clip::div_l2_norm; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 37f4cdbf59..d3e2032b6e 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -1,3 +1,12 @@ +//! Stable Diffusion +//! +//! Stable Diffusion is a latent text-to-image diffusion model capable of +//! generating photo-realistic images given any text input. +//! +//! - [Original Repository](https://github.com/CompVis/stable-diffusion) +//! - [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5) +//! + pub mod attention; pub mod clip; pub mod ddim; diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 2b46e8a12f..c5dbd3958d 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -1,3 +1,18 @@ +//! StableLM model implementation. +//! +//! StableLM is a family of language models trained by Stability AI. +//! This implementation supports the StableLM architecture. +//! +//! Key characteristics: +//! - Grouped query attention (GQA) +//! - Layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for different model sizes (3B, 7B) +//! +//! References: +//! - [Model Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index d108d06235..833cb0679f 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -1,3 +1,20 @@ +//! StarCoder model implementation with quantization support. +//! +//! StarCoder is a large language model optimized for code generation. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Causal self-attention mechanism +//! - Multi-query attention (MQA) +//! - LayerNorm for normalization +//! - Absolute positional embeddings +//! - Support for 8-bit quantization +//! +//! References: +//! - [StarCoder Paper](https://arxiv.org/abs/2305.06161) +//! - [Model Card](https://huggingface.co/bigcode/starcoder) +//! + #![allow(unused)] use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 9d933fade5..7c1d2b5ae9 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -1,3 +1,20 @@ +//! Stella v5 model implementation. +//! +//! Stella is a dense text embedding model optimized for retrieval and similarity tasks. +//! This implementation provides support for multiple embedding dimensions. +//! +//! Key characteristics: +//! - Dense text embeddings optimized for similarity search +//! - Multiple output dimension support (256 to 8192) +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! +//! References: +//! - [MRL Framework](https://arxiv.org/abs/2205.13147) +//! - [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8ba0c1c1d7..9da0c1afec 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,5 +1,19 @@ -// T5 Text Model -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +//! T5 model implementation. +//! +//! T5 (Text-to-Text Transfer Transformer) is a unified text-to-text transformer model. +//! This implementation follows the original model architecture. +//! +//! Key characteristics: +//! - Text-to-text framework +//! - Relative positional embeddings +//! - T5-specific layer normalization +//! - Encoder-decoder architecture +//! - Support for sequence-to-sequence tasks +//! +//! References: +//! - [T5 Paper](https://arxiv.org/abs/1910.10683) +//! - [HuggingFace T5](https://huggingface.co/docs/transformers/model_doc/t5) +//! - [GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) use crate::models::with_tracing::Embedding; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs index d17eda17bf..88418dd3ca 100644 --- a/candle-transformers/src/models/trocr.rs +++ b/candle-transformers/src/models/trocr.rs @@ -1,3 +1,19 @@ +//! TrOCR model implementation. +//! +//! TrOCR is a Transformer-based OCR model that uses a Vision Transformer encoder +//! and a BART-like decoder for optical character recognition. +//! +//! Key characteristics: +//! - Vision Transformer encoder for image processing +//! - BART-style decoder for text generation +//! - Learned positional embeddings +//! - Layer normalization and self-attention +//! +//! References: +//! - [Paper](https://arxiv.org/abs/2109.10282) +//! - [Model Card](https://huggingface.co/microsoft/trocr-base-handwritten) +//! + use crate::models::vit::{Config, Embeddings, Encoder}; use candle::{DType, Result, Tensor}; use candle_nn::{ diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 010643c8d2..57f9ae67bb 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -1,7 +1,18 @@ //! VGG-16 model implementation. //! -//! See Very Deep Convolutional Networks for Large-Scale Image Recognition -//! +//! VGG-16 is a convolutional neural network architecture. It consists of 13 +//! convolutional layers followed by 3 fully connected layers. +//! +//! Key characteristics: +//! - Conv layers with 3x3 filters +//! - Max pooling after every 2-3 conv layers +//! - Three fully connected layers of 4096, 4096, 1000 units +//! - ReLU activation and dropout +//! +//! References: +//! - [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) +//! + use candle::{ModuleT, Result, Tensor}; use candle_nn::{FuncT, VarBuilder}; diff --git a/candle-transformers/src/models/vit.rs b/candle-transformers/src/models/vit.rs index 3be72bf599..49ab463017 100644 --- a/candle-transformers/src/models/vit.rs +++ b/candle-transformers/src/models/vit.rs @@ -1,3 +1,20 @@ +//! Vision Transformer (ViT) implementation. +//! +//! Vision Transformer applies transformer architecture to image classification +//! by splitting images into patches and processing them as a sequence. +//! +//! Key characteristics: +//! - Image patches as sequence tokens +//! - Self-attention between patches +//! - Position embeddings +//! - CLS token for classification +//! - Layer normalization +//! +//! References: +//! - [ViT Paper](https://arxiv.org/abs/2010.11929) +//! - [Model Card](https://huggingface.co/google/vit-base-patch16-224) +//! + use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear}; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 8028cf2c66..6123884ae4 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -1,3 +1,11 @@ +//! Whisper Model Implementation +//! +//! Whisper is an automatic speech recognition (ASR) system trained on large amounts +//! of multilingual and multitask supervised data collected from the web. +//! +//! - [GH Link](https://github.com/openai/whisper) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py) +//! pub mod audio; pub mod model; pub mod quantized_model; diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 7b076f0610..9bb37a3bcc 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,3 +1,12 @@ +//! Würstchen Efficient Diffusion Model +//! +//! Würstchen is an efficient diffusion model architecture for generating images using +//! a two-stage approach with a small decoder and prior network. +//! +//! - [Paper Link](https://openreview.net/pdf?id=gU58AyJlYz) +//! - [GH Link](https://github.com/dome272/Wuerstchen) +//! - [Reference Implementation](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py) +//! pub mod attention_processor; pub mod common; pub mod ddpm; diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index df78ddce7a..047ea77046 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,4 +1,18 @@ -/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py +//! Yi model implementation. +//! +//! Yi is a decoder-only large language model trained by 01.AI. +//! It follows a standard transformer architecture similar to Llama. +//! +//! Key characteristics: +//! - Multi-head attention with rotary positional embeddings +//! - RMS normalization +//! - SwiGLU activation in feed-forward layers +//! - Grouped-query attention for efficient inference +//! +//! References: +//! - [Yi Model](https://huggingface.co/01-ai/Yi-6B) +//! - [Hugging Face](https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py) + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder};