From 404dc1cc5d6639cb6ec09f60c1c4b4a458990fe4 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Mar 2023 19:17:54 +0100 Subject: [PATCH 01/10] feat(raw): make build script more resilient --- ggml-raw/build.rs | 66 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/ggml-raw/build.rs b/ggml-raw/build.rs index 729ae432..5208e768 100644 --- a/ggml-raw/build.rs +++ b/ggml-raw/build.rs @@ -1,5 +1,4 @@ -extern crate bindgen; - +use std::collections::HashSet; use std::env; use std::path::PathBuf; @@ -10,13 +9,64 @@ fn main() { let build = builder.files(ggml_src.iter()).include("include"); - // TODO: This is currently hardcoded for (my) linux. - build.flag("-mavx2"); - build.flag("-mavx"); - build.flag("-mfma"); - build.flag("-mf16c"); + // This is a very basic heuristic for applying compile flags. + // Feel free to update this to fit your operating system. + let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap(); + let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap(); + let is_release = std::env::var("PROFILE").unwrap() == "release"; + + let supported_features: HashSet<_> = std::env::var("CARGO_CFG_TARGET_FEATURE") + .unwrap() + .split(',') + .map(|s| s.to_string()) + .collect(); + + match target_arch.as_str() { + "x86" | "x86_64" => { + let supports_fma = supported_features.contains("fma"); + let supports_avx = supported_features.contains("avx"); + let supports_avx2 = supported_features.contains("avx2"); + let supports_f16c = supported_features.contains("f16c"); + let supports_sse3 = supported_features.contains("sse3"); + + match target_os.as_str() { + "freebsd" | "haiku" | "ios" | "macos" | "linux" => { + build.flag("-pthread"); - build.compile("foo"); + if supports_avx { + build.flag("-mavx"); + } + if supports_avx2 { + build.flag("-mavx2"); + } + if supports_fma { + build.flag("-mfma"); + } + if supports_f16c { + build.flag("-mf16c"); + } + if supports_sse3 { + build.flag("-msse3"); + } + } + "windows" => match (supports_avx2, supports_avx) { + (true, _) => { + build.flag("/arch:AVX2"); + } + (_, true) => { + build.flag("/arch:AVX"); + } + _ => {} + }, + _ => {} + } + } + _ => {} + } + if is_release { + build.define("NDEBUG", None); + } + build.compile("ggml"); println!("cargo:rerun-if-changed=ggml/ggml.h"); From 90589fa3348ea6e9638eae9204de4b2543776dcb Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Mar 2023 20:09:13 +0100 Subject: [PATCH 02/10] refactor: rough decouple lib and cli --- Cargo.lock | 33 +++++++++---------------- Cargo.toml | 4 +++ llama-cli/Cargo.toml | 15 +++++++++++ {llama-rs => llama-cli}/src/cli_args.rs | 0 {llama-rs => llama-cli}/src/main.rs | 6 ++--- llama-rs/Cargo.toml | 8 ++---- llama-rs/src/lib.rs | 6 +++++ 7 files changed, 40 insertions(+), 32 deletions(-) create mode 100644 llama-cli/Cargo.toml rename {llama-rs => llama-cli}/src/cli_args.rs (100%) rename {llama-rs => llama-cli}/src/main.rs (88%) create mode 100644 llama-rs/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index eeb91ff4..f906aeb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,15 +17,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "anyhow" version = "1.0.69" @@ -50,12 +41,6 @@ dependencies = [ "rustc-demangle", ] -[[package]] -name = "bimap" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc0455254eb5c6964c4545d8bac815e1a1be4f3afe0ae695ea539c12d728d44b" - [[package]] name = "bindgen" version = "0.62.0" @@ -294,20 +279,26 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +[[package]] +name = "llama-cli" +version = "0.1.0" +dependencies = [ + "clap", + "llama-rs", + "num_cpus", + "once_cell", + "rand", +] + [[package]] name = "llama-rs" version = "0.1.0" dependencies = [ "anyhow", - "bimap", "bytemuck", - "clap", "ggml-raw", - "num_cpus", - "once_cell", "partial_sort", "rand", - "regex", ] [[package]] @@ -477,8 +468,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ - "aho-corasick", - "memchr", "regex-syntax", ] diff --git a/Cargo.toml b/Cargo.toml index c27b0276..43938e39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,10 @@ members = [ "ggml-raw", "llama-rs", + "llama-cli" ] resolver = "2" + +[workspace.dependencies] +rand = "0.8.5" \ No newline at end of file diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml new file mode 100644 index 00000000..258d7833 --- /dev/null +++ b/llama-cli/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "llama-cli" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.1.8", features = ["derive"] } +once_cell = "1.17.1" +num_cpus = "1.15.0" + +llama-rs = { path = "../llama-rs" } + +rand = { workspace = true } \ No newline at end of file diff --git a/llama-rs/src/cli_args.rs b/llama-cli/src/cli_args.rs similarity index 100% rename from llama-rs/src/cli_args.rs rename to llama-cli/src/cli_args.rs diff --git a/llama-rs/src/main.rs b/llama-cli/src/main.rs similarity index 88% rename from llama-rs/src/main.rs rename to llama-cli/src/main.rs index a81084c9..705a1dbb 100644 --- a/llama-rs/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,10 +1,8 @@ use cli_args::CLI_ARGS; -use llama::InferenceParams; +use llama_rs::InferenceParams; use rand::thread_rng; mod cli_args; -mod ggml; -mod llama; fn main() { let args = &*CLI_ARGS; @@ -35,7 +33,7 @@ fn main() { std::process::exit(1); }; - let (model, vocab) = llama::LlamaModel::load(&args.model_path, args.num_ctx_tokens as i32) + let (model, vocab) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32) .expect("Could not load model"); let mut rng = thread_rng(); diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 3ea632f2..ae756783 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -7,12 +7,8 @@ edition = "2021" [dependencies] anyhow = { version = "1.0.69", features = ["backtrace"] } -bimap = "0.6.2" bytemuck = "1.13.1" -clap = { version = "4.1.8", features = ["derive"] } ggml-raw = { path = "../ggml-raw" } -num_cpus = "1.15.0" -once_cell = "1.17.1" partial_sort = "0.2.0" -rand = "0.8.5" -regex = "1.7.1" + +rand = { workspace = true } \ No newline at end of file diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs new file mode 100644 index 00000000..2887b124 --- /dev/null +++ b/llama-rs/src/lib.rs @@ -0,0 +1,6 @@ +mod ggml; +mod llama; + +pub use llama::{ + GptVocab as Vocab, InferenceParams, LlamaHyperParams as HyperParams, LlamaModel as Model, +}; From 175674e3cb6f6afaa07bc832c650bc406ef9bd81 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Mar 2023 20:34:31 +0100 Subject: [PATCH 03/10] refactor(llama): output tokens through callback --- Cargo.lock | 32 ++++++++++++++++++++++ llama-cli/Cargo.toml | 1 + llama-cli/src/main.rs | 7 ++++- llama-rs/Cargo.toml | 1 + llama-rs/src/llama.rs | 63 ++++++++++++++++++++++++++----------------- 5 files changed, 79 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f906aeb4..57a1a7df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +dependencies = [ + "memchr", +] + [[package]] name = "anyhow" version = "1.0.69" @@ -150,6 +159,19 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "env_logger" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "errno" version = "0.2.8" @@ -223,6 +245,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "io-lifetimes" version = "1.0.6" @@ -284,6 +312,7 @@ name = "llama-cli" version = "0.1.0" dependencies = [ "clap", + "env_logger", "llama-rs", "num_cpus", "once_cell", @@ -297,6 +326,7 @@ dependencies = [ "anyhow", "bytemuck", "ggml-raw", + "log", "partial_sort", "rand", ] @@ -468,6 +498,8 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ + "aho-corasick", + "memchr", "regex-syntax", ] diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 258d7833..24187868 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] clap = { version = "4.1.8", features = ["derive"] } +env_logger = "0.10.0" once_cell = "1.17.1" num_cpus = "1.15.0" diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 705a1dbb..f486d89c 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -5,6 +5,8 @@ use rand::thread_rng; mod cli_args; fn main() { + env_logger::init(); + let args = &*CLI_ARGS; let inference_params = InferenceParams { @@ -37,5 +39,8 @@ fn main() { .expect("Could not load model"); let mut rng = thread_rng(); - model.inference_with_prompt(&vocab, &inference_params, &prompt, &mut rng); + model.inference_with_prompt(&vocab, &inference_params, &prompt, &mut rng, |t| { + print!("{t}") + }); + println!(); } diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index ae756783..d7d31a52 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] anyhow = { version = "1.0.69", features = ["backtrace"] } bytemuck = "1.13.1" +log = "0.4" ggml-raw = { path = "../ggml-raw" } partial_sort = "0.2.0" diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs index 64f6a99e..68c9671b 100644 --- a/llama-rs/src/llama.rs +++ b/llama-rs/src/llama.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, - io::{self, BufRead, Read, Seek, SeekFrom, Write}, + fmt::Display, + io::{BufRead, Read, Seek, SeekFrom}, path::Path, }; @@ -95,6 +96,24 @@ pub struct GptVocab { mapping: Vec, } +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum OutputToken<'a> { + Token(&'a str), + EndOfText, +} +impl Display for OutputToken<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + OutputToken::Token(t) => *t, + OutputToken::EndOfText => "[end of text]", + } + ) + } +} + fn llama_n_parts(size: i32) -> i32 { match size { 4096 => 1, @@ -164,7 +183,7 @@ impl LlamaModel { ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; let n_parts = llama_n_parts(hparams.n_embd); - eprintln!("Loaded HyperParams {hparams:#?}"); + log::debug!("Loaded HyperParams {hparams:#?}"); // =============== // Load vocabulary @@ -248,7 +267,7 @@ impl LlamaModel { ctx_size += (5 + 10 * n_layer) * 256; // object overhead - println!( + log::info!( "ggml ctx size = {:.2} MB\n", ctx_size as f64 / (1024.0 * 1024.0) ); @@ -323,7 +342,7 @@ impl LlamaModel { let memory_v = context.new_tensor_1d(GGML_TYPE_F32, n_elements); let memory_size = memory_k.nbytes() + memory_v.nbytes(); - println!( + log::info!( "Memory size: {} MB {}", memory_size as f32 / 1024.0 / 1024.0, n_mem @@ -361,7 +380,7 @@ impl LlamaModel { }; let part_path_str = part_path.to_string_lossy(); - println!( + log::info!( "loading model part {}/{} from '{}'\n", i + 1, n_parts, @@ -558,14 +577,10 @@ impl LlamaModel { } n_tensors += 1; - if n_tensors % 8 == 0 { - print!("."); - io::stdout().flush()?; - } } - println!(" done"); - println!( + log::info!("loading complete"); + log::info!( "model size = {:.2} MB / num tensors = {}\n", total_size as f64 / 1024.0 / 1024.0, n_tensors @@ -575,12 +590,13 @@ impl LlamaModel { Ok((model, vocab)) } - pub fn inference_with_prompt( + pub fn inference_with_prompt<'a>( &self, - vocab: &GptVocab, + vocab: &'a GptVocab, params: &InferenceParams, prompt: &str, rng: &mut impl rand::Rng, + callback: impl Fn(OutputToken<'a>), ) { let embd_inp = self.tokenize(vocab, prompt, true); let mut logits = Vec::new(); @@ -603,7 +619,6 @@ impl LlamaModel { self.hparams.n_ctx as usize - embd_inp.len(), ); let mut input_consumed = 0; - let mut input_noecho = false; let mut n_past = 0; let mut embd = Vec::new(); @@ -651,9 +666,6 @@ impl LlamaModel { // add it to the context embd.push(id); - // echo this to console - input_noecho = false; - // decrement remaining sampling budget remaining_tokens -= 1; } else { @@ -670,15 +682,18 @@ impl LlamaModel { } // display text - if !input_noecho { - for &id in &embd { - print!("{}", vocab.mapping[id as usize]); - io::stdout().flush().expect("flush"); - } + let mut eot = false; + for &id in &embd { + let output_token = if id == 2 { + eot = true; + OutputToken::EndOfText + } else { + OutputToken::Token(&vocab.mapping[id as usize]) + }; + callback(output_token); } - if embd.last().copied() == Some(2) { - println!(" [end of text]"); + if eot { break; } } From c437f9f1121541b5121e99ec13316f288b0700ba Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Mar 2023 22:23:31 +0100 Subject: [PATCH 04/10] feat: export OutputToken --- llama-rs/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 2887b124..9c81d69a 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -3,4 +3,5 @@ mod llama; pub use llama::{ GptVocab as Vocab, InferenceParams, LlamaHyperParams as HyperParams, LlamaModel as Model, + OutputToken, }; From a4b3798147aa2685b5bc8317a82c3f234724bd5a Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 01:23:26 +0100 Subject: [PATCH 05/10] chore(llama): remove unnecessary lifetimes --- llama-rs/src/llama.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs index 68c9671b..b68b5c35 100644 --- a/llama-rs/src/llama.rs +++ b/llama-rs/src/llama.rs @@ -590,13 +590,13 @@ impl LlamaModel { Ok((model, vocab)) } - pub fn inference_with_prompt<'a>( + pub fn inference_with_prompt( &self, - vocab: &'a GptVocab, + vocab: &GptVocab, params: &InferenceParams, prompt: &str, rng: &mut impl rand::Rng, - callback: impl Fn(OutputToken<'a>), + callback: impl Fn(OutputToken), ) { let embd_inp = self.tokenize(vocab, prompt, true); let mut logits = Vec::new(); From 9642ef9594d56a978ef206e9df301f28b7edccaa Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 02:00:14 +0100 Subject: [PATCH 06/10] feat(llama): provide load progress callback --- Cargo.lock | 2 +- llama-cli/Cargo.toml | 1 + llama-cli/src/main.rs | 60 ++++++++++++++++++++++++++++- llama-rs/Cargo.toml | 1 - llama-rs/src/lib.rs | 2 +- llama-rs/src/llama.rs | 90 +++++++++++++++++++++++++++++++------------ 6 files changed, 126 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 57a1a7df..a98faf62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,6 +314,7 @@ dependencies = [ "clap", "env_logger", "llama-rs", + "log", "num_cpus", "once_cell", "rand", @@ -326,7 +327,6 @@ dependencies = [ "anyhow", "bytemuck", "ggml-raw", - "log", "partial_sort", "rand", ] diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 24187868..d3a3b0bd 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] clap = { version = "4.1.8", features = ["derive"] } env_logger = "0.10.0" +log = "0.4" once_cell = "1.17.1" num_cpus = "1.15.0" diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index f486d89c..d1268a5b 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,3 +1,5 @@ +use std::io::Write; + use cli_args::CLI_ARGS; use llama_rs::InferenceParams; use rand::thread_rng; @@ -35,12 +37,66 @@ fn main() { std::process::exit(1); }; - let (model, vocab) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32) + let (model, vocab) = + llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { + use llama_rs::LoadProgress; + match progress { + LoadProgress::HyperParamsLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + total_tensors, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{total_tensors}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + } + }) .expect("Could not load model"); + log::info!("Model fully loaded!"); + let mut rng = thread_rng(); model.inference_with_prompt(&vocab, &inference_params, &prompt, &mut rng, |t| { - print!("{t}") + print!("{t}"); + std::io::stdout().flush().unwrap(); }); println!(); } diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index d7d31a52..ae756783 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,7 +8,6 @@ edition = "2021" [dependencies] anyhow = { version = "1.0.69", features = ["backtrace"] } bytemuck = "1.13.1" -log = "0.4" ggml-raw = { path = "../ggml-raw" } partial_sort = "0.2.0" diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 9c81d69a..7811286e 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -3,5 +3,5 @@ mod llama; pub use llama::{ GptVocab as Vocab, InferenceParams, LlamaHyperParams as HyperParams, LlamaModel as Model, - OutputToken, + LoadProgress, OutputToken, }; diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs index b68b5c35..614ac85f 100644 --- a/llama-rs/src/llama.rs +++ b/llama-rs/src/llama.rs @@ -14,7 +14,7 @@ use rand::{distributions::WeightedIndex, prelude::Distribution}; use crate::ggml::{GgmlCGraph, GGML_TYPE_F16, GGML_TYPE_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1}; -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct LlamaHyperParams { n_vocab: i32, n_ctx: i32, @@ -124,8 +124,44 @@ fn llama_n_parts(size: i32) -> i32 { } } +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum LoadProgress<'a> { + HyperParamsLoaded(&'a LlamaHyperParams), + BadToken { + index: usize, + }, + ContextSize { + bytes: usize, + }, + MemorySize { + bytes: usize, + n_mem: usize, + }, + PartLoading { + file: &'a Path, + current_part: usize, + total_parts: usize, + }, + PartTensorLoaded { + file: &'a Path, + current_tensor: usize, + total_tensors: usize, + }, + PartLoaded { + file: &'a Path, + byte_size: usize, + tensor_count: usize, + }, +} + impl LlamaModel { - pub fn load(path: impl AsRef, n_ctx: i32) -> Result<(LlamaModel, GptVocab)> { + pub fn load( + path: impl AsRef, + n_ctx: i32, + load_progress_callback: impl Fn(LoadProgress), + ) -> Result<(LlamaModel, GptVocab)> { use std::fs::File; use std::io::BufReader; @@ -183,7 +219,7 @@ impl LlamaModel { ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; let n_parts = llama_n_parts(hparams.n_embd); - log::debug!("Loaded HyperParams {hparams:#?}"); + load_progress_callback(LoadProgress::HyperParamsLoaded(&hparams)); // =============== // Load vocabulary @@ -194,7 +230,9 @@ impl LlamaModel { if let Ok(word) = read_string(&mut reader, len as usize) { vocab.mapping.push(word); } else { - println!("Warning: Bad token in vocab at index {i}"); + load_progress_callback(LoadProgress::BadToken { + index: i.try_into()?, + }); vocab.mapping.push("�".to_string()); } } @@ -267,10 +305,9 @@ impl LlamaModel { ctx_size += (5 + 10 * n_layer) * 256; // object overhead - log::info!( - "ggml ctx size = {:.2} MB\n", - ctx_size as f64 / (1024.0 * 1024.0) - ); + load_progress_callback(LoadProgress::ContextSize { + bytes: ctx_size.try_into()?, + }); ctx_size }; @@ -342,11 +379,11 @@ impl LlamaModel { let memory_v = context.new_tensor_1d(GGML_TYPE_F32, n_elements); let memory_size = memory_k.nbytes() + memory_v.nbytes(); - log::info!( - "Memory size: {} MB {}", - memory_size as f32 / 1024.0 / 1024.0, - n_mem - ); + + load_progress_callback(LoadProgress::MemorySize { + bytes: memory_size, + n_mem: n_mem.try_into()?, + }); LlamaModel { hparams, @@ -380,12 +417,11 @@ impl LlamaModel { }; let part_path_str = part_path.to_string_lossy(); - log::info!( - "loading model part {}/{} from '{}'\n", - i + 1, - n_parts, - part_path_str, - ); + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: (i + 1).try_into()?, + total_parts: n_parts.try_into()?, + }); let mut part_reader = BufReader::new(File::open(&part_path)?); @@ -577,14 +613,18 @@ impl LlamaModel { } n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + total_tensors: model.tensors.len(), + }); } - log::info!("loading complete"); - log::info!( - "model size = {:.2} MB / num tensors = {}\n", - total_size as f64 / 1024.0 / 1024.0, - n_tensors - ); + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size.try_into()?, + tensor_count: n_tensors.try_into()?, + }); } Ok((model, vocab)) From 5f0992b14ad106ed933eccf30670468d06d58122 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 02:34:07 +0100 Subject: [PATCH 07/10] feat(llama): switch to thiserror --- Cargo.lock | 91 ++++++++------------------------ llama-rs/Cargo.toml | 2 +- llama-rs/src/llama.rs | 117 +++++++++++++++++++++++++++++++++--------- 3 files changed, 114 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a98faf62..0affbdc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,21 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - [[package]] name = "aho-corasick" version = "0.7.20" @@ -26,30 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "anyhow" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" -dependencies = [ - "backtrace", -] - -[[package]] -name = "backtrace" -version = "0.3.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "bindgen" version = "0.62.0" @@ -212,12 +173,6 @@ dependencies = [ "cc", ] -[[package]] -name = "gimli" -version = "0.27.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" - [[package]] name = "glob" version = "0.3.1" @@ -324,11 +279,11 @@ dependencies = [ name = "llama-rs" version = "0.1.0" dependencies = [ - "anyhow", "bytemuck", "ggml-raw", "partial_sort", "rand", + "thiserror", ] [[package]] @@ -352,15 +307,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "miniz_oxide" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" -dependencies = [ - "adler", -] - [[package]] name = "nom" version = "7.1.3" @@ -381,15 +327,6 @@ dependencies = [ "libc", ] -[[package]] -name = "object" -version = "0.30.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.17.1" @@ -509,12 +446,6 @@ version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" -[[package]] -name = "rustc-demangle" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -567,6 +498,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.8" diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index ae756783..8eb42cb8 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -6,9 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = { version = "1.0.69", features = ["backtrace"] } bytemuck = "1.13.1" ggml-raw = { path = "../ggml-raw" } partial_sort = "0.2.0" +thiserror = "1.0" rand = { workspace = true } \ No newline at end of file diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs index 614ac85f..bdc9fbe7 100644 --- a/llama-rs/src/llama.rs +++ b/llama-rs/src/llama.rs @@ -2,10 +2,10 @@ use std::{ collections::HashMap, fmt::Display, io::{BufRead, Read, Seek, SeekFrom}, - path::Path, + path::{Path, PathBuf}, }; -use anyhow::{Context, Result}; +use thiserror::Error; use crate::ggml::{GgmlContext, GgmlTensor, GGML_TYPE_I32}; use ggml_raw::ggml_type; @@ -156,36 +156,76 @@ pub enum LoadProgress<'a> { }, } +#[derive(Error, Debug)] +pub enum LoadError { + #[error("could not open file {path:?}")] + OpenFileFailed { + source: std::io::Error, + path: PathBuf, + }, + #[error("unable to read exactly {bytes} bytes")] + ReadExactFailed { + source: std::io::Error, + bytes: usize, + }, + #[error("non-specific I/O error")] + IO(#[from] std::io::Error), + + #[error("could not convert bytes to a UTF-8 string")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + + #[error("invalid magic number for {path:?}")] + InvalidMagic { path: PathBuf }, + #[error("invalid value {value} for `f16` in hyperparameters")] + HyperparametersF16Invalid { value: i32 }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + UnknownTensor { tensor_name: String, path: PathBuf }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + TensorWrongSize { tensor_name: String, path: PathBuf }, + #[error("invalid ftype {ftype} in {path:?}")] + InvalidFtype { ftype: i32, path: PathBuf }, +} + impl LlamaModel { pub fn load( path: impl AsRef, n_ctx: i32, load_progress_callback: impl Fn(LoadProgress), - ) -> Result<(LlamaModel, GptVocab)> { + ) -> Result<(LlamaModel, GptVocab), LoadError> { use std::fs::File; use std::io::BufReader; let path = path.as_ref(); - let path_str = path.to_string_lossy(); - let mut reader = BufReader::new( - File::open(path) - .with_context(|| anyhow::anyhow!("Failed to open file at '{path_str}'",))?, - ); + let mut reader = + BufReader::new(File::open(path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: path.to_owned(), + })?); /// Helper function. Reads an int from the buffer and returns it. - fn read_i32(reader: &mut impl BufRead) -> Result { + fn read_i32(reader: &mut impl BufRead) -> Result { let mut bytes = [0u8; 4]; reader .read_exact(&mut bytes) - .context("Trying to parse metadata")?; + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: bytes.len(), + })?; Ok(i32::from_le_bytes(bytes)) } /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { + fn read_string(reader: &mut BufReader, len: usize) -> Result { let mut buf = vec![0; len]; - reader.read_exact(&mut buf)?; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; let s = String::from_utf8(buf)?; Ok(s) } @@ -194,7 +234,9 @@ impl LlamaModel { { let magic = read_i32(&mut reader)?; if magic != 0x67676d6c { - anyhow::bail!("Invalid model file '{path_str}' (bad magic)") + return Err(LoadError::InvalidMagic { + path: path.to_owned(), + }); } } @@ -245,7 +287,7 @@ impl LlamaModel { 1 => GGML_TYPE_F16, 2 => GGML_TYPE_Q4_0, 3 => GGML_TYPE_Q4_1, - invalid => anyhow::bail!("Invalid value for hparams.f16_ {invalid}"), + invalid => return Err(LoadError::HyperparametersF16Invalid { value: invalid }), }; let n_embd = hparams.n_embd; @@ -415,7 +457,6 @@ impl LlamaModel { } else { path.to_path_buf() }; - let part_path_str = part_path.to_string_lossy(); load_progress_callback(LoadProgress::PartLoading { file: &part_path, @@ -455,7 +496,7 @@ impl LlamaModel { let Some(tensor) = model.tensors.get(&tensor_name) else { - anyhow::bail!("Unknown tensor '{tensor_name}' in model_file '{part_path_str}'") + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); }; // split_type = 0: split by columns @@ -494,26 +535,41 @@ impl LlamaModel { if n_dims == 1 { if tensor.nelements() != nelements { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } } else { if tensor.nelements() / n_parts != nelements { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } } if n_dims == 1 { if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } } else { if split_type == 0 { if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } } else { if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } } } @@ -537,14 +593,22 @@ impl LlamaModel { assert_eq!(ne[0] % 64, 0); ggml_type_size(GGML_TYPE_Q4_1) } - _ => anyhow::bail!("Invalid ftype {ftype} in model file"), + _ => { + return Err(LoadError::InvalidFtype { + ftype, + path: part_path, + }) + } }; if n_dims == 1 || n_parts == 1 { if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize != tensor.nbytes() { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } let data = tensor.data(); @@ -564,7 +628,10 @@ impl LlamaModel { if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize != tensor.nbytes() / n_parts as usize { - anyhow::bail!("Tensor {tensor_name} has the wrong size in model file"); + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); } if split_type == 0 { @@ -622,7 +689,7 @@ impl LlamaModel { load_progress_callback(LoadProgress::PartLoaded { file: &part_path, - byte_size: total_size.try_into()?, + byte_size: total_size, tensor_count: n_tensors.try_into()?, }); } From e2efd9772d461d10e509bce6936d0ead30e8f942 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 11:49:07 +0100 Subject: [PATCH 08/10] total_tensors -> tensor_count --- llama-cli/src/main.rs | 4 ++-- llama-rs/src/llama.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index d1268a5b..f06ea21d 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -68,11 +68,11 @@ fn main() { ), LoadProgress::PartTensorLoaded { current_tensor, - total_tensors, + tensor_count, .. } => { if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{total_tensors}"); + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); } } LoadProgress::PartLoaded { diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs index bdc9fbe7..0da03460 100644 --- a/llama-rs/src/llama.rs +++ b/llama-rs/src/llama.rs @@ -147,7 +147,7 @@ pub enum LoadProgress<'a> { PartTensorLoaded { file: &'a Path, current_tensor: usize, - total_tensors: usize, + tensor_count: usize, }, PartLoaded { file: &'a Path, @@ -683,7 +683,7 @@ impl LlamaModel { load_progress_callback(LoadProgress::PartTensorLoaded { file: &part_path, current_tensor: n_tensors.try_into()?, - total_tensors: model.tensors.len(), + tensor_count: model.tensors.len(), }); } From 9809f9694cbc374bcb55531660d49137f63cab52 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 11:57:47 +0100 Subject: [PATCH 09/10] feat(cli): default log level --- llama-cli/src/main.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index f06ea21d..77cef923 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -7,7 +7,10 @@ use rand::thread_rng; mod cli_args; fn main() { - env_logger::init(); + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_default_env() + .init(); let args = &*CLI_ARGS; From e7656175153cec3be175536e775439ce5246b6e3 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 16 Mar 2023 23:05:42 +0100 Subject: [PATCH 10/10] refactor(llama): rustify type names - fix #17 --- llama-cli/src/main.rs | 4 +- llama-rs/src/ggml.rs | 96 ++-- llama-rs/src/lib.rs | 1157 +++++++++++++++++++++++++++++++++++++++- llama-rs/src/llama.rs | 1158 ----------------------------------------- 4 files changed, 1197 insertions(+), 1218 deletions(-) delete mode 100644 llama-rs/src/llama.rs diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 77cef923..d7d0f2bb 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,7 +1,7 @@ use std::io::Write; use cli_args::CLI_ARGS; -use llama_rs::InferenceParams; +use llama_rs::InferenceParameters; use rand::thread_rng; mod cli_args; @@ -14,7 +14,7 @@ fn main() { let args = &*CLI_ARGS; - let inference_params = InferenceParams { + let inference_params = InferenceParameters { n_threads: args.num_threads as i32, n_predict: args.num_predict, n_batch: args.batch_size, diff --git a/llama-rs/src/ggml.rs b/llama-rs/src/ggml.rs index 2967ce26..9ce86741 100644 --- a/llama-rs/src/ggml.rs +++ b/llama-rs/src/ggml.rs @@ -1,29 +1,30 @@ use std::{ ffi::c_void, - marker::PhantomData, - ptr::{addr_of, NonNull}, + ptr::NonNull, sync::{Arc, Weak}, }; -pub const GGML_TYPE_Q4_0: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_Q4_0; -pub const GGML_TYPE_Q4_1: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_Q4_1; -pub const GGML_TYPE_I8: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I8; -pub const GGML_TYPE_I16: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I16; -pub const GGML_TYPE_I32: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I32; -pub const GGML_TYPE_F16: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_F16; -pub const GGML_TYPE_F32: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_F32; -pub const GGML_TYPE_COUNT: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_COUNT; +pub use ggml_raw::ggml_type as Type; + +pub const TYPE_Q4_0: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_Q4_0; +pub const TYPE_Q4_1: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_Q4_1; +pub const TYPE_I8: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I8; +pub const TYPE_I16: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I16; +pub const TYPE_I32: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_I32; +pub const TYPE_F16: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_F16; +pub const TYPE_F32: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_F32; +pub const TYPE_COUNT: ggml_raw::ggml_type = ggml_raw::ggml_type_GGML_TYPE_COUNT; /// Acts as a RAII-guard over a `ggml_raw::ggml_context`, allocating via /// ggml_init and dropping via ggml_free -pub struct GgmlContext { +pub struct Context { /// An `Arc` is used to model the relation between the context and the /// allocated tensors. Tensors are owned by the object, so a [`GgmlTensor`] /// contains a `Weak` reference underneath and doesn't let you do anything /// with it if the underlying context has been deallocated. ptr: Arc>, } -impl GgmlContext { +impl Context { pub fn init(mem_size: usize) -> Self { let raw = unsafe { ggml_raw::ggml_init(ggml_raw::ggml_init_params { @@ -38,116 +39,103 @@ impl GgmlContext { } } - fn new_tensor_raw(&self, raw: *mut ggml_raw::ggml_tensor) -> GgmlTensor { - GgmlTensor { + fn new_tensor_raw(&self, raw: *mut ggml_raw::ggml_tensor) -> Tensor { + Tensor { ptr: NonNull::new(raw).expect("Should not be null"), ctx: Arc::downgrade(&self.ptr), } } - pub fn new_tensor_1d(&self, typ: ggml_raw::ggml_type, ne0: i32) -> GgmlTensor { + pub fn new_tensor_1d(&self, typ: ggml_raw::ggml_type, ne0: i32) -> Tensor { let raw = unsafe { ggml_raw::ggml_new_tensor_1d(self.ptr.as_ptr(), typ, ne0) }; self.new_tensor_raw(raw) } - pub fn new_tensor_2d(&self, typ: ggml_raw::ggml_type, ne0: i32, ne1: i32) -> GgmlTensor { + pub fn new_tensor_2d(&self, typ: ggml_raw::ggml_type, ne0: i32, ne1: i32) -> Tensor { let raw = unsafe { ggml_raw::ggml_new_tensor_2d(self.ptr.as_ptr(), typ, ne0, ne1) }; self.new_tensor_raw(raw) } - pub fn new_tensor_3d( - &self, - typ: ggml_raw::ggml_type, - ne0: i32, - ne1: i32, - ne2: i32, - ) -> GgmlTensor { + pub fn new_tensor_3d(&self, typ: ggml_raw::ggml_type, ne0: i32, ne1: i32, ne2: i32) -> Tensor { let raw = unsafe { ggml_raw::ggml_new_tensor_3d(self.ptr.as_ptr(), typ, ne0, ne1, ne2) }; self.new_tensor_raw(raw) } - pub fn new_f32(&self, x: f32) -> GgmlTensor { + pub fn new_f32(&self, x: f32) -> Tensor { let raw = unsafe { ggml_raw::ggml_new_f32(self.ptr.as_ptr(), x) }; self.new_tensor_raw(raw) } - pub fn op_get_rows(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_get_rows(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_get_rows(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_norm(&self, a: &GgmlTensor) -> GgmlTensor { + pub fn op_norm(&self, a: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_norm(self.ptr.as_ptr(), a.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_mul(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_mul(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_mul(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_repeat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_repeat(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_mul_mat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_mul_mat(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_mul_mat(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_add(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_add(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_add(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_silu(&self, a: &GgmlTensor) -> GgmlTensor { + pub fn op_silu(&self, a: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_silu(self.ptr.as_ptr(), a.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_scale(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_scale(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_scale(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_diag_mask_inf(&self, a: &GgmlTensor, n_past: i32) -> GgmlTensor { + pub fn op_diag_mask_inf(&self, a: &Tensor, n_past: i32) -> Tensor { let tensor = unsafe { ggml_raw::ggml_diag_mask_inf(self.ptr.as_ptr(), a.ptr.as_ptr(), n_past) }; self.new_tensor_raw(tensor) } - pub fn op_soft_max(&self, a: &GgmlTensor) -> GgmlTensor { + pub fn op_soft_max(&self, a: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_soft_max(self.ptr.as_ptr(), a.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_view_1d(&self, a: &GgmlTensor, ne0: i32, offset: usize) -> GgmlTensor { + pub fn op_view_1d(&self, a: &Tensor, ne0: i32, offset: usize) -> Tensor { let tensor = unsafe { ggml_raw::ggml_view_1d(self.ptr.as_ptr(), a.ptr.as_ptr(), ne0, offset) }; self.new_tensor_raw(tensor) } - pub fn op_cpy(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor { + pub fn op_cpy(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = unsafe { ggml_raw::ggml_cpy(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } - pub fn op_permute( - &self, - a: &GgmlTensor, - axis0: i32, - axis1: i32, - axis2: i32, - axis3: i32, - ) -> GgmlTensor { + pub fn op_permute(&self, a: &Tensor, axis0: i32, axis1: i32, axis2: i32, axis3: i32) -> Tensor { let tensor = unsafe { ggml_raw::ggml_permute( self.ptr.as_ptr(), @@ -160,19 +148,19 @@ impl GgmlContext { }; self.new_tensor_raw(tensor) } - pub fn op_reshape_3d(&self, a: &GgmlTensor, ne0: i32, ne1: i32, ne2: i32) -> GgmlTensor { + pub fn op_reshape_3d(&self, a: &Tensor, ne0: i32, ne1: i32, ne2: i32) -> Tensor { let tensor = unsafe { ggml_raw::ggml_reshape_3d(self.ptr.as_ptr(), a.ptr.as_ptr(), ne0, ne1, ne2) }; self.new_tensor_raw(tensor) } - pub fn op_rope(&self, a: &GgmlTensor, npast: i32, ndims: i32, mode: i32) -> GgmlTensor { + pub fn op_rope(&self, a: &Tensor, npast: i32, ndims: i32, mode: i32) -> Tensor { let tensor = unsafe { ggml_raw::ggml_rope(self.ptr.as_ptr(), a.ptr.as_ptr(), npast, ndims, mode) }; self.new_tensor_raw(tensor) } - pub fn graph_compute(&self, graph: &mut GgmlCGraph) { + pub fn graph_compute(&self, graph: &mut ComputationGraph) { unsafe { ggml_raw::ggml_graph_compute(self.ptr.as_ptr(), &mut graph.inner); } @@ -183,7 +171,7 @@ impl GgmlContext { } } -impl Drop for GgmlContext { +impl Drop for Context { fn drop(&mut self) { // SAFETY: The only non-weak copy of ptr is no longer accessible after // this drop call. @@ -195,15 +183,15 @@ impl Drop for GgmlContext { /// Tensors are owned by the context. A tensor is alive as long as the /// underlying context it was created with is alive. -pub struct GgmlTensor { +pub struct Tensor { ptr: NonNull, ctx: Weak>, } -impl GgmlTensor { +impl Tensor { /// Creates a shared copy of this tensor pointer. pub fn share(&self) -> Self { - GgmlTensor { + Tensor { ptr: self.ptr, ctx: Weak::clone(&self.ctx), } @@ -264,11 +252,11 @@ impl GgmlTensor { } } -pub struct GgmlCGraph { +pub struct ComputationGraph { inner: ggml_raw::ggml_cgraph, } -impl GgmlCGraph { +impl ComputationGraph { pub fn new(n_threads: i32) -> Self { Self { inner: ggml_raw::ggml_cgraph { @@ -280,7 +268,7 @@ impl GgmlCGraph { } } - pub fn build_forward_expand(&mut self, tensor: &GgmlTensor) { + pub fn build_forward_expand(&mut self, tensor: &Tensor) { unsafe { ggml_raw::ggml_build_forward_expand(&mut self.inner, tensor.ptr.as_ptr()) } } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 7811286e..2f68e442 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,7 +1,1156 @@ mod ggml; -mod llama; -pub use llama::{ - GptVocab as Vocab, InferenceParams, LlamaHyperParams as HyperParams, LlamaModel as Model, - LoadProgress, OutputToken, +use std::{ + collections::HashMap, + fmt::Display, + io::{BufRead, Read, Seek, SeekFrom}, + path::{Path, PathBuf}, }; + +use thiserror::Error; + +use partial_sort::PartialSort; +use rand::{distributions::WeightedIndex, prelude::Distribution}; + +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Hyperparameters { + n_vocab: i32, + n_ctx: i32, + n_embd: i32, + n_mult: i32, + n_head: i32, + n_layer: i32, + n_rot: i32, + f16_: i32, +} + +struct Layer { + attention_norm: ggml::Tensor, + + wq: ggml::Tensor, + wk: ggml::Tensor, + wv: ggml::Tensor, + wo: ggml::Tensor, + + // normalization + ffn_norm: ggml::Tensor, + + // ff + w1: ggml::Tensor, + w2: ggml::Tensor, + w3: ggml::Tensor, +} + +pub struct Model { + hparams: Hyperparameters, + + tok_embeddings: ggml::Tensor, + + norm: ggml::Tensor, + output: ggml::Tensor, + + layers: Vec, + + memory_k: ggml::Tensor, + memory_v: ggml::Tensor, + + tensors: HashMap, + + // Must be kept alive for the model + _context: ggml::Context, +} + +pub struct InferenceParameters { + pub n_threads: i32, + pub n_predict: usize, + pub n_batch: usize, + pub repeat_last_n: usize, + pub top_k: i32, + pub top_p: f32, + pub repeat_penalty: f32, + pub temp: f32, +} + +impl Default for InferenceParameters { + fn default() -> Self { + Self { + n_threads: 8, + n_predict: 128, + n_batch: 8, + repeat_last_n: 64, + top_k: 40, + top_p: 0.95, + repeat_penalty: 1.30, + temp: 0.80, + } + } +} + +type TokenId = i32; +type Token = String; + +#[derive(Default)] +pub struct Vocabulary { + /// Maps every integer (index) token id to its corresponding string + mapping: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum OutputToken<'a> { + Token(&'a str), + EndOfText, +} +impl Display for OutputToken<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + OutputToken::Token(t) => *t, + OutputToken::EndOfText => "[end of text]", + } + ) + } +} + +fn llama_n_parts(size: i32) -> i32 { + match size { + 4096 => 1, + 5120 => 2, + 6656 => 3, + 8192 => 8, + _ => unreachable!("Invalid size for N_PARTS"), + } +} + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum LoadProgress<'a> { + HyperParamsLoaded(&'a Hyperparameters), + BadToken { + index: usize, + }, + ContextSize { + bytes: usize, + }, + MemorySize { + bytes: usize, + n_mem: usize, + }, + PartLoading { + file: &'a Path, + current_part: usize, + total_parts: usize, + }, + PartTensorLoaded { + file: &'a Path, + current_tensor: usize, + tensor_count: usize, + }, + PartLoaded { + file: &'a Path, + byte_size: usize, + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +pub enum LoadError { + #[error("could not open file {path:?}")] + OpenFileFailed { + source: std::io::Error, + path: PathBuf, + }, + #[error("unable to read exactly {bytes} bytes")] + ReadExactFailed { + source: std::io::Error, + bytes: usize, + }, + #[error("non-specific I/O error")] + IO(#[from] std::io::Error), + + #[error("could not convert bytes to a UTF-8 string")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + + #[error("invalid magic number for {path:?}")] + InvalidMagic { path: PathBuf }, + #[error("invalid value {value} for `f16` in hyperparameters")] + HyperparametersF16Invalid { value: i32 }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + UnknownTensor { tensor_name: String, path: PathBuf }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + TensorWrongSize { tensor_name: String, path: PathBuf }, + #[error("invalid ftype {ftype} in {path:?}")] + InvalidFtype { ftype: i32, path: PathBuf }, +} + +impl Model { + pub fn load( + path: impl AsRef, + n_ctx: i32, + load_progress_callback: impl Fn(LoadProgress), + ) -> Result<(Model, Vocabulary), LoadError> { + use std::fs::File; + use std::io::BufReader; + + let path = path.as_ref(); + + let mut reader = + BufReader::new(File::open(path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: path.to_owned(), + })?); + + /// Helper function. Reads an int from the buffer and returns it. + fn read_i32(reader: &mut impl BufRead) -> Result { + let mut bytes = [0u8; 4]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: bytes.len(), + })?; + Ok(i32::from_le_bytes(bytes)) + } + + /// Helper function. Reads a string from the buffer and returns it. + fn read_string(reader: &mut BufReader, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) + } + + // Verify magic + { + let magic = read_i32(&mut reader)?; + if magic != 0x67676d6c { + return Err(LoadError::InvalidMagic { + path: path.to_owned(), + }); + } + } + + // ================= + // Load hyper params + // ================= + + // NOTE: Field order matters! Data is laid out in the file exactly + // in this order. + let hparams = Hyperparameters { + n_vocab: read_i32(&mut reader)?, + n_ctx, + n_embd: read_i32(&mut reader)?, + n_mult: read_i32(&mut reader)?, + n_head: read_i32(&mut reader)?, + n_layer: read_i32(&mut reader)?, + n_rot: read_i32(&mut reader)?, + f16_: read_i32(&mut reader)?, + }; + + let n_ff = + ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; + let n_parts = llama_n_parts(hparams.n_embd); + + load_progress_callback(LoadProgress::HyperParamsLoaded(&hparams)); + + // =============== + // Load vocabulary + // =============== + let mut vocab = Vocabulary::default(); + for i in 0..hparams.n_vocab { + let len = read_i32(&mut reader)?; + if let Ok(word) = read_string(&mut reader, len as usize) { + vocab.mapping.push(word); + } else { + load_progress_callback(LoadProgress::BadToken { + index: i.try_into()?, + }); + vocab.mapping.push("�".to_string()); + } + } + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + let wtype = match hparams.f16_ { + 0 => ggml::TYPE_F32, + 1 => ggml::TYPE_F16, + 2 => ggml::TYPE_Q4_0, + 3 => ggml::TYPE_Q4_1, + invalid => return Err(LoadError::HyperparametersF16Invalid { value: invalid }), + }; + + let n_embd = hparams.n_embd; + let n_layer = hparams.n_layer; + let n_ctx = hparams.n_ctx; + let n_vocab = hparams.n_vocab; + + let ctx_size = { + // Use 64-bit math to prevent overflow. + let n_embd = n_embd as u64; + let n_layer = n_layer as u64; + let n_ctx = n_ctx as u64; + let n_vocab = n_vocab as u64; + let n_ff = n_ff as u64; + + /// NOTE: The original code relies in promotion rules and automatic + /// cast between int to float. What we do instead is use this macro + /// to convert every term of the multiplication to f64, which should + /// have enough precision bits to hold the final value, then cast to + /// usize. I have observed a discrepancy between the ctx_size found + /// using this code, and the one in llama.cpp. The number for rust + /// ends up being slightly lower, but no "out of memory" errors are + /// reported by ggml. + macro_rules! mul { + ($term:expr, $($terms:expr),*) => { + (($term as f64) $(* ($terms as f64))*) as u64 + }; + } + + fn ggml_type_sizef(x: ggml_raw::ggml_type) -> f64 { + (unsafe { ggml_raw::ggml_type_sizef(x) }) as f64 + } + + let mut ctx_size: u64 = 0; + + ctx_size += mul!(n_embd, n_vocab, ggml_type_sizef(wtype)); // tok_embeddings + + ctx_size += mul!(n_embd, ggml_type_sizef(ggml::TYPE_F32)); // norm + + ctx_size += mul!(n_embd, n_vocab, ggml_type_sizef(wtype)); // output + + ctx_size += mul!(n_layer, n_embd, ggml_type_sizef(ggml::TYPE_F32)); // attention_norm + + ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wq + ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wk + ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wv + ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wo + + ctx_size += mul!(n_layer, n_embd, ggml_type_sizef(ggml::TYPE_F32)); // ffn_norm + + ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w1 + ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w2 + ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w3 + + ctx_size += mul!(n_ctx, n_layer, n_embd, ggml_type_sizef(ggml::TYPE_F32)); // memory_k + ctx_size += mul!(n_ctx, n_layer, n_embd, ggml_type_sizef(ggml::TYPE_F32)); // memory_v + + ctx_size += (5 + 10 * n_layer) * 256; // object overhead + + load_progress_callback(LoadProgress::ContextSize { + bytes: ctx_size.try_into()?, + }); + + ctx_size + }; + + // Initialize the context + let context = ggml::Context::init(ctx_size as usize); + + let model = { + let mut tensors = HashMap::new(); + + let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); + let norm = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + let output = context.new_tensor_2d(wtype, n_embd, n_vocab); + + tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); + tensors.insert("norm.weight".to_owned(), norm.share()); + tensors.insert("output.weight".to_owned(), output.share()); + + let mut layers = Vec::new(); + for i in 0..n_layer { + let layer = Layer { + attention_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + wq: context.new_tensor_2d(wtype, n_embd, n_embd), + wk: context.new_tensor_2d(wtype, n_embd, n_embd), + wv: context.new_tensor_2d(wtype, n_embd, n_embd), + wo: context.new_tensor_2d(wtype, n_embd, n_embd), + ffn_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + w1: context.new_tensor_2d(wtype, n_embd, n_ff), + w2: context.new_tensor_2d(wtype, n_ff, n_embd), + w3: context.new_tensor_2d(wtype, n_embd, n_ff), + }; + + tensors.insert( + format!("layers.{i}.attention_norm.weight"), + layer.attention_norm.share(), + ); + + tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); + tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); + tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); + tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); + + tensors.insert( + format!("layers.{i}.ffn_norm.weight"), + layer.ffn_norm.share(), + ); + + tensors.insert( + format!("layers.{i}.feed_forward.w1.weight"), + layer.w1.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w2.weight"), + layer.w2.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w3.weight"), + layer.w3.share(), + ); + + layers.push(layer); + } + + // key + value memory + let n_mem = n_layer * n_ctx; + let n_elements = n_embd * n_mem; + + let memory_k = context.new_tensor_1d(ggml::TYPE_F32, n_elements); + let memory_v = context.new_tensor_1d(ggml::TYPE_F32, n_elements); + + let memory_size = memory_k.nbytes() + memory_v.nbytes(); + + load_progress_callback(LoadProgress::MemorySize { + bytes: memory_size, + n_mem: n_mem.try_into()?, + }); + + Model { + hparams, + tok_embeddings, + norm, + output, + layers, + memory_k, + memory_v, + tensors, + _context: context, + } + }; + + // Close the file, but keep its offset. That way we know how to skip the + // metadata when loading the parts. + let file_offset = reader.stream_position()?; + drop(reader); + + for i in 0..n_parts { + let part_id = i; + + let part_path = if i > 0 { + let mut path = path.to_owned(); + let mut filename = path.components().last().unwrap().as_os_str().to_owned(); + filename.push(&format!(".{i}")); + path.pop(); + path.join(filename) + } else { + path.to_path_buf() + }; + + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: (i + 1).try_into()?, + total_parts: n_parts.try_into()?, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = read_i32(&mut part_reader)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_i32(&mut part_reader)?; + + let mut nelements = 1; + let mut ne = [1i32, 1i32]; + for i in 0..n_dims { + ne[i as usize] = read_i32(&mut part_reader)?; + nelements *= ne[i as usize]; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else { + if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else { + if split_type == 0 { + if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } + } + + fn ggml_type_size(t: ggml::Type) -> usize { + unsafe { ggml_raw::ggml_type_size(t) } + } + + fn ggml_blck_size(t: ggml::Type) -> i32 { + unsafe { ggml_raw::ggml_blck_size(t) } + } + + let bpe = match ftype { + 0 => ggml_type_size(ggml::TYPE_F32), + 1 => ggml_type_size(ggml::TYPE_F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml_type_size(ggml::TYPE_Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml_type_size(ggml::TYPE_Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + ftype, + path: part_path, + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize + != tensor.nbytes() + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + let data = tensor.data(); + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize + != tensor.nbytes() / n_parts as usize + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (tensor.get_ne()[0] / ggml_blck_size(tensor.get_type())) + as usize + * ggml_type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0) as usize + / ggml_blck_size(tensor.get_type()) as usize) + * ggml_type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = std::slice::from_raw_parts_mut( + ptr as *mut u8, + row_size / n_parts as usize, + ); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (tensor.get_ne()[0] / ggml_blck_size(tensor.get_type())) + as usize + * ggml_type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 + part_id * np1) as usize * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts as usize + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + } + + Ok((model, vocab)) + } + + pub fn inference_with_prompt( + &self, + vocab: &Vocabulary, + params: &InferenceParameters, + prompt: &str, + rng: &mut impl rand::Rng, + callback: impl Fn(OutputToken), + ) { + let embd_inp = self.tokenize(vocab, prompt, true); + let mut logits = Vec::new(); + + // determine the required inference memory per token: + let mut mem_per_token = 0; + let _ = self.evaluate( + params.n_threads, + 0, + &[0, 1, 2, 3], + &mut logits, + &mut mem_per_token, + ); + + let last_n_size = params.repeat_last_n; + let mut last_n_tokens = vec![0 as TokenId; last_n_size]; + + let mut remaining_tokens = usize::min( + params.n_predict, + self.hparams.n_ctx as usize - embd_inp.len(), + ); + let mut input_consumed = 0; + + let mut n_past = 0; + let mut embd = Vec::new(); + while remaining_tokens > 0 { + // predict + if embd.len() > 0 { + self.evaluate( + params.n_threads, + n_past, + &embd, + &mut logits, + &mut mem_per_token, + ); + } + + n_past += embd.len() as i32; + embd.clear(); + + if embd_inp.len() <= input_consumed { + // out of input, sample next token + let InferenceParameters { + top_k, + top_p, + repeat_penalty, + temp, + .. + } = params; + + let n_vocab = self.hparams.n_vocab; + + let id = self.sample_top_p_top_k( + vocab, + &logits[logits.len() - n_vocab as usize..], + &last_n_tokens, + *repeat_penalty as f64, + *top_k, + *top_p as f64, + *temp as f64, + rng, + ); + + last_n_tokens.remove(0); + last_n_tokens.push(id); + + // add it to the context + embd.push(id); + + // decrement remaining sampling budget + remaining_tokens -= 1; + } else { + // if here, it means we are still processing the input prompt + while embd_inp.len() > input_consumed { + embd.push(embd_inp[input_consumed]); + last_n_tokens.remove(0); + last_n_tokens.push(embd_inp[input_consumed]); + input_consumed += 1; + if embd.len() > params.n_batch { + break; + } + } + } + + // display text + let mut eot = false; + for &id in &embd { + let output_token = if id == 2 { + eot = true; + OutputToken::EndOfText + } else { + OutputToken::Token(&vocab.mapping[id as usize]) + }; + callback(output_token); + } + + if eot { + break; + } + } + } + + pub fn sample_top_p_top_k( + &self, + vocab: &Vocabulary, + logits: &[f32], + last_n_tokens: &[TokenId], + repeat_penalty: f64, + top_k: i32, + top_p: f64, + temp: f64, + rng: &mut impl rand::Rng, + ) -> TokenId { + let n_logits = vocab.mapping.len(); + let mut logits_id = Vec::<(f64, TokenId)>::with_capacity(n_logits); + + { + let scale = 1.0 / temp; + for (i, &logit) in logits.iter().enumerate() { + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if last_n_tokens.contains(&(i as TokenId)) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i] < 0.0 { + logits_id.push((logit as f64 * scale * repeat_penalty, i as TokenId)); + } else { + logits_id.push((logit as f64 * scale / repeat_penalty, i as TokenId)); + } + } else { + logits_id.push((logit as f64 * scale, i as TokenId)); + } + } + } + + // find the top K tokens + { + logits_id.partial_sort(top_k as usize, |a, b| { + // Sort descending + b.0.total_cmp(&a.0) + }); + logits_id.truncate(top_k as usize); + } + + let maxl = logits_id + .iter() + .map(|x| x.0) + .max_by(f64::total_cmp) + .unwrap(); + + // compute probs for the top K tokens + let mut probs: Vec = logits_id + .iter() + .copied() + .map(|(k, v)| (k - maxl).exp()) + .collect(); + let sum: f64 = probs.iter().copied().sum(); + + // Normalize the probs + for p in probs.iter_mut() { + *p /= sum; + } + + // Top p sampling + if top_p < 1.0 { + let mut cumsum = 0.0; + for i in 0..probs.len() { + cumsum += probs[i]; + if cumsum >= top_p { + probs.truncate(i + 1); + logits_id.truncate(i + 1); + break; + } + } + + cumsum = 1.0 / cumsum; + for p in probs.iter_mut() { + *p *= cumsum; + } + } + + let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); + let idx = dist.sample(rng); + + logits_id[idx].1 + } + + pub fn evaluate( + &self, + n_threads: i32, + n_past: i32, + embd_inp: &[TokenId], + embd_w: &mut Vec, + mem_per_token: &mut usize, + ) { + let N = embd_inp.len(); + + let Hyperparameters { + n_vocab, + n_ctx, + n_embd, + n_mult: _, + n_head, + n_layer, + n_rot, + f16_: _, + } = self.hparams; + + let mut buf_size = 512 * 1024 * 1024; + if *mem_per_token > 0 && *mem_per_token * N > buf_size { + // add 10% to account for ggml object overhead + buf_size = (1.1f64 * *mem_per_token as f64 * N as f64) as usize; + }; + let ctx0 = ggml::Context::init(buf_size); + + let mut gf = ggml::ComputationGraph::new(n_threads); + + let embd = ctx0.new_tensor_1d(ggml::TYPE_I32, N as i32); + unsafe { embd.write_data(bytemuck::cast_slice(embd_inp)) }; + + let mut inpL = ctx0.op_get_rows(&self.tok_embeddings, &embd); + + for il in 0..n_layer as usize { + let inpSA = inpL.share(); + let mut cur: ggml::Tensor; + + // norm + { + cur = ctx0.op_norm(&inpL); + + // cur = attention_norm * cur + cur = ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].attention_norm, &cur), &cur); + } + + // self-attention + { + let Qcur = ctx0.op_mul_mat(&self.layers[il].wq, &cur); + let Kcur = ctx0.op_mul_mat(&self.layers[il].wk, &cur); + let Vcur = ctx0.op_mul_mat(&self.layers[il].wv, &cur); + + // store key and value to memory + if N >= 1 { + let k = ctx0.op_view_1d( + &self.memory_k, + N as i32 * n_embd, + (self.memory_k.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + let v = ctx0.op_view_1d( + &self.memory_v, + N as i32 * n_embd, + (self.memory_v.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + gf.build_forward_expand(&ctx0.op_cpy(&Kcur, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&Vcur, &v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + let Q = ctx0.op_permute( + &ctx0.op_rope( + &ctx0.op_cpy( + &Qcur, + &ctx0.new_tensor_3d(ggml::TYPE_F32, n_embd / n_head, n_head, N as i32), + ), + n_past, + n_rot, + 0, + ), + 0, + 2, + 1, + 3, + ); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + let K = ctx0.op_permute( + &ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &self.memory_k, + (n_past + N as i32) * n_embd, + il * n_ctx as usize + * self.memory_k.element_size() + * n_embd as usize, + ), + n_embd / n_head, + n_head, + n_past + N as i32, + ), + n_past, + n_rot, + 1, + ), + 0, + 2, + 1, + 3, + ); + + // K * Q + let KQ = ctx0.op_mul_mat(&K, &Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let KQ_scaled = ctx0.op_scale( + &KQ, + &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + // KQ_masked = mask_past(KQ_scaled) + let KQ_masked = ctx0.op_diag_mask_inf(&KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + let KQ_soft_max = ctx0.op_soft_max(&KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + let V_trans = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &self.memory_v, + (n_past + N as i32) * n_embd, + il * n_ctx as usize * self.memory_v.element_size() * n_embd as usize, + ), + n_embd / n_head, + n_head, + n_past + N as i32, + ), + 1, + 2, + 0, + 3, + ); + + // KQV = transpose(V) * KQ_soft_max + let KQV = ctx0.op_mul_mat(&V_trans, &KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + let KQV_merged = ctx0.op_permute(&KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ctx0.op_cpy( + &KQV_merged, + &ctx0.new_tensor_2d(ggml::TYPE_F32, n_embd, N as i32), + ); + + // projection (no bias) + cur = ctx0.op_mul_mat(&self.layers[il].wo, &cur); + } + + let inpFF = ctx0.op_add(&cur, &inpSA); + + // feed-forward network + { + // norm + { + cur = ctx0.op_norm(&inpFF); + + // cur = ffn_norm*cur + cur = ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ffn_norm, &cur), &cur); + } + + let tmp = ctx0.op_mul_mat(&self.layers[il].w3, &cur); + + cur = ctx0.op_mul_mat(&self.layers[il].w1, &cur); + + // SILU activation + cur = ctx0.op_silu(&cur); + + cur = ctx0.op_mul(&cur, &tmp); + + cur = ctx0.op_mul_mat(&self.layers[il].w2, &cur); + } + + cur = ctx0.op_add(&cur, &inpFF); + + // input for next layer + inpL = cur; + } + + // norm + { + inpL = ctx0.op_norm(&inpL); + + // inpL = norm*inpL + inpL = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &inpL), &inpL); + } + + // lm_head + { + inpL = ctx0.op_mul_mat(&self.output, &inpL); + } + + // logits -> probs + // inpL = ctx0.op_soft_max(&inpL); + + // run the computation + gf.build_forward_expand(&inpL); + ctx0.graph_compute(&mut gf); + + // return result for just the last token + embd_w.resize(n_vocab as usize, 0.0); + // SAFETY: yolo + unsafe { + inpL.read_data( + n_vocab as usize * (N - 1) * std::mem::size_of::(), + bytemuck::cast_slice_mut(embd_w), + ) + }; + + if *mem_per_token == 0 { + *mem_per_token = ctx0.used_mem() / N; + } + } + + pub fn tokenize(&self, vocab: &Vocabulary, text: &str, bos: bool) -> Vec { + let mut res = Vec::new(); + if bos { + res.push(1 as TokenId); // TODO: replace with vocab.bos + } + + // Find the longest token that matches the text + let mut pos = 0; + loop { + let mut l = 0; + let mut t = 0; + + for (tk_id, tk) in vocab.mapping.iter().enumerate() { + if tk.len() < l { + continue; + } + if tk.len() > text.len() - pos { + continue; + } + if text[pos..].starts_with(tk) { + l = tk.len(); + t = tk_id; + } + } + + if l == 0 { + break; + } + + res.push(t as TokenId); + pos += l; + } + + res + } +} diff --git a/llama-rs/src/llama.rs b/llama-rs/src/llama.rs deleted file mode 100644 index 0da03460..00000000 --- a/llama-rs/src/llama.rs +++ /dev/null @@ -1,1158 +0,0 @@ -use std::{ - collections::HashMap, - fmt::Display, - io::{BufRead, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, -}; - -use thiserror::Error; - -use crate::ggml::{GgmlContext, GgmlTensor, GGML_TYPE_I32}; -use ggml_raw::ggml_type; -use partial_sort::PartialSort; -use rand::{distributions::WeightedIndex, prelude::Distribution}; - -use crate::ggml::{GgmlCGraph, GGML_TYPE_F16, GGML_TYPE_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1}; - -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -pub struct LlamaHyperParams { - n_vocab: i32, - n_ctx: i32, - n_embd: i32, - n_mult: i32, - n_head: i32, - n_layer: i32, - n_rot: i32, - f16_: i32, -} - -struct LlamaLayer { - attention_norm: GgmlTensor, - - wq: GgmlTensor, - wk: GgmlTensor, - wv: GgmlTensor, - wo: GgmlTensor, - - // normalization - ffn_norm: GgmlTensor, - - // ff - w1: GgmlTensor, - w2: GgmlTensor, - w3: GgmlTensor, -} - -pub struct LlamaModel { - hparams: LlamaHyperParams, - - tok_embeddings: GgmlTensor, - - norm: GgmlTensor, - output: GgmlTensor, - - layers: Vec, - - memory_k: GgmlTensor, - memory_v: GgmlTensor, - - tensors: HashMap, - - context: GgmlContext, -} - -pub struct InferenceParams { - pub n_threads: i32, - pub n_predict: usize, - pub n_batch: usize, - pub repeat_last_n: usize, - pub top_k: i32, - pub top_p: f32, - pub repeat_penalty: f32, - pub temp: f32, -} - -impl Default for InferenceParams { - fn default() -> Self { - Self { - n_threads: 8, - n_predict: 128, - n_batch: 8, - repeat_last_n: 64, - top_k: 40, - top_p: 0.95, - repeat_penalty: 1.30, - temp: 0.80, - } - } -} - -type TokenId = i32; -type Token = String; - -#[derive(Default)] -pub struct GptVocab { - /// Maps every integer (index) token id to its corresponding string - mapping: Vec, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum OutputToken<'a> { - Token(&'a str), - EndOfText, -} -impl Display for OutputToken<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - OutputToken::Token(t) => *t, - OutputToken::EndOfText => "[end of text]", - } - ) - } -} - -fn llama_n_parts(size: i32) -> i32 { - match size { - 4096 => 1, - 5120 => 2, - 6656 => 3, - 8192 => 8, - _ => unreachable!("Invalid size for N_PARTS"), - } -} - -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub enum LoadProgress<'a> { - HyperParamsLoaded(&'a LlamaHyperParams), - BadToken { - index: usize, - }, - ContextSize { - bytes: usize, - }, - MemorySize { - bytes: usize, - n_mem: usize, - }, - PartLoading { - file: &'a Path, - current_part: usize, - total_parts: usize, - }, - PartTensorLoaded { - file: &'a Path, - current_tensor: usize, - tensor_count: usize, - }, - PartLoaded { - file: &'a Path, - byte_size: usize, - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -pub enum LoadError { - #[error("could not open file {path:?}")] - OpenFileFailed { - source: std::io::Error, - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - ReadExactFailed { - source: std::io::Error, - bytes: usize, - }, - #[error("non-specific I/O error")] - IO(#[from] std::io::Error), - - #[error("could not convert bytes to a UTF-8 string")] - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - - #[error("invalid magic number for {path:?}")] - InvalidMagic { path: PathBuf }, - #[error("invalid value {value} for `f16` in hyperparameters")] - HyperparametersF16Invalid { value: i32 }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - UnknownTensor { tensor_name: String, path: PathBuf }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - TensorWrongSize { tensor_name: String, path: PathBuf }, - #[error("invalid ftype {ftype} in {path:?}")] - InvalidFtype { ftype: i32, path: PathBuf }, -} - -impl LlamaModel { - pub fn load( - path: impl AsRef, - n_ctx: i32, - load_progress_callback: impl Fn(LoadProgress), - ) -> Result<(LlamaModel, GptVocab), LoadError> { - use std::fs::File; - use std::io::BufReader; - - let path = path.as_ref(); - - let mut reader = - BufReader::new(File::open(path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: path.to_owned(), - })?); - - /// Helper function. Reads an int from the buffer and returns it. - fn read_i32(reader: &mut impl BufRead) -> Result { - let mut bytes = [0u8; 4]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: bytes.len(), - })?; - Ok(i32::from_le_bytes(bytes)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } - - // Verify magic - { - let magic = read_i32(&mut reader)?; - if magic != 0x67676d6c { - return Err(LoadError::InvalidMagic { - path: path.to_owned(), - }); - } - } - - // ================= - // Load hyper params - // ================= - - // NOTE: Field order matters! Data is laid out in the file exactly - // in this order. - let hparams = LlamaHyperParams { - n_vocab: read_i32(&mut reader)?, - n_ctx, - n_embd: read_i32(&mut reader)?, - n_mult: read_i32(&mut reader)?, - n_head: read_i32(&mut reader)?, - n_layer: read_i32(&mut reader)?, - n_rot: read_i32(&mut reader)?, - f16_: read_i32(&mut reader)?, - }; - - let n_ff = - ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; - let n_parts = llama_n_parts(hparams.n_embd); - - load_progress_callback(LoadProgress::HyperParamsLoaded(&hparams)); - - // =============== - // Load vocabulary - // =============== - let mut vocab = GptVocab::default(); - for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { - vocab.mapping.push(word); - } else { - load_progress_callback(LoadProgress::BadToken { - index: i.try_into()?, - }); - vocab.mapping.push("�".to_string()); - } - } - - // for the big tensors, we have the option to store the data in 16-bit - // floats or quantized in order to save memory and also to speed up the - // computation - let wtype = match hparams.f16_ { - 0 => GGML_TYPE_F32, - 1 => GGML_TYPE_F16, - 2 => GGML_TYPE_Q4_0, - 3 => GGML_TYPE_Q4_1, - invalid => return Err(LoadError::HyperparametersF16Invalid { value: invalid }), - }; - - let n_embd = hparams.n_embd; - let n_layer = hparams.n_layer; - let n_ctx = hparams.n_ctx; - let n_vocab = hparams.n_vocab; - - let ctx_size = { - // Use 64-bit math to prevent overflow. - let n_embd = n_embd as u64; - let n_layer = n_layer as u64; - let n_ctx = n_ctx as u64; - let n_vocab = n_vocab as u64; - let n_ff = n_ff as u64; - - /// NOTE: The original code relies in promotion rules and automatic - /// cast between int to float. What we do instead is use this macro - /// to convert every term of the multiplication to f64, which should - /// have enough precision bits to hold the final value, then cast to - /// usize. I have observed a discrepancy between the ctx_size found - /// using this code, and the one in llama.cpp. The number for rust - /// ends up being slightly lower, but no "out of memory" errors are - /// reported by ggml. - macro_rules! mul { - ($term:expr, $($terms:expr),*) => { - (($term as f64) $(* ($terms as f64))*) as u64 - }; - } - - fn ggml_type_sizef(x: ggml_raw::ggml_type) -> f64 { - (unsafe { ggml_raw::ggml_type_sizef(x) }) as f64 - } - - let mut ctx_size: u64 = 0; - - ctx_size += mul!(n_embd, n_vocab, ggml_type_sizef(wtype)); // tok_embeddings - - ctx_size += mul!(n_embd, ggml_type_sizef(GGML_TYPE_F32)); // norm - - ctx_size += mul!(n_embd, n_vocab, ggml_type_sizef(wtype)); // output - - ctx_size += mul!(n_layer, n_embd, ggml_type_sizef(GGML_TYPE_F32)); // attention_norm - - ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wq - ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wk - ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wv - ctx_size += mul!(n_layer, n_embd, n_embd, ggml_type_sizef(wtype)); // wo - - ctx_size += mul!(n_layer, n_embd, ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm - - ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w1 - ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w2 - ctx_size += mul!(n_layer, n_ff, n_embd, ggml_type_sizef(wtype)); // w3 - - ctx_size += mul!(n_ctx, n_layer, n_embd, ggml_type_sizef(GGML_TYPE_F32)); // memory_k - ctx_size += mul!(n_ctx, n_layer, n_embd, ggml_type_sizef(GGML_TYPE_F32)); // memory_v - - ctx_size += (5 + 10 * n_layer) * 256; // object overhead - - load_progress_callback(LoadProgress::ContextSize { - bytes: ctx_size.try_into()?, - }); - - ctx_size - }; - - // Initialize the context - let context = GgmlContext::init(ctx_size as usize); - - let model = { - let mut tensors = HashMap::new(); - - let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); - let norm = context.new_tensor_1d(GGML_TYPE_F32, n_embd); - let output = context.new_tensor_2d(wtype, n_embd, n_vocab); - - tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); - tensors.insert("norm.weight".to_owned(), norm.share()); - tensors.insert("output.weight".to_owned(), output.share()); - - let mut layers = Vec::new(); - for i in 0..n_layer { - let layer = LlamaLayer { - attention_norm: context.new_tensor_1d(GGML_TYPE_F32, n_embd), - wq: context.new_tensor_2d(wtype, n_embd, n_embd), - wk: context.new_tensor_2d(wtype, n_embd, n_embd), - wv: context.new_tensor_2d(wtype, n_embd, n_embd), - wo: context.new_tensor_2d(wtype, n_embd, n_embd), - ffn_norm: context.new_tensor_1d(GGML_TYPE_F32, n_embd), - w1: context.new_tensor_2d(wtype, n_embd, n_ff), - w2: context.new_tensor_2d(wtype, n_ff, n_embd), - w3: context.new_tensor_2d(wtype, n_embd, n_ff), - }; - - tensors.insert( - format!("layers.{i}.attention_norm.weight"), - layer.attention_norm.share(), - ); - - tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); - tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); - tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); - tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); - - tensors.insert( - format!("layers.{i}.ffn_norm.weight"), - layer.ffn_norm.share(), - ); - - tensors.insert( - format!("layers.{i}.feed_forward.w1.weight"), - layer.w1.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w2.weight"), - layer.w2.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w3.weight"), - layer.w3.share(), - ); - - layers.push(layer); - } - - // key + value memory - let n_mem = n_layer * n_ctx; - let n_elements = n_embd * n_mem; - - let memory_k = context.new_tensor_1d(GGML_TYPE_F32, n_elements); - let memory_v = context.new_tensor_1d(GGML_TYPE_F32, n_elements); - - let memory_size = memory_k.nbytes() + memory_v.nbytes(); - - load_progress_callback(LoadProgress::MemorySize { - bytes: memory_size, - n_mem: n_mem.try_into()?, - }); - - LlamaModel { - hparams, - tok_embeddings, - norm, - output, - layers, - memory_k, - memory_v, - tensors, - context, - } - }; - - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - for i in 0..n_parts { - let part_id = i; - - let part_path = if i > 0 { - let mut path = path.to_owned(); - let mut filename = path.components().last().unwrap().as_os_str().to_owned(); - filename.push(&format!(".{i}")); - path.pop(); - path.join(filename) - } else { - path.to_path_buf() - }; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: (i + 1).try_into()?, - total_parts: n_parts.try_into()?, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = read_i32(&mut part_reader)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_i32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i32, 1i32]; - for i in 0..n_dims { - ne[i as usize] = read_i32(&mut part_reader)?; - nelements *= ne[i as usize]; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else { - if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else { - if split_type == 0 { - if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } - } - - fn ggml_type_size(t: ggml_type) -> usize { - unsafe { ggml_raw::ggml_type_size(t) } - } - - fn ggml_blck_size(t: ggml_type) -> i32 { - unsafe { ggml_raw::ggml_blck_size(t) } - } - - let bpe = match ftype { - 0 => ggml_type_size(GGML_TYPE_F32), - 1 => ggml_type_size(GGML_TYPE_F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml_type_size(GGML_TYPE_Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml_type_size(GGML_TYPE_Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize - != tensor.nbytes() - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let data = tensor.data(); - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements as usize * bpe) / ggml_blck_size(tensor.get_type()) as usize - != tensor.nbytes() / n_parts as usize - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (tensor.get_ne()[0] / ggml_blck_size(tensor.get_type())) - as usize - * ggml_type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0) as usize - / ggml_blck_size(tensor.get_type()) as usize) - * ggml_type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts as usize, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (tensor.get_ne()[0] / ggml_blck_size(tensor.get_type())) - as usize - * ggml_type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 + part_id * np1) as usize * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts as usize - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); - } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); - } - - Ok((model, vocab)) - } - - pub fn inference_with_prompt( - &self, - vocab: &GptVocab, - params: &InferenceParams, - prompt: &str, - rng: &mut impl rand::Rng, - callback: impl Fn(OutputToken), - ) { - let embd_inp = self.tokenize(vocab, prompt, true); - let mut logits = Vec::new(); - - // determine the required inference memory per token: - let mut mem_per_token = 0; - let _ = self.llama_eval( - params.n_threads, - 0, - &[0, 1, 2, 3], - &mut logits, - &mut mem_per_token, - ); - - let last_n_size = params.repeat_last_n; - let mut last_n_tokens = vec![0 as TokenId; last_n_size]; - - let mut remaining_tokens = usize::min( - params.n_predict, - self.hparams.n_ctx as usize - embd_inp.len(), - ); - let mut input_consumed = 0; - - let mut n_past = 0; - let mut embd = Vec::new(); - while remaining_tokens > 0 { - // predict - if embd.len() > 0 { - self.llama_eval( - params.n_threads, - n_past, - &embd, - &mut logits, - &mut mem_per_token, - ); - } - - n_past += embd.len() as i32; - embd.clear(); - - if embd_inp.len() <= input_consumed { - // out of input, sample next token - let InferenceParams { - top_k, - top_p, - repeat_penalty, - temp, - .. - } = params; - - let n_vocab = self.hparams.n_vocab; - - let id = self.sample_top_p_top_k( - vocab, - &logits[logits.len() - n_vocab as usize..], - &last_n_tokens, - *repeat_penalty as f64, - *top_k, - *top_p as f64, - *temp as f64, - rng, - ); - - last_n_tokens.remove(0); - last_n_tokens.push(id); - - // add it to the context - embd.push(id); - - // decrement remaining sampling budget - remaining_tokens -= 1; - } else { - // if here, it means we are still processing the input prompt - while embd_inp.len() > input_consumed { - embd.push(embd_inp[input_consumed]); - last_n_tokens.remove(0); - last_n_tokens.push(embd_inp[input_consumed]); - input_consumed += 1; - if embd.len() > params.n_batch { - break; - } - } - } - - // display text - let mut eot = false; - for &id in &embd { - let output_token = if id == 2 { - eot = true; - OutputToken::EndOfText - } else { - OutputToken::Token(&vocab.mapping[id as usize]) - }; - callback(output_token); - } - - if eot { - break; - } - } - } - - pub fn sample_top_p_top_k( - &self, - vocab: &GptVocab, - logits: &[f32], - last_n_tokens: &[TokenId], - repeat_penalty: f64, - top_k: i32, - top_p: f64, - temp: f64, - rng: &mut impl rand::Rng, - ) -> TokenId { - let n_logits = vocab.mapping.len(); - let mut logits_id = Vec::<(f64, TokenId)>::with_capacity(n_logits); - - { - let scale = 1.0 / temp; - for (i, &logit) in logits.iter().enumerate() { - // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if last_n_tokens.contains(&(i as TokenId)) { - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if logits[i] < 0.0 { - logits_id.push((logit as f64 * scale * repeat_penalty, i as TokenId)); - } else { - logits_id.push((logit as f64 * scale / repeat_penalty, i as TokenId)); - } - } else { - logits_id.push((logit as f64 * scale, i as TokenId)); - } - } - } - - // find the top K tokens - { - logits_id.partial_sort(top_k as usize, |a, b| { - // Sort descending - b.0.total_cmp(&a.0) - }); - logits_id.truncate(top_k as usize); - } - - let maxl = logits_id - .iter() - .map(|x| x.0) - .max_by(f64::total_cmp) - .unwrap(); - - // compute probs for the top K tokens - let mut probs: Vec = logits_id - .iter() - .copied() - .map(|(k, v)| (k - maxl).exp()) - .collect(); - let sum: f64 = probs.iter().copied().sum(); - - // Normalize the probs - for p in probs.iter_mut() { - *p /= sum; - } - - // Top p sampling - if top_p < 1.0 { - let mut cumsum = 0.0; - for i in 0..probs.len() { - cumsum += probs[i]; - if cumsum >= top_p { - probs.truncate(i + 1); - logits_id.truncate(i + 1); - break; - } - } - - cumsum = 1.0 / cumsum; - for p in probs.iter_mut() { - *p *= cumsum; - } - } - - let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); - let idx = dist.sample(rng); - - logits_id[idx].1 - } - - #[allow(non_snake_case)] - pub fn llama_eval( - &self, - n_threads: i32, - n_past: i32, - embd_inp: &[TokenId], - embd_w: &mut Vec, - mem_per_token: &mut usize, - ) { - let N = embd_inp.len(); - - let LlamaHyperParams { - n_vocab, - n_ctx, - n_embd, - n_mult: _, - n_head, - n_layer, - n_rot, - f16_: _, - } = self.hparams; - - let mut buf_size = 512 * 1024 * 1024; - if *mem_per_token > 0 && *mem_per_token * N > buf_size { - // add 10% to account for ggml object overhead - buf_size = (1.1f64 * *mem_per_token as f64 * N as f64) as usize; - }; - let ctx0 = GgmlContext::init(buf_size); - - let mut gf = GgmlCGraph::new(n_threads); - - let embd = ctx0.new_tensor_1d(GGML_TYPE_I32, N as i32); - unsafe { embd.write_data(bytemuck::cast_slice(embd_inp)) }; - - let mut inpL = ctx0.op_get_rows(&self.tok_embeddings, &embd); - - for il in 0..n_layer as usize { - let inpSA = inpL.share(); - let mut cur: GgmlTensor; - - // norm - { - cur = ctx0.op_norm(&inpL); - - // cur = attention_norm * cur - cur = ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].attention_norm, &cur), &cur); - } - - // self-attention - { - let Qcur = ctx0.op_mul_mat(&self.layers[il].wq, &cur); - let Kcur = ctx0.op_mul_mat(&self.layers[il].wk, &cur); - let Vcur = ctx0.op_mul_mat(&self.layers[il].wv, &cur); - - // store key and value to memory - if N >= 1 { - let k = ctx0.op_view_1d( - &self.memory_k, - N as i32 * n_embd, - (self.memory_k.element_size() * n_embd as usize) - * (il * n_ctx as usize + n_past as usize), - ); - - let v = ctx0.op_view_1d( - &self.memory_v, - N as i32 * n_embd, - (self.memory_v.element_size() * n_embd as usize) - * (il * n_ctx as usize + n_past as usize), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&Kcur, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&Vcur, &v)); - } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - let Q = ctx0.op_permute( - &ctx0.op_rope( - &ctx0.op_cpy( - &Qcur, - &ctx0.new_tensor_3d(GGML_TYPE_F32, n_embd / n_head, n_head, N as i32), - ), - n_past, - n_rot, - 0, - ), - 0, - 2, - 1, - 3, - ); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - let K = ctx0.op_permute( - &ctx0.op_rope( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - &self.memory_k, - (n_past + N as i32) * n_embd, - il * n_ctx as usize - * self.memory_k.element_size() - * n_embd as usize, - ), - n_embd / n_head, - n_head, - n_past + N as i32, - ), - n_past, - n_rot, - 1, - ), - 0, - 2, - 1, - 3, - ); - - // K * Q - let KQ = ctx0.op_mul_mat(&K, &Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let KQ_scaled = ctx0.op_scale( - &KQ, - &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - // KQ_masked = mask_past(KQ_scaled) - let KQ_masked = ctx0.op_diag_mask_inf(&KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - let KQ_soft_max = ctx0.op_soft_max(&KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - let V_trans = ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - &self.memory_v, - (n_past + N as i32) * n_embd, - il * n_ctx as usize * self.memory_v.element_size() * n_embd as usize, - ), - n_embd / n_head, - n_head, - n_past + N as i32, - ), - 1, - 2, - 0, - 3, - ); - - // KQV = transpose(V) * KQ_soft_max - let KQV = ctx0.op_mul_mat(&V_trans, &KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - let KQV_merged = ctx0.op_permute(&KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ctx0.op_cpy( - &KQV_merged, - &ctx0.new_tensor_2d(GGML_TYPE_F32, n_embd, N as i32), - ); - - // projection (no bias) - cur = ctx0.op_mul_mat(&self.layers[il].wo, &cur); - } - - let inpFF = ctx0.op_add(&cur, &inpSA); - - // feed-forward network - { - // norm - { - cur = ctx0.op_norm(&inpFF); - - // cur = ffn_norm*cur - cur = ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ffn_norm, &cur), &cur); - } - - let tmp = ctx0.op_mul_mat(&self.layers[il].w3, &cur); - - cur = ctx0.op_mul_mat(&self.layers[il].w1, &cur); - - // SILU activation - cur = ctx0.op_silu(&cur); - - cur = ctx0.op_mul(&cur, &tmp); - - cur = ctx0.op_mul_mat(&self.layers[il].w2, &cur); - } - - cur = ctx0.op_add(&cur, &inpFF); - - // input for next layer - inpL = cur; - } - - // norm - { - inpL = ctx0.op_norm(&inpL); - - // inpL = norm*inpL - inpL = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &inpL), &inpL); - } - - // lm_head - { - inpL = ctx0.op_mul_mat(&self.output, &inpL); - } - - // logits -> probs - // inpL = ctx0.op_soft_max(&inpL); - - // run the computation - gf.build_forward_expand(&inpL); - ctx0.graph_compute(&mut gf); - - // return result for just the last token - embd_w.resize(n_vocab as usize, 0.0); - // SAFETY: yolo - unsafe { - inpL.read_data( - n_vocab as usize * (N - 1) * std::mem::size_of::(), - bytemuck::cast_slice_mut(embd_w), - ) - }; - - if *mem_per_token == 0 { - *mem_per_token = ctx0.used_mem() / N; - } - } - - pub fn tokenize(&self, vocab: &GptVocab, text: &str, bos: bool) -> Vec { - let mut res = Vec::new(); - if bos { - res.push(1 as TokenId); // TODO: replace with vocab.bos - } - - // Find the longest token that matches the text - let mut pos = 0; - loop { - let mut l = 0; - let mut t = 0; - - for (tk_id, tk) in vocab.mapping.iter().enumerate() { - if tk.len() < l { - continue; - } - if tk.len() > text.len() - pos { - continue; - } - if text[pos..].starts_with(tk) { - l = tk.len(); - t = tk_id; - } - } - - if l == 0 { - break; - } - - res.push(t as TokenId); - pos += l; - } - - res - } -}