diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs deleted file mode 100644 index 617245d0..00000000 --- a/llama-cli/src/cli_args.rs +++ /dev/null @@ -1,84 +0,0 @@ -use clap::Parser; -use once_cell::sync::Lazy; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Where to load the model path from - #[arg(long, short = 'm')] - pub model_path: String, - - /// The prompt to feed the generator - #[arg(long, short = 'p', default_value = None)] - pub prompt: Option, - - /// A file to read the prompt from. Takes precedence over `prompt` if set. - #[arg(long, short = 'f', default_value = None)] - pub prompt_file: Option, - - /// Run in REPL mode. - #[arg(long, short = 'R', default_value_t = false)] - pub repl: bool, - - /// Sets the number of threads to use - #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] - pub num_threads: usize, - - /// Sets how many tokens to predict - #[arg(long, short = 'n')] - pub num_predict: Option, - - /// Sets the size of the context (in tokens). Allows feeding longer prompts. - /// Note that this affects memory. TODO: Unsure how large the limit is. - #[arg(long, default_value_t = 512)] - pub num_ctx_tokens: usize, - - /// How many tokens from the prompt at a time to feed the network. Does not - /// affect generation. - #[arg(long, default_value_t = 8)] - pub batch_size: usize, - - /// Size of the 'last N' buffer that is used for the `repeat_penalty` - /// option. In tokens. - #[arg(long, default_value_t = 64)] - pub repeat_last_n: usize, - - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - #[arg(long, default_value_t = 1.30)] - pub repeat_penalty: f32, - - /// Temperature - #[arg(long, default_value_t = 0.80)] - pub temp: f32, - - /// Top-K: The top K words by score are kept during sampling. - #[arg(long, default_value_t = 40)] - pub top_k: usize, - - /// Top-p: The cummulative probability after which no more words are kept - /// for sampling. - #[arg(long, default_value_t = 0.95)] - pub top_p: f32, - - /// Stores a cached prompt at the given path. The same prompt can then be - /// loaded from disk using --restore-prompt - #[arg(long, default_value = None)] - pub cache_prompt: Option, - - /// Restores a cached prompt at the given path, previously using - /// --cache-prompt - #[arg(long, default_value = None)] - pub restore_prompt: Option, - - /// Specifies the seed to use during sampling. Note that, depending on - /// hardware, the same seed may lead to different results on two separate - /// machines. - #[arg(long, default_value = None)] - pub seed: Option, -} - -/// CLI args are stored in a lazy static variable so they're accessible from -/// everywhere. Arguments are parsed on first access. -pub static CLI_ARGS: Lazy = Lazy::new(Args::parse); diff --git a/llama-cli/src/commands/generate.rs b/llama-cli/src/commands/generate.rs new file mode 100644 index 00000000..d6f9bbd8 --- /dev/null +++ b/llama-cli/src/commands/generate.rs @@ -0,0 +1,173 @@ +use clap::Args; +use llama_rs::{InferenceError, InferenceParameters, Model, Vocabulary}; +use rand::rngs::StdRng; +use rand::SeedableRng; +use std::convert::Infallible; + +#[derive(Debug, Args)] +pub struct Generate { + /// Where to load the model path from + #[arg(long, short = 'm')] + model_path: String, + + /// Sets the number of threads to use + #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] + num_threads: usize, + + /// Sets how many tokens to predict + #[arg(long, short = 'n')] + num_predict: Option, + + /// Sets the size of the context (in tokens). Allows feeding longer prompts. + /// Note that this affects memory. TODO: Unsure how large the limit is. + #[arg(long, default_value_t = 512)] + num_ctx_tokens: usize, + + /// How many tokens from the prompt at a time to feed the network. Does not + /// affect generation. + #[arg(long, default_value_t = 8)] + batch_size: usize, + + /// Size of the 'last N' buffer that is used for the `repeat_penalty` + /// option. In tokens. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The penalty for repeating tokens. Higher values make the generation less + /// likely to get into a loop, but may harm results when repetitive outputs + /// are desired. + #[arg(long, default_value_t = 1.30)] + repeat_penalty: f32, + + /// Temperature + #[arg(long, default_value_t = 0.80)] + temp: f32, + + /// Top-K: The top K words by score are kept during sampling. + #[arg(long, default_value_t = 40)] + top_k: usize, + + /// Top-p: The cummulative probability after which no more words are kept + /// for sampling. + #[arg(long, default_value_t = 0.95)] + top_p: f32, + + /// Specifies the seed to use during sampling. Note that, depending on + /// hardware, the same seed may lead to different results on two separate + /// machines. + #[arg(long, default_value = None)] + seed: Option, +} + +impl Generate { + fn create_seed(&self) -> StdRng { + match self.seed { + Some(seed) => StdRng::seed_from_u64(seed), + None => StdRng::from_entropy(), + } + } + + fn load_model(&self) -> Result<(Model, Vocabulary), String> { + let (model, vocab) = + llama_rs::Model::load(&self.model_path, self.num_ctx_tokens as i32, |progress| { + use llama_rs::LoadProgress; + + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + } + }) + .expect("Could not load model"); + + log::info!("Model fully loaded!"); + Ok((model, vocab)) + } + + fn run(&self, prompt: &String) -> Result { + // model start session + + let inference_params = InferenceParameters { + n_threads: self.num_threads as i32, + n_batch: self.batch_size, + top_k: self.top_k, + top_p: self.top_p, + repeat_penalty: self.repeat_penalty, + temp: self.temp, + }; + + let (mut model, vocab) = self.load_model()?; + let rng = self.create_seed(); + let session = model.start_session(self.repeat_last_n); + let res = session.inference_with_prompt::( + &model, + &vocab, + &inference_params, + &prompt, + self.num_predict, + &mut rng, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(stats) => { + println!("{}", stats); + } + Err(llama_rs::InferenceError::ContextFull) => { + log::warn!("Context window full, stopping inference.") + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + Ok(()) + } +} diff --git a/llama-cli/src/commands/mod.rs b/llama-cli/src/commands/mod.rs new file mode 100644 index 00000000..910b24e3 --- /dev/null +++ b/llama-cli/src/commands/mod.rs @@ -0,0 +1,35 @@ +use clap::Args; +use env_logger::Builder; + +mod generate; +mod mode; +mod prompt; + +use generate::Generate; +use mode::Mode; +use prompt::Prompts; + +#[derive(Debug, Args)] +pub struct LlamaCmd { + #[command(flatten)] + pub generate: Generate, + + #[command(flatten)] + pub mode: Mode, + + #[command(flatten)] + pub prompts: Prompts, +} + +impl LlamaCmd { + pub fn run(&self) -> Result<(), String> { + Builder::new() + .filter_level(log::LevelFilter::Info) + .parse_default_env() + .init(); + + let prompt = self.prompts.run(); + let generate = self.generate.run(&prompt); + self.mode.run(&generate); + } +} diff --git a/llama-cli/src/commands/mode.rs b/llama-cli/src/commands/mode.rs new file mode 100644 index 00000000..4fabb675 --- /dev/null +++ b/llama-cli/src/commands/mode.rs @@ -0,0 +1,97 @@ +use clap::Parser; +use llama_rs::{InferenceError, InferenceParameters}; +use rand::thread_rng; +use rustyline::error::ReadlineError; +use std::{convert::Infallible, io::Write}; + +#[derive(Debug, Parser)] +pub enum Mode { + Repl { + /// Run in REPL mode. + #[arg(long, short = 'R', default_value_t = false)] + repl: bool, + }, + + Interactive { + // Run in interactive mode. + #[arg(long, short = 'i', default_value_t = false)] + interactive: bool, + }, +} + +impl Mode { + fn interactive_mode(&self, model: &llama_rs::Model, vocab: &llama_rs::Vocabulary) { + println!("activated") + // create a sliding window of context + // convert initial prompt into tokens + // convert ai answer into tokens and add into total token count + // wait for user response + // repeat + // issue a warning after the total context is > 2048 tokens + } + + fn repl_mode( + &self, + prompt: &str, + model: &llama_rs::Model, + vocab: &llama_rs::Vocabulary, + params: &InferenceParameters, + ) { + // TODO: refactor this to decouple model generation + // TODO: check run model then store prompt if successful + let mut rl = rustyline::DefaultEditor::new().unwrap(); + loop { + let readline = rl.readline(">> "); + match readline { + Ok(line) => { + // model generation + let mut session = model.start_session(CLI_ARGS.repeat_last_n); + // why this? + let prompt = prompt.replace("$PROMPT", &line); + let mut rng = thread_rng(); + + // TODO: create UI for cli in seperate struct + let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); + if let Err(InferenceError::ContextFull) = + session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) + { + log::error!("Prompt exceeds context window length.") + }; + sp.stop(); + + let res = session.inference_with_prompt::( + model, + vocab, + params, + "", + CLI_ARGS.num_predict, + &mut rng, + |tk| { + print!("{tk}"); + std::io::stdout().flush().unwrap(); + Ok(()) + }, + ); + println!(); + + if let Err(InferenceError::ContextFull) = res { + log::error!("Reply exceeds context window length"); + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); + } + } + } + } + + fn run(self) { + match self { + Self::Repl { repl } => self.repl_mode(), + Self::Interactive { interactive } => self.interactive_mode(), + } + } +} diff --git a/llama-cli/src/commands/prompt.rs b/llama-cli/src/commands/prompt.rs new file mode 100644 index 00000000..2fb5ec6c --- /dev/null +++ b/llama-cli/src/commands/prompt.rs @@ -0,0 +1,137 @@ +use clap::Parser; +use llama_rs::{ + InferenceError, InferenceParameters, InferenceSession, InferenceSnapshot, Model, Vocabulary, +}; +use std::{convert::Infallible, io::Write}; + +#[derive(Debug, Parser)] +pub enum Prompts { + Prompt { + /// The prompt to feed the generator + #[arg(long, short = 'p', default_value = None)] + prompt: Option, + }, + PromptFile { + /// A file to read the prompt from. Takes precedence over `prompt` if set. + #[arg(long, short = 'f', default_value = None)] + prompt_file: Option, + }, + + RestorePrompt { + /// Restores a cached prompt at the given path, previously using + /// --cache-prompt + #[arg(long, default_value = None)] + restore_prompt: Option, + }, + + CachePrompt { + /// Stores a cached prompt at the given path. The same prompt can then be + /// loaded from disk using --restore-prompt + #[arg(long, default_value = None)] + cache_prompt: Option, + }, +} + +impl Prompts { + fn cache_current_prompt( + &self, + session: &InferenceSession, + model: &Model, + vocab: &Vocabulary, + inference_params: &InferenceParameters, + ) { + // TODO: refactor this to decouple model generation and prompt creation + // TODO: check run model then store prompt if successful + let res = session.feed_prompt::( + &model, + &vocab, + &inference_params, + self.cache_prompt, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(_) => (), + Err(InferenceError::ContextFull) => { + log::warn!( + "Context is not large enough to fit the prompt. Saving intermediate state." + ); + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + + // Write the memory to the cache file + // SAFETY: no other model functions used inside the block + unsafe { + let memory = session.get_snapshot(); + match memory.write_to_disk(self.cache_prompt) { + Ok(_) => { + log::info!( + "Successfully written prompt cache to {0}", + self.cache_prompt + ); + } + Err(err) => { + eprintln!("Could not restore prompt. Error: {err}"); + std::process::exit(1); + } + } + } + } + + fn read_prompt_from_file(&self) -> Result { + match std::fs::read_to_string(self.prompt_file) { + Ok(prompt) => Ok(prompt), + Err(err) => { + log::error!( + "Could not read prompt file at {}. Error: {}", + self.prompt_file, + err + ); + return Err(format!( + "Could not read prompt file at {}. Error: {}", + self.prompt_file, err + )); + } + } + } + + fn create_prompt(&self) -> Result { + match self.prompt { + Some(prompt) => Ok(prompt), + None => {} + } + } + + fn restore_previous_prompt(&self, model: &Model) -> Result { + if self.restore_prompt.is_some() { + let snapshot = InferenceSnapshot::load_from_disk(&self.restore_prompt); + match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { + Ok(session) => { + log::info!("Restored cached memory from {0}", self.restore_prompt); + session + } + Err(err) => { + log::error!("{err}"); + std::process::exit(1); + } + } + } + } + + fn run(&self) -> Result { + match self { + Self::Prompt { prompt } => self.create_prompt(), + Self::PromptFile { prompt_file } => self.create_prompt(), + Self::RestorePrompt { restore_prompt } => self.restore_previous_prompt(), + Self::CachePrompt { cache_prompt } => self.cache_current_prompt(), + } + } +} diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 87fd7132..89458418 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,237 +1,21 @@ -use std::{convert::Infallible, io::Write}; +use clap::Parser; -use cli_args::CLI_ARGS; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot}; -use rand::thread_rng; -use rand::SeedableRng; -use rustyline::error::ReadlineError; +use Commands::LlamaCmd; +mod Commands; -mod cli_args; - -fn repl_mode( - prompt: &str, - model: &llama_rs::Model, - vocab: &llama_rs::Vocabulary, - params: &InferenceParameters, -) { - let mut rl = rustyline::DefaultEditor::new().unwrap(); - loop { - let readline = rl.readline(">> "); - match readline { - Ok(line) => { - let mut session = model.start_session(CLI_ARGS.repeat_last_n); - let prompt = prompt.replace("$PROMPT", &line); - let mut rng = thread_rng(); - - let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); - if let Err(InferenceError::ContextFull) = - session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) - { - log::error!("Prompt exceeds context window length.") - }; - sp.stop(); - - let res = session.inference_with_prompt::( - model, - vocab, - params, - "", - CLI_ARGS.num_predict, - &mut rng, - |tk| { - print!("{tk}"); - std::io::stdout().flush().unwrap(); - Ok(()) - }, - ); - println!(); +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +pub struct Args { + #[command(flatten)] + pub cmds: LlamaCmd, +} - if let Err(InferenceError::ContextFull) = res { - log::error!("Reply exceeds context window length"); - } - } - Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { - break; - } - Err(err) => { - log::error!("{err}"); - } - } +impl Args { + fn run(self) -> Result<(), String> { + self.cmds.run() } } -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Info) - .parse_default_env() - .init(); - - let args = &*CLI_ARGS; - - let inference_params = InferenceParameters { - n_threads: args.num_threads as i32, - n_batch: args.batch_size, - top_k: args.top_k, - top_p: args.top_p, - repeat_penalty: args.repeat_penalty, - temp: args.temp, - }; - - let prompt = if let Some(path) = &args.prompt_file { - match std::fs::read_to_string(path) { - Ok(prompt) => prompt, - Err(err) => { - log::error!("Could not read prompt file at {path}. Error {err}"); - std::process::exit(1); - } - } - } else if let Some(prompt) = &args.prompt { - prompt.clone() - } else { - log::error!("No prompt or prompt file was provided. See --help"); - std::process::exit(1); - }; - - let (mut model, vocab) = - llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded HyperParams {hparams:#?}") - } - LoadProgress::BadToken { index } => { - log::info!("Warning: Bad token in vocab at index {index}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::MemorySize { bytes, n_mem } => log::info!( - "Memory size: {} MB {}", - bytes as f32 / 1024.0 / 1024.0, - n_mem - ), - LoadProgress::PartLoading { - file, - current_part, - total_parts, - } => log::info!( - "Loading model part {}/{} from '{}'\n", - current_part, - total_parts, - file.to_string_lossy(), - ), - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - } - }) - .expect("Could not load model"); - - log::info!("Model fully loaded!"); - - let mut rng = if let Some(seed) = CLI_ARGS.seed { - rand::rngs::StdRng::seed_from_u64(seed) - } else { - rand::rngs::StdRng::from_entropy() - }; - - let mut session = if let Some(restore_path) = &args.restore_prompt { - let snapshot = InferenceSnapshot::load_from_disk(restore_path); - match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { - Ok(session) => { - log::info!("Restored cached memory from {restore_path}"); - session - } - Err(err) => { - log::error!("{err}"); - std::process::exit(1); - } - } - } else { - model.start_session(args.repeat_last_n) - }; - - if args.repl { - repl_mode(&prompt, &model, &vocab, &inference_params); - } else if let Some(cache_path) = &args.cache_prompt { - let res = - session.feed_prompt::(&model, &vocab, &inference_params, &prompt, |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }); - - println!(); - - match res { - Ok(_) => (), - Err(InferenceError::ContextFull) => { - log::warn!( - "Context is not large enough to fit the prompt. Saving intermediate state." - ); - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), - } - - // Write the memory to the cache file - // SAFETY: no other model functions used inside the block - unsafe { - let memory = session.get_snapshot(); - match memory.write_to_disk(cache_path) { - Ok(_) => { - log::info!("Successfully written prompt cache to {cache_path}"); - } - Err(err) => { - eprintln!("Could not restore prompt. Error: {err}"); - std::process::exit(1); - } - } - } - } else { - let res = session.inference_with_prompt::( - &model, - &vocab, - &inference_params, - &prompt, - args.num_predict, - &mut rng, - |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }, - ); - println!(); - - match res { - Ok(stats) => { - println!("{}", stats); - } - Err(llama_rs::InferenceError::ContextFull) => { - log::warn!("Context window full, stopping inference.") - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), - } - } +fn main() -> Result<(), String> { + Args::parse().run() }