From 5d8a6134c076ad68045fd19ca1574f6b68622ceb Mon Sep 17 00:00:00 2001 From: User <53880692+hhamud@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:11:35 +0000 Subject: [PATCH 1/5] fix: move files --- llama-cli/src/main.rs | 55 ++--------------------------- llama-cli/src/mode/interactive.rs | 11 ++++++ llama-cli/src/mode/mod.rs | 2 ++ llama-cli/src/mode/repl.rs | 57 +++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 53 deletions(-) create mode 100644 llama-cli/src/mode/interactive.rs create mode 100644 llama-cli/src/mode/mod.rs create mode 100644 llama-cli/src/mode/repl.rs diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 87fd7132..da7af333 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,64 +1,13 @@ use std::{convert::Infallible, io::Write}; - +use crate::mode::{repl::repl_mode, interactive::interactive_mode}; use cli_args::CLI_ARGS; use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot}; -use rand::thread_rng; use rand::SeedableRng; -use rustyline::error::ReadlineError; mod cli_args; +mod mode; -fn repl_mode( - prompt: &str, - model: &llama_rs::Model, - vocab: &llama_rs::Vocabulary, - params: &InferenceParameters, -) { - let mut rl = rustyline::DefaultEditor::new().unwrap(); - loop { - let readline = rl.readline(">> "); - match readline { - Ok(line) => { - let mut session = model.start_session(CLI_ARGS.repeat_last_n); - let prompt = prompt.replace("$PROMPT", &line); - let mut rng = thread_rng(); - - let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); - if let Err(InferenceError::ContextFull) = - session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) - { - log::error!("Prompt exceeds context window length.") - }; - sp.stop(); - - let res = session.inference_with_prompt::( - model, - vocab, - params, - "", - CLI_ARGS.num_predict, - &mut rng, - |tk| { - print!("{tk}"); - std::io::stdout().flush().unwrap(); - Ok(()) - }, - ); - println!(); - if let Err(InferenceError::ContextFull) = res { - log::error!("Reply exceeds context window length"); - } - } - Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { - break; - } - Err(err) => { - log::error!("{err}"); - } - } - } -} fn main() { env_logger::builder() diff --git a/llama-cli/src/mode/interactive.rs b/llama-cli/src/mode/interactive.rs new file mode 100644 index 00000000..940bd66b --- /dev/null +++ b/llama-cli/src/mode/interactive.rs @@ -0,0 +1,11 @@ + +pub fn interactive_mode(model: &llama_rs::Model, vocab: &llama_rs::Vocabulary){ + println!("activated") + // create a sliding window of context + // convert initial prompt into tokens + // convert ai answer into tokens and add into total token count + // wait for user response + // repeat + // issue a warning after the total context is > 2048 tokens + +} diff --git a/llama-cli/src/mode/mod.rs b/llama-cli/src/mode/mod.rs new file mode 100644 index 00000000..1d1bb2f3 --- /dev/null +++ b/llama-cli/src/mode/mod.rs @@ -0,0 +1,2 @@ +pub mod interactive; +pub mod repl; diff --git a/llama-cli/src/mode/repl.rs b/llama-cli/src/mode/repl.rs new file mode 100644 index 00000000..f2070027 --- /dev/null +++ b/llama-cli/src/mode/repl.rs @@ -0,0 +1,57 @@ +use rand::thread_rng; +use rustyline::error::ReadlineError; +use std::{convert::Infallible, io::Write}; +use llama_rs::{InferenceError, InferenceParameters}; +use crate::cli_args::CLI_ARGS; + +pub fn repl_mode( + prompt: &str, + model: &llama_rs::Model, + vocab: &llama_rs::Vocabulary, + params: &InferenceParameters, +) { + let mut rl = rustyline::DefaultEditor::new().unwrap(); + loop { + let readline = rl.readline(">> "); + match readline { + Ok(line) => { + let mut session = model.start_session(CLI_ARGS.repeat_last_n); + let prompt = prompt.replace("$PROMPT", &line); + let mut rng = thread_rng(); + + let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); + if let Err(InferenceError::ContextFull) = + session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) + { + log::error!("Prompt exceeds context window length.") + }; + sp.stop(); + + let res = session.inference_with_prompt::( + model, + vocab, + params, + "", + CLI_ARGS.num_predict, + &mut rng, + |tk| { + print!("{tk}"); + std::io::stdout().flush().unwrap(); + Ok(()) + }, + ); + println!(); + + if let Err(InferenceError::ContextFull) = res { + log::error!("Reply exceeds context window length"); + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); + } + } + } +} From b755d1c119f4b6dc5ed40302e48385419cf231a4 Mon Sep 17 00:00:00 2001 From: User <53880692+hhamud@users.noreply.github.com> Date: Wed, 22 Mar 2023 20:13:14 +0000 Subject: [PATCH 2/5] refactor: break up cli arg struct into subcommands --- llama-cli/src/cli_args.rs | 84 ------------------- llama-cli/src/commands/cache.rs | 12 +++ llama-cli/src/commands/generate.rs | 50 +++++++++++ llama-cli/src/commands/mod.rs | 33 ++++++++ .../src/{mode/repl.rs => commands/mode.rs} | 31 ++++++- llama-cli/src/commands/model.rs | 5 ++ llama-cli/src/commands/prompt.rs | 9 ++ llama-cli/src/mode/interactive.rs | 11 --- llama-cli/src/mode/mod.rs | 2 - 9 files changed, 137 insertions(+), 100 deletions(-) delete mode 100644 llama-cli/src/cli_args.rs create mode 100644 llama-cli/src/commands/cache.rs create mode 100644 llama-cli/src/commands/generate.rs create mode 100644 llama-cli/src/commands/mod.rs rename llama-cli/src/{mode/repl.rs => commands/mode.rs} (73%) create mode 100644 llama-cli/src/commands/model.rs create mode 100644 llama-cli/src/commands/prompt.rs delete mode 100644 llama-cli/src/mode/interactive.rs delete mode 100644 llama-cli/src/mode/mod.rs diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs deleted file mode 100644 index 617245d0..00000000 --- a/llama-cli/src/cli_args.rs +++ /dev/null @@ -1,84 +0,0 @@ -use clap::Parser; -use once_cell::sync::Lazy; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -pub struct Args { - /// Where to load the model path from - #[arg(long, short = 'm')] - pub model_path: String, - - /// The prompt to feed the generator - #[arg(long, short = 'p', default_value = None)] - pub prompt: Option, - - /// A file to read the prompt from. Takes precedence over `prompt` if set. - #[arg(long, short = 'f', default_value = None)] - pub prompt_file: Option, - - /// Run in REPL mode. - #[arg(long, short = 'R', default_value_t = false)] - pub repl: bool, - - /// Sets the number of threads to use - #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] - pub num_threads: usize, - - /// Sets how many tokens to predict - #[arg(long, short = 'n')] - pub num_predict: Option, - - /// Sets the size of the context (in tokens). Allows feeding longer prompts. - /// Note that this affects memory. TODO: Unsure how large the limit is. - #[arg(long, default_value_t = 512)] - pub num_ctx_tokens: usize, - - /// How many tokens from the prompt at a time to feed the network. Does not - /// affect generation. - #[arg(long, default_value_t = 8)] - pub batch_size: usize, - - /// Size of the 'last N' buffer that is used for the `repeat_penalty` - /// option. In tokens. - #[arg(long, default_value_t = 64)] - pub repeat_last_n: usize, - - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - #[arg(long, default_value_t = 1.30)] - pub repeat_penalty: f32, - - /// Temperature - #[arg(long, default_value_t = 0.80)] - pub temp: f32, - - /// Top-K: The top K words by score are kept during sampling. - #[arg(long, default_value_t = 40)] - pub top_k: usize, - - /// Top-p: The cummulative probability after which no more words are kept - /// for sampling. - #[arg(long, default_value_t = 0.95)] - pub top_p: f32, - - /// Stores a cached prompt at the given path. The same prompt can then be - /// loaded from disk using --restore-prompt - #[arg(long, default_value = None)] - pub cache_prompt: Option, - - /// Restores a cached prompt at the given path, previously using - /// --cache-prompt - #[arg(long, default_value = None)] - pub restore_prompt: Option, - - /// Specifies the seed to use during sampling. Note that, depending on - /// hardware, the same seed may lead to different results on two separate - /// machines. - #[arg(long, default_value = None)] - pub seed: Option, -} - -/// CLI args are stored in a lazy static variable so they're accessible from -/// everywhere. Arguments are parsed on first access. -pub static CLI_ARGS: Lazy = Lazy::new(Args::parse); diff --git a/llama-cli/src/commands/cache.rs b/llama-cli/src/commands/cache.rs new file mode 100644 index 00000000..968c3690 --- /dev/null +++ b/llama-cli/src/commands/cache.rs @@ -0,0 +1,12 @@ + + pub struct Cmd { + /// Stores a cached prompt at the given path. The same prompt can then be + /// loaded from disk using --restore-prompt + #[arg(long, default_value = None)] + cache_prompt: Option, + + /// Restores a cached prompt at the given path, previously using + /// --cache-prompt + #[arg(long, default_value = None)] + restore_prompt: Option, + }, diff --git a/llama-cli/src/commands/generate.rs b/llama-cli/src/commands/generate.rs new file mode 100644 index 00000000..73cbb1d8 --- /dev/null +++ b/llama-cli/src/commands/generate.rs @@ -0,0 +1,50 @@ + + pub struct Cmd { + /// Sets the number of threads to use + #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] + num_threads: usize, + + /// Sets how many tokens to predict + #[arg(long, short = 'n')] + num_predict: Option, + + /// Sets the size of the context (in tokens). Allows feeding longer prompts. + /// Note that this affects memory. TODO: Unsure how large the limit is. + #[arg(long, default_value_t = 512)] + num_ctx_tokens: usize, + + /// How many tokens from the prompt at a time to feed the network. Does not + /// affect generation. + #[arg(long, default_value_t = 8)] + batch_size: usize, + + /// Size of the 'last N' buffer that is used for the `repeat_penalty` + /// option. In tokens. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The penalty for repeating tokens. Higher values make the generation less + /// likely to get into a loop, but may harm results when repetitive outputs + /// are desired. + #[arg(long, default_value_t = 1.30)] + repeat_penalty: f32, + + /// Temperature + #[arg(long, default_value_t = 0.80)] + temp: f32, + + /// Top-K: The top K words by score are kept during sampling. + #[arg(long, default_value_t = 40)] + top_k: usize, + + /// Top-p: The cummulative probability after which no more words are kept + /// for sampling. + #[arg(long, default_value_t = 0.95)] + top_p: f32, + + /// Specifies the seed to use during sampling. Note that, depending on + /// hardware, the same seed may lead to different results on two separate + /// machines. + #[arg(long, default_value = None)] + seed: Option, + }, diff --git a/llama-cli/src/commands/mod.rs b/llama-cli/src/commands/mod.rs new file mode 100644 index 00000000..7f9c0056 --- /dev/null +++ b/llama-cli/src/commands/mod.rs @@ -0,0 +1,33 @@ +use clap::{Parser, Subcommand}; +use once_cell::sync::Lazy; + +mod cache; +mod mode; +mod model; +mod prompt; +mod generate; + + +#[derive(Debug, Subcommand)] +pub enum LlamaCmd { + + Cache(cache::Cmd), + + Generate(generate::Cmd), + + Mode(mode::Cmd), + + Model(model::Cmd), + + Prompt(prompt::Cmd) + +} + +impl LlamaCmd { + fn run(&self) { + match self { + + } + + } +} diff --git a/llama-cli/src/mode/repl.rs b/llama-cli/src/commands/mode.rs similarity index 73% rename from llama-cli/src/mode/repl.rs rename to llama-cli/src/commands/mode.rs index f2070027..cab80f62 100644 --- a/llama-cli/src/mode/repl.rs +++ b/llama-cli/src/commands/mode.rs @@ -1,10 +1,34 @@ +use crate::cli_args::CLI_ARGS; +use llama_rs::{InferenceError, InferenceParameters}; use rand::thread_rng; use rustyline::error::ReadlineError; use std::{convert::Infallible, io::Write}; -use llama_rs::{InferenceError, InferenceParameters}; -use crate::cli_args::CLI_ARGS; -pub fn repl_mode( + pub struct Cmd { + /// Run in REPL mode. + #[arg(long, short = 'R', default_value_t = false)] + repl: bool, + + // Run in interactive mode. + #[arg(long, short = 'i', default_value_t = false)] + interactive: bool, + } + +impl Cmd { + + fn interactive_mode(&self, model: &llama_rs::Model, vocab: &llama_rs::Vocabulary) { + println!("activated") + // create a sliding window of context + // convert initial prompt into tokens + // convert ai answer into tokens and add into total token count + // wait for user response + // repeat + // issue a warning after the total context is > 2048 tokens +} + + + fn repl_mode( + &self, prompt: &str, model: &llama_rs::Model, vocab: &llama_rs::Vocabulary, @@ -55,3 +79,4 @@ pub fn repl_mode( } } } +} diff --git a/llama-cli/src/commands/model.rs b/llama-cli/src/commands/model.rs new file mode 100644 index 00000000..02dee97d --- /dev/null +++ b/llama-cli/src/commands/model.rs @@ -0,0 +1,5 @@ + pub struct Cmd { + /// Where to load the model path from + #[arg(long, short = 'm')] + model_path: String, + }, diff --git a/llama-cli/src/commands/prompt.rs b/llama-cli/src/commands/prompt.rs new file mode 100644 index 00000000..4af0a420 --- /dev/null +++ b/llama-cli/src/commands/prompt.rs @@ -0,0 +1,9 @@ + pub struct Cmd { + /// The prompt to feed the generator + #[arg(long, short = 'p', default_value = None)] + prompt: Option, + + /// A file to read the prompt from. Takes precedence over `prompt` if set. + #[arg(long, short = 'f', default_value = None)] + prompt_file: Option, + } diff --git a/llama-cli/src/mode/interactive.rs b/llama-cli/src/mode/interactive.rs deleted file mode 100644 index 940bd66b..00000000 --- a/llama-cli/src/mode/interactive.rs +++ /dev/null @@ -1,11 +0,0 @@ - -pub fn interactive_mode(model: &llama_rs::Model, vocab: &llama_rs::Vocabulary){ - println!("activated") - // create a sliding window of context - // convert initial prompt into tokens - // convert ai answer into tokens and add into total token count - // wait for user response - // repeat - // issue a warning after the total context is > 2048 tokens - -} diff --git a/llama-cli/src/mode/mod.rs b/llama-cli/src/mode/mod.rs deleted file mode 100644 index 1d1bb2f3..00000000 --- a/llama-cli/src/mode/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod interactive; -pub mod repl; From 38106fc799c1e92743c69d3d63d08769564a2235 Mon Sep 17 00:00:00 2001 From: User <53880692+hhamud@users.noreply.github.com> Date: Thu, 23 Mar 2023 22:35:33 +0000 Subject: [PATCH 3/5] TEMP: fixup or squash later --- llama-cli/src/commands/cache.rs | 3 +- llama-cli/src/commands/generate.rs | 2 +- llama-cli/src/main.rs | 100 ++++++++++++++++++----------- 3 files changed, 64 insertions(+), 41 deletions(-) diff --git a/llama-cli/src/commands/cache.rs b/llama-cli/src/commands/cache.rs index 968c3690..c619b008 100644 --- a/llama-cli/src/commands/cache.rs +++ b/llama-cli/src/commands/cache.rs @@ -1,4 +1,3 @@ - pub struct Cmd { /// Stores a cached prompt at the given path. The same prompt can then be /// loaded from disk using --restore-prompt @@ -9,4 +8,4 @@ /// --cache-prompt #[arg(long, default_value = None)] restore_prompt: Option, - }, + } diff --git a/llama-cli/src/commands/generate.rs b/llama-cli/src/commands/generate.rs index 73cbb1d8..08abc791 100644 --- a/llama-cli/src/commands/generate.rs +++ b/llama-cli/src/commands/generate.rs @@ -47,4 +47,4 @@ /// machines. #[arg(long, default_value = None)] seed: Option, - }, + } diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index da7af333..ce75fcaa 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,47 +1,43 @@ -use std::{convert::Infallible, io::Write}; -use crate::mode::{repl::repl_mode, interactive::interactive_mode}; -use cli_args::CLI_ARGS; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot}; +use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; use rand::SeedableRng; - -mod cli_args; -mod mode; +use std::{convert::Infallible, io::Write}; +use clap::{Parser, Subcommand}; +use once_cell::sync::Lazy; +use Commands::LlamaCmd; +mod Commands; -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Info) - .parse_default_env() - .init(); - - let args = &*CLI_ARGS; - - let inference_params = InferenceParameters { - n_threads: args.num_threads as i32, - n_batch: args.batch_size, - top_k: args.top_k, - top_p: args.top_p, - repeat_penalty: args.repeat_penalty, - temp: args.temp, - }; +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +pub struct Args { + #[command(subcommand)] + pub cmds: LlamaCmd, +} - let prompt = if let Some(path) = &args.prompt_file { - match std::fs::read_to_string(path) { - Ok(prompt) => prompt, - Err(err) => { - log::error!("Could not read prompt file at {path}. Error {err}"); - std::process::exit(1); - } +/// CLI args are stored in a lazy static variable so they're accessible from +/// everywhere. Arguments are parsed on first access. +pub static CLI_ARGS: Lazy = Lazy::new(Args::parse); + +fn read_prompt_from_file(path: &str) -> Result { + match std::fs::read_to_string(path) { + Ok(prompt) => Ok(prompt), + Err(err) => { + log::error!("Could not read prompt file at {}. Error: {}", path, err); + return Err(format!( + "Could not read prompt file at {}. Error: {}", + path, err + )); } - } else if let Some(prompt) = &args.prompt { - prompt.clone() - } else { - log::error!("No prompt or prompt file was provided. See --help"); - std::process::exit(1); - }; + } +} + +fn create_prompt(prompt: &str) -> Result<&str, String> { + Ok(prompt) +} - let (mut model, vocab) = +fn select_model_and_vocab(args: &Args) -> Result<(Model, Vocabulary), String> { + let (model, vocab) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { use llama_rs::LoadProgress; match progress { @@ -96,13 +92,39 @@ fn main() { .expect("Could not load model"); log::info!("Model fully loaded!"); + Ok((model, vocab)) +} + +fn main() { + env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_default_env() + .init(); + + let args = &*CLI_ARGS; + + let inference_params = InferenceParameters { + n_threads: args.num_threads as i32, + n_batch: args.batch_size, + top_k: args.top_k, + top_p: args.top_p, + repeat_penalty: args.repeat_penalty, + temp: args.temp, + }; + + let prompt = create_prompt(prompt).unwrap(); + let (mut model, vocab) = select_model_and_vocab(args).unwrap(); + + fn create_seed(seed: u64) -> Result {} + // seed flag let mut rng = if let Some(seed) = CLI_ARGS.seed { rand::rngs::StdRng::seed_from_u64(seed) } else { rand::rngs::StdRng::from_entropy() }; + // restore_prompt flag let mut session = if let Some(restore_path) = &args.restore_prompt { let snapshot = InferenceSnapshot::load_from_disk(restore_path); match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { @@ -119,7 +141,9 @@ fn main() { model.start_session(args.repeat_last_n) }; - if args.repl { + if args.interactive { + interactive_mode(&model, &vocab) + } else if args.repl { repl_mode(&prompt, &model, &vocab, &inference_params); } else if let Some(cache_path) = &args.cache_prompt { let res = From 4c9ac3e208651d678f1a9d3f0ff158f0de71a700 Mon Sep 17 00:00:00 2001 From: User <53880692+hhamud@users.noreply.github.com> Date: Sat, 25 Mar 2023 12:47:29 +0000 Subject: [PATCH 4/5] refactor: remove model.rs and refactor prompt --- llama-cli/src/commands/cache.rs | 68 ++++++++-- llama-cli/src/commands/generate.rs | 185 +++++++++++++++++++------- llama-cli/src/commands/mod.rs | 49 ++++--- llama-cli/src/commands/mode.rs | 129 +++++++++--------- llama-cli/src/commands/model.rs | 5 - llama-cli/src/commands/prompt.rs | 98 +++++++++++++- llama-cli/src/main.rs | 203 +---------------------------- 7 files changed, 399 insertions(+), 338 deletions(-) delete mode 100644 llama-cli/src/commands/model.rs diff --git a/llama-cli/src/commands/cache.rs b/llama-cli/src/commands/cache.rs index c619b008..6a25ea9f 100644 --- a/llama-cli/src/commands/cache.rs +++ b/llama-cli/src/commands/cache.rs @@ -1,11 +1,59 @@ - pub struct Cmd { - /// Stores a cached prompt at the given path. The same prompt can then be - /// loaded from disk using --restore-prompt - #[arg(long, default_value = None)] - cache_prompt: Option, - - /// Restores a cached prompt at the given path, previously using - /// --cache-prompt - #[arg(long, default_value = None)] - restore_prompt: Option, +use clap::Args; +use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, InferenceSession, Model, Vocabulary}; +use std::{convert::Infallible, io::Write}; + +#[derive(Debug, Args)] +pub struct Cache { + /// Stores a cached prompt at the given path. The same prompt can then be + /// loaded from disk using --restore-prompt + #[arg(long, default_value = None)] + cache_prompt: Option, + +} + +impl Cache { + + fn cache_current_prompt(&self, session: &InferenceSession, model: &Model, vocab: &Vocabulary, prompt: &str, inference_params: &InferenceParameters) { + if self.cache_prompt.is_some() { + let res = session.feed_prompt::( + &model, + &vocab, + &inference_params, + &prompt, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(_) => (), + Err(InferenceError::ContextFull) => { + log::warn!( + "Context is not large enough to fit the prompt. Saving intermediate state." + ); + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + + // Write the memory to the cache file + // SAFETY: no other model functions used inside the block + unsafe { + let memory = session.get_snapshot(); + match memory.write_to_disk(self.cache_prompt) { + Ok(_) => { + log::info!("Successfully written prompt cache to {0}", self.cache_prompt); + } + Err(err) => { + eprintln!("Could not restore prompt. Error: {err}"); + std::process::exit(1); + } + } + } + } } +} diff --git a/llama-cli/src/commands/generate.rs b/llama-cli/src/commands/generate.rs index 08abc791..6e546886 100644 --- a/llama-cli/src/commands/generate.rs +++ b/llama-cli/src/commands/generate.rs @@ -1,50 +1,139 @@ +use clap::Args; +use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; +use rand::SeedableRng; - pub struct Cmd { - /// Sets the number of threads to use - #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] - num_threads: usize, - - /// Sets how many tokens to predict - #[arg(long, short = 'n')] - num_predict: Option, - - /// Sets the size of the context (in tokens). Allows feeding longer prompts. - /// Note that this affects memory. TODO: Unsure how large the limit is. - #[arg(long, default_value_t = 512)] - num_ctx_tokens: usize, - - /// How many tokens from the prompt at a time to feed the network. Does not - /// affect generation. - #[arg(long, default_value_t = 8)] - batch_size: usize, - - /// Size of the 'last N' buffer that is used for the `repeat_penalty` - /// option. In tokens. - #[arg(long, default_value_t = 64)] - repeat_last_n: usize, - - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - #[arg(long, default_value_t = 1.30)] - repeat_penalty: f32, - - /// Temperature - #[arg(long, default_value_t = 0.80)] - temp: f32, - - /// Top-K: The top K words by score are kept during sampling. - #[arg(long, default_value_t = 40)] - top_k: usize, - - /// Top-p: The cummulative probability after which no more words are kept - /// for sampling. - #[arg(long, default_value_t = 0.95)] - top_p: f32, - - /// Specifies the seed to use during sampling. Note that, depending on - /// hardware, the same seed may lead to different results on two separate - /// machines. - #[arg(long, default_value = None)] - seed: Option, +#[derive(Debug, Args)] +pub struct Generate { + /// Where to load the model path from + #[arg(long, short = 'm')] + model_path: String, + + /// Sets the number of threads to use + #[arg(long, short = 't', default_value_t = num_cpus::get_physical())] + num_threads: usize, + + /// Sets how many tokens to predict + #[arg(long, short = 'n')] + num_predict: Option, + + /// Sets the size of the context (in tokens). Allows feeding longer prompts. + /// Note that this affects memory. TODO: Unsure how large the limit is. + #[arg(long, default_value_t = 512)] + num_ctx_tokens: usize, + + /// How many tokens from the prompt at a time to feed the network. Does not + /// affect generation. + #[arg(long, default_value_t = 8)] + batch_size: usize, + + /// Size of the 'last N' buffer that is used for the `repeat_penalty` + /// option. In tokens. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The penalty for repeating tokens. Higher values make the generation less + /// likely to get into a loop, but may harm results when repetitive outputs + /// are desired. + #[arg(long, default_value_t = 1.30)] + repeat_penalty: f32, + + /// Temperature + #[arg(long, default_value_t = 0.80)] + temp: f32, + + /// Top-K: The top K words by score are kept during sampling. + #[arg(long, default_value_t = 40)] + top_k: usize, + + /// Top-p: The cummulative probability after which no more words are kept + /// for sampling. + #[arg(long, default_value_t = 0.95)] + top_p: f32, + + /// Specifies the seed to use during sampling. Note that, depending on + /// hardware, the same seed may lead to different results on two separate + /// machines. + #[arg(long, default_value = None)] + seed: Option, +} + +impl Generate { + fn create_seed(&self) -> rand::rngs::StdRng { + if self.seed.is_some() { + rand::rngs::StdRng::seed_from_u64(self.seed) + } else { + rand::rngs::StdRng::from_entropy() + } + } + + fn select_model_and_vocab(&self) -> Result<(Model, Vocabulary), String> { + let (model, vocab) = + llama_rs::Model::load(&self.model_path, self.num_ctx_tokens as i32, |progress| { + use llama_rs::LoadProgress; + + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + } + }) + .expect("Could not load model"); + + log::info!("Model fully loaded!"); + Ok((model, vocab)) + } + + fn create_inference_parameters(&self) -> InferenceParameters { + InferenceParameters { + n_threads: self.num_threads as i32, + n_batch: self.batch_size, + top_k: self.top_k, + top_p: self.top_p, + repeat_penalty: self.repeat_penalty, + temp: self.temp, + }; } +} diff --git a/llama-cli/src/commands/mod.rs b/llama-cli/src/commands/mod.rs index 7f9c0056..cc30a465 100644 --- a/llama-cli/src/commands/mod.rs +++ b/llama-cli/src/commands/mod.rs @@ -1,33 +1,46 @@ -use clap::{Parser, Subcommand}; +use clap::{Args, Parser}; +use env_logger::Builder; +use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; use once_cell::sync::Lazy; +use std::{convert::Infallible, io::Write}; mod cache; +mod generate; mod mode; -mod model; mod prompt; -mod generate; - - -#[derive(Debug, Subcommand)] -pub enum LlamaCmd { - Cache(cache::Cmd), +use cache::Cache; +use generate::Generate; +use mode::Mode; +use prompt::Prompts; - Generate(generate::Cmd), +#[derive(Debug, Args)] +pub struct LlamaCmd { + #[command(flatten)] + pub cache: Cache, - Mode(mode::Cmd), + #[command(flatten)] + pub generate: Generate, - Model(model::Cmd), - - Prompt(prompt::Cmd) + #[command(flatten)] + pub mode: Mode, + #[command(flatten)] + pub prompts: Prompts, } impl LlamaCmd { - fn run(&self) { - match self { - - } - + pub fn run(&self) -> Result<(), String> { + // create and run the actual session here + // use match and if statements to build up the + Builder::new() + .filter_level(log::LevelFilter::Info) + .parse_default_env() + .init(); + + let prompt = self.prompts.run(); + let generate = self.generate.inference_parameters(); + let mode = self.mode.run(); + let (mut model, vocab) = select_model_and_vocab(args).unwrap(); } } diff --git a/llama-cli/src/commands/mode.rs b/llama-cli/src/commands/mode.rs index cab80f62..c55dba3e 100644 --- a/llama-cli/src/commands/mode.rs +++ b/llama-cli/src/commands/mode.rs @@ -1,82 +1,93 @@ -use crate::cli_args::CLI_ARGS; +use clap::{Parser,Args}; use llama_rs::{InferenceError, InferenceParameters}; use rand::thread_rng; use rustyline::error::ReadlineError; use std::{convert::Infallible, io::Write}; - pub struct Cmd { +#[derive(Debug, Parser)] +pub enum Mode { + Repl { /// Run in REPL mode. #[arg(long, short = 'R', default_value_t = false)] repl: bool, + }, + Interactive { // Run in interactive mode. #[arg(long, short = 'i', default_value_t = false)] interactive: bool, - } - -impl Cmd { - - fn interactive_mode(&self, model: &llama_rs::Model, vocab: &llama_rs::Vocabulary) { - println!("activated") - // create a sliding window of context - // convert initial prompt into tokens - // convert ai answer into tokens and add into total token count - // wait for user response - // repeat - // issue a warning after the total context is > 2048 tokens + }, } +impl Mode { + fn interactive_mode(&self, model: &llama_rs::Model, vocab: &llama_rs::Vocabulary) { + println!("activated") + // create a sliding window of context + // convert initial prompt into tokens + // convert ai answer into tokens and add into total token count + // wait for user response + // repeat + // issue a warning after the total context is > 2048 tokens + } - fn repl_mode( - &self, - prompt: &str, - model: &llama_rs::Model, - vocab: &llama_rs::Vocabulary, - params: &InferenceParameters, -) { - let mut rl = rustyline::DefaultEditor::new().unwrap(); - loop { - let readline = rl.readline(">> "); - match readline { - Ok(line) => { - let mut session = model.start_session(CLI_ARGS.repeat_last_n); - let prompt = prompt.replace("$PROMPT", &line); - let mut rng = thread_rng(); + fn repl_mode( + &self, + prompt: &str, + model: &llama_rs::Model, + vocab: &llama_rs::Vocabulary, + params: &InferenceParameters, + ) { + let mut rl = rustyline::DefaultEditor::new().unwrap(); + loop { + let readline = rl.readline(">> "); + match readline { + Ok(line) => { + let mut session = model.start_session(CLI_ARGS.repeat_last_n); + let prompt = prompt.replace("$PROMPT", &line); + let mut rng = thread_rng(); - let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); - if let Err(InferenceError::ContextFull) = - session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) - { - log::error!("Prompt exceeds context window length.") - }; - sp.stop(); + let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); + if let Err(InferenceError::ContextFull) = + session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) + { + log::error!("Prompt exceeds context window length.") + }; + sp.stop(); - let res = session.inference_with_prompt::( - model, - vocab, - params, - "", - CLI_ARGS.num_predict, - &mut rng, - |tk| { - print!("{tk}"); - std::io::stdout().flush().unwrap(); - Ok(()) - }, - ); - println!(); + let res = session.inference_with_prompt::( + model, + vocab, + params, + "", + CLI_ARGS.num_predict, + &mut rng, + |tk| { + print!("{tk}"); + std::io::stdout().flush().unwrap(); + Ok(()) + }, + ); + println!(); - if let Err(InferenceError::ContextFull) = res { - log::error!("Reply exceeds context window length"); + if let Err(InferenceError::ContextFull) = res { + log::error!("Reply exceeds context window length"); + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); } - } - Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { - break; - } - Err(err) => { - log::error!("{err}"); } } } -} + + fn run(self) { + match self { + Self::Repl { repl } => self.repl_mode(), + Self::Interactive { interactive } => self.interactive_mode(), + _ => {} + } + } } diff --git a/llama-cli/src/commands/model.rs b/llama-cli/src/commands/model.rs deleted file mode 100644 index 02dee97d..00000000 --- a/llama-cli/src/commands/model.rs +++ /dev/null @@ -1,5 +0,0 @@ - pub struct Cmd { - /// Where to load the model path from - #[arg(long, short = 'm')] - model_path: String, - }, diff --git a/llama-cli/src/commands/prompt.rs b/llama-cli/src/commands/prompt.rs index 4af0a420..bdf33efb 100644 --- a/llama-cli/src/commands/prompt.rs +++ b/llama-cli/src/commands/prompt.rs @@ -1,9 +1,103 @@ - pub struct Cmd { +use clap::{Parser, Args}; +use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, InferenceSession, Model, Vocabulary}; +use std::{convert::Infallible, io::Write}; + +#[derive(Debug, Parser)] +pub enum Prompts { + Prompt { /// The prompt to feed the generator #[arg(long, short = 'p', default_value = None)] prompt: Option, - + }, + PromptFile { /// A file to read the prompt from. Takes precedence over `prompt` if set. #[arg(long, short = 'f', default_value = None)] prompt_file: Option, + }, + + RestorePrompt { + /// Restores a cached prompt at the given path, previously using + /// --cache-prompt + #[arg(long, default_value = None)] + restore_prompt: Option, + }, +} + +impl Prompts { + fn read_prompt_from_file(&self) -> Result { + match std::fs::read_to_string(self.prompt_file) { + Ok(prompt) => Ok(prompt), + Err(err) => { + log::error!( + "Could not read prompt file at {}. Error: {}", + self.prompt_file, + err + ); + return Err(format!( + "Could not read prompt file at {}. Error: {}", + self.prompt_file, err + )); + } + } + } + + fn create_prompt(&self) -> Result { + // interactive, repl, cache_prompt + // if just plain prompt file or prompt fun this + Ok(self.prompt); + } + + fn create_session(&self, session: &InferenceSession) { + let res = session.inference_with_prompt::( + &model, + &vocab, + &inference_params, + &prompt, + args.num_predict, + &mut rng, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(stats) => { + println!("{}", stats); + } + Err(llama_rs::InferenceError::ContextFull) => { + log::warn!("Context window full, stopping inference.") + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + Ok(()) + } + + fn restore_previous_prompt(&self, model: &Model) -> Result { + if self.restore_prompt.is_some() { + let snapshot = InferenceSnapshot::load_from_disk(&self.restore_prompt); + match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { + Ok(session) => { + log::info!("Restored cached memory from {0}", self.restore_prompt); + session + } + Err(err) => { + log::error!("{err}"); + std::process::exit(1); + } + } + } + } + + fn run(&self) -> Result { + match self { + Self::Prompt { prompt } => self.create_prompt(), + Self::PromptFile { prompt_file } => self.create_prompt(), + Self::RestorePrompt { restore_prompt } => self.restore_previous_prompt(), + } } +} diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index ce75fcaa..89458418 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,9 +1,4 @@ -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; -use rand::SeedableRng; -use std::{convert::Infallible, io::Write}; -use clap::{Parser, Subcommand}; -use once_cell::sync::Lazy; - +use clap::Parser; use Commands::LlamaCmd; mod Commands; @@ -11,200 +6,16 @@ mod Commands; #[derive(Debug, Parser)] #[command(author, version, about, long_about = None)] pub struct Args { - #[command(subcommand)] + #[command(flatten)] pub cmds: LlamaCmd, } -/// CLI args are stored in a lazy static variable so they're accessible from -/// everywhere. Arguments are parsed on first access. -pub static CLI_ARGS: Lazy = Lazy::new(Args::parse); - -fn read_prompt_from_file(path: &str) -> Result { - match std::fs::read_to_string(path) { - Ok(prompt) => Ok(prompt), - Err(err) => { - log::error!("Could not read prompt file at {}. Error: {}", path, err); - return Err(format!( - "Could not read prompt file at {}. Error: {}", - path, err - )); - } +impl Args { + fn run(self) -> Result<(), String> { + self.cmds.run() } } -fn create_prompt(prompt: &str) -> Result<&str, String> { - Ok(prompt) -} - -fn select_model_and_vocab(args: &Args) -> Result<(Model, Vocabulary), String> { - let (model, vocab) = - llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded HyperParams {hparams:#?}") - } - LoadProgress::BadToken { index } => { - log::info!("Warning: Bad token in vocab at index {index}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::MemorySize { bytes, n_mem } => log::info!( - "Memory size: {} MB {}", - bytes as f32 / 1024.0 / 1024.0, - n_mem - ), - LoadProgress::PartLoading { - file, - current_part, - total_parts, - } => log::info!( - "Loading model part {}/{} from '{}'\n", - current_part, - total_parts, - file.to_string_lossy(), - ), - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - } - }) - .expect("Could not load model"); - - log::info!("Model fully loaded!"); - Ok((model, vocab)) -} - -fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Info) - .parse_default_env() - .init(); - - let args = &*CLI_ARGS; - - let inference_params = InferenceParameters { - n_threads: args.num_threads as i32, - n_batch: args.batch_size, - top_k: args.top_k, - top_p: args.top_p, - repeat_penalty: args.repeat_penalty, - temp: args.temp, - }; - - let prompt = create_prompt(prompt).unwrap(); - let (mut model, vocab) = select_model_and_vocab(args).unwrap(); - - fn create_seed(seed: u64) -> Result {} - - // seed flag - let mut rng = if let Some(seed) = CLI_ARGS.seed { - rand::rngs::StdRng::seed_from_u64(seed) - } else { - rand::rngs::StdRng::from_entropy() - }; - - // restore_prompt flag - let mut session = if let Some(restore_path) = &args.restore_prompt { - let snapshot = InferenceSnapshot::load_from_disk(restore_path); - match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { - Ok(session) => { - log::info!("Restored cached memory from {restore_path}"); - session - } - Err(err) => { - log::error!("{err}"); - std::process::exit(1); - } - } - } else { - model.start_session(args.repeat_last_n) - }; - - if args.interactive { - interactive_mode(&model, &vocab) - } else if args.repl { - repl_mode(&prompt, &model, &vocab, &inference_params); - } else if let Some(cache_path) = &args.cache_prompt { - let res = - session.feed_prompt::(&model, &vocab, &inference_params, &prompt, |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }); - - println!(); - - match res { - Ok(_) => (), - Err(InferenceError::ContextFull) => { - log::warn!( - "Context is not large enough to fit the prompt. Saving intermediate state." - ); - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), - } - - // Write the memory to the cache file - // SAFETY: no other model functions used inside the block - unsafe { - let memory = session.get_snapshot(); - match memory.write_to_disk(cache_path) { - Ok(_) => { - log::info!("Successfully written prompt cache to {cache_path}"); - } - Err(err) => { - eprintln!("Could not restore prompt. Error: {err}"); - std::process::exit(1); - } - } - } - } else { - let res = session.inference_with_prompt::( - &model, - &vocab, - &inference_params, - &prompt, - args.num_predict, - &mut rng, - |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }, - ); - println!(); - - match res { - Ok(stats) => { - println!("{}", stats); - } - Err(llama_rs::InferenceError::ContextFull) => { - log::warn!("Context window full, stopping inference.") - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), - } - } +fn main() -> Result<(), String> { + Args::parse().run() } From fa0ca5c720a116a8667665f72372695179870ce8 Mon Sep 17 00:00:00 2001 From: User <53880692+hhamud@users.noreply.github.com> Date: Sat, 25 Mar 2023 19:35:56 +0000 Subject: [PATCH 5/5] refactor: remove cache.rs --- llama-cli/src/commands/cache.rs | 59 ----------------- llama-cli/src/commands/generate.rs | 52 ++++++++++++--- llama-cli/src/commands/mod.rs | 17 +---- llama-cli/src/commands/mode.rs | 8 ++- llama-cli/src/commands/prompt.rs | 102 +++++++++++++++++++---------- 5 files changed, 120 insertions(+), 118 deletions(-) delete mode 100644 llama-cli/src/commands/cache.rs diff --git a/llama-cli/src/commands/cache.rs b/llama-cli/src/commands/cache.rs deleted file mode 100644 index 6a25ea9f..00000000 --- a/llama-cli/src/commands/cache.rs +++ /dev/null @@ -1,59 +0,0 @@ -use clap::Args; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, InferenceSession, Model, Vocabulary}; -use std::{convert::Infallible, io::Write}; - -#[derive(Debug, Args)] -pub struct Cache { - /// Stores a cached prompt at the given path. The same prompt can then be - /// loaded from disk using --restore-prompt - #[arg(long, default_value = None)] - cache_prompt: Option, - -} - -impl Cache { - - fn cache_current_prompt(&self, session: &InferenceSession, model: &Model, vocab: &Vocabulary, prompt: &str, inference_params: &InferenceParameters) { - if self.cache_prompt.is_some() { - let res = session.feed_prompt::( - &model, - &vocab, - &inference_params, - &prompt, - |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }, - ); - - println!(); - - match res { - Ok(_) => (), - Err(InferenceError::ContextFull) => { - log::warn!( - "Context is not large enough to fit the prompt. Saving intermediate state." - ); - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), - } - - // Write the memory to the cache file - // SAFETY: no other model functions used inside the block - unsafe { - let memory = session.get_snapshot(); - match memory.write_to_disk(self.cache_prompt) { - Ok(_) => { - log::info!("Successfully written prompt cache to {0}", self.cache_prompt); - } - Err(err) => { - eprintln!("Could not restore prompt. Error: {err}"); - std::process::exit(1); - } - } - } - } - } -} diff --git a/llama-cli/src/commands/generate.rs b/llama-cli/src/commands/generate.rs index 6e546886..d6f9bbd8 100644 --- a/llama-cli/src/commands/generate.rs +++ b/llama-cli/src/commands/generate.rs @@ -1,6 +1,8 @@ use clap::Args; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; +use llama_rs::{InferenceError, InferenceParameters, Model, Vocabulary}; +use rand::rngs::StdRng; use rand::SeedableRng; +use std::convert::Infallible; #[derive(Debug, Args)] pub struct Generate { @@ -58,15 +60,14 @@ pub struct Generate { } impl Generate { - fn create_seed(&self) -> rand::rngs::StdRng { - if self.seed.is_some() { - rand::rngs::StdRng::seed_from_u64(self.seed) - } else { - rand::rngs::StdRng::from_entropy() + fn create_seed(&self) -> StdRng { + match self.seed { + Some(seed) => StdRng::seed_from_u64(seed), + None => StdRng::from_entropy(), } } - fn select_model_and_vocab(&self) -> Result<(Model, Vocabulary), String> { + fn load_model(&self) -> Result<(Model, Vocabulary), String> { let (model, vocab) = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens as i32, |progress| { use llama_rs::LoadProgress; @@ -126,8 +127,10 @@ impl Generate { Ok((model, vocab)) } - fn create_inference_parameters(&self) -> InferenceParameters { - InferenceParameters { + fn run(&self, prompt: &String) -> Result { + // model start session + + let inference_params = InferenceParameters { n_threads: self.num_threads as i32, n_batch: self.batch_size, top_k: self.top_k, @@ -135,5 +138,36 @@ impl Generate { repeat_penalty: self.repeat_penalty, temp: self.temp, }; + + let (mut model, vocab) = self.load_model()?; + let rng = self.create_seed(); + let session = model.start_session(self.repeat_last_n); + let res = session.inference_with_prompt::( + &model, + &vocab, + &inference_params, + &prompt, + self.num_predict, + &mut rng, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(stats) => { + println!("{}", stats); + } + Err(llama_rs::InferenceError::ContextFull) => { + log::warn!("Context window full, stopping inference.") + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + Ok(()) } } diff --git a/llama-cli/src/commands/mod.rs b/llama-cli/src/commands/mod.rs index cc30a465..910b24e3 100644 --- a/llama-cli/src/commands/mod.rs +++ b/llama-cli/src/commands/mod.rs @@ -1,24 +1,16 @@ -use clap::{Args, Parser}; +use clap::Args; use env_logger::Builder; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, Model, Vocabulary}; -use once_cell::sync::Lazy; -use std::{convert::Infallible, io::Write}; -mod cache; mod generate; mod mode; mod prompt; -use cache::Cache; use generate::Generate; use mode::Mode; use prompt::Prompts; #[derive(Debug, Args)] pub struct LlamaCmd { - #[command(flatten)] - pub cache: Cache, - #[command(flatten)] pub generate: Generate, @@ -31,16 +23,13 @@ pub struct LlamaCmd { impl LlamaCmd { pub fn run(&self) -> Result<(), String> { - // create and run the actual session here - // use match and if statements to build up the Builder::new() .filter_level(log::LevelFilter::Info) .parse_default_env() .init(); let prompt = self.prompts.run(); - let generate = self.generate.inference_parameters(); - let mode = self.mode.run(); - let (mut model, vocab) = select_model_and_vocab(args).unwrap(); + let generate = self.generate.run(&prompt); + self.mode.run(&generate); } } diff --git a/llama-cli/src/commands/mode.rs b/llama-cli/src/commands/mode.rs index c55dba3e..4fabb675 100644 --- a/llama-cli/src/commands/mode.rs +++ b/llama-cli/src/commands/mode.rs @@ -1,4 +1,4 @@ -use clap::{Parser,Args}; +use clap::Parser; use llama_rs::{InferenceError, InferenceParameters}; use rand::thread_rng; use rustyline::error::ReadlineError; @@ -37,15 +37,20 @@ impl Mode { vocab: &llama_rs::Vocabulary, params: &InferenceParameters, ) { + // TODO: refactor this to decouple model generation + // TODO: check run model then store prompt if successful let mut rl = rustyline::DefaultEditor::new().unwrap(); loop { let readline = rl.readline(">> "); match readline { Ok(line) => { + // model generation let mut session = model.start_session(CLI_ARGS.repeat_last_n); + // why this? let prompt = prompt.replace("$PROMPT", &line); let mut rng = thread_rng(); + // TODO: create UI for cli in seperate struct let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); if let Err(InferenceError::ContextFull) = session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) @@ -87,7 +92,6 @@ impl Mode { match self { Self::Repl { repl } => self.repl_mode(), Self::Interactive { interactive } => self.interactive_mode(), - _ => {} } } } diff --git a/llama-cli/src/commands/prompt.rs b/llama-cli/src/commands/prompt.rs index bdf33efb..2fb5ec6c 100644 --- a/llama-cli/src/commands/prompt.rs +++ b/llama-cli/src/commands/prompt.rs @@ -1,5 +1,7 @@ -use clap::{Parser, Args}; -use llama_rs::{InferenceError, InferenceParameters, InferenceSnapshot, InferenceSession, Model, Vocabulary}; +use clap::Parser; +use llama_rs::{ + InferenceError, InferenceParameters, InferenceSession, InferenceSnapshot, Model, Vocabulary, +}; use std::{convert::Infallible, io::Write}; #[derive(Debug, Parser)] @@ -21,9 +23,69 @@ pub enum Prompts { #[arg(long, default_value = None)] restore_prompt: Option, }, + + CachePrompt { + /// Stores a cached prompt at the given path. The same prompt can then be + /// loaded from disk using --restore-prompt + #[arg(long, default_value = None)] + cache_prompt: Option, + }, } impl Prompts { + fn cache_current_prompt( + &self, + session: &InferenceSession, + model: &Model, + vocab: &Vocabulary, + inference_params: &InferenceParameters, + ) { + // TODO: refactor this to decouple model generation and prompt creation + // TODO: check run model then store prompt if successful + let res = session.feed_prompt::( + &model, + &vocab, + &inference_params, + self.cache_prompt, + |t| { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(()) + }, + ); + + println!(); + + match res { + Ok(_) => (), + Err(InferenceError::ContextFull) => { + log::warn!( + "Context is not large enough to fit the prompt. Saving intermediate state." + ); + } + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + } + + // Write the memory to the cache file + // SAFETY: no other model functions used inside the block + unsafe { + let memory = session.get_snapshot(); + match memory.write_to_disk(self.cache_prompt) { + Ok(_) => { + log::info!( + "Successfully written prompt cache to {0}", + self.cache_prompt + ); + } + Err(err) => { + eprintln!("Could not restore prompt. Error: {err}"); + std::process::exit(1); + } + } + } + } + fn read_prompt_from_file(&self) -> Result { match std::fs::read_to_string(self.prompt_file) { Ok(prompt) => Ok(prompt), @@ -42,39 +104,10 @@ impl Prompts { } fn create_prompt(&self) -> Result { - // interactive, repl, cache_prompt - // if just plain prompt file or prompt fun this - Ok(self.prompt); - } - - fn create_session(&self, session: &InferenceSession) { - let res = session.inference_with_prompt::( - &model, - &vocab, - &inference_params, - &prompt, - args.num_predict, - &mut rng, - |t| { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(()) - }, - ); - - println!(); - - match res { - Ok(stats) => { - println!("{}", stats); - } - Err(llama_rs::InferenceError::ContextFull) => { - log::warn!("Context window full, stopping inference.") - } - Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), + match self.prompt { + Some(prompt) => Ok(prompt), + None => {} } - Ok(()) } fn restore_previous_prompt(&self, model: &Model) -> Result { @@ -98,6 +131,7 @@ impl Prompts { Self::Prompt { prompt } => self.create_prompt(), Self::PromptFile { prompt_file } => self.create_prompt(), Self::RestorePrompt { restore_prompt } => self.restore_previous_prompt(), + Self::CachePrompt { cache_prompt } => self.cache_current_prompt(), } } }