Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Librarification #10

Merged
merged 12 commits into from
Mar 16, 2023
134 changes: 53 additions & 81 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
members = [
"ggml-raw",
"llama-rs",
"llama-cli"
]

resolver = "2"

[workspace.dependencies]
rand = "0.8.5"
1 change: 1 addition & 0 deletions ggml-raw/build.rs
philpax marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ fn main() {

// This is a very basic heuristic for applying compile flags.
// Feel free to update this to fit your operating system.

let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let is_release = env::var("PROFILE").unwrap() == "release";
let compiler = build.get_compiler();
Expand Down
17 changes: 17 additions & 0 deletions llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "llama-cli"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
clap = { version = "4.1.8", features = ["derive"] }
env_logger = "0.10.0"
log = "0.4"
once_cell = "1.17.1"
num_cpus = "1.15.0"

llama-rs = { path = "../llama-rs" }

rand = { workspace = true }
File renamed without changes.
105 changes: 105 additions & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use std::io::Write;

use cli_args::CLI_ARGS;
use llama_rs::InferenceParameters;
use rand::thread_rng;

mod cli_args;

fn main() {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.parse_default_env()
.init();

let args = &*CLI_ARGS;

let inference_params = InferenceParameters {
n_threads: args.num_threads as i32,
n_predict: args.num_predict,
n_batch: args.batch_size,
top_k: args.top_k as i32,
top_p: args.top_p,
repeat_last_n: args.repeat_last_n,
repeat_penalty: args.repeat_penalty,
temp: args.temp,
};

let prompt = if let Some(path) = &args.prompt_file {
match std::fs::read_to_string(path) {
Ok(prompt) => prompt,
Err(err) => {
eprintln!("Could not read prompt file at {path}. Error {err}");
std::process::exit(1);
}
}
} else if let Some(prompt) = &args.prompt {
prompt.clone()
} else {
eprintln!("No prompt or prompt file was provided. See --help");
std::process::exit(1);
};

let (model, vocab) =
llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| {
use llama_rs::LoadProgress;
match progress {
LoadProgress::HyperParamsLoaded(hparams) => {
log::debug!("Loaded HyperParams {hparams:#?}")
}
LoadProgress::BadToken { index } => {
log::info!("Warning: Bad token in vocab at index {index}")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::MemorySize { bytes, n_mem } => log::info!(
"Memory size: {} MB {}",
bytes as f32 / 1024.0 / 1024.0,
n_mem
),
LoadProgress::PartLoading {
file,
current_part,
total_parts,
} => log::info!(
"Loading model part {}/{} from '{}'\n",
current_part,
total_parts,
file.to_string_lossy(),
),
LoadProgress::PartTensorLoaded {
current_tensor,
tensor_count,
..
} => {
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
LoadProgress::PartLoaded {
file,
byte_size,
tensor_count,
} => {
log::info!("Loading of '{}' complete", file.to_string_lossy());
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
}
}
})
.expect("Could not load model");

log::info!("Model fully loaded!");

let mut rng = thread_rng();
model.inference_with_prompt(&vocab, &inference_params, &prompt, &mut rng, |t| {
print!("{t}");
std::io::stdout().flush().unwrap();
});
println!();
}
10 changes: 3 additions & 7 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = { version = "1.0.69", features = ["backtrace"] }
bimap = "0.6.2"
bytemuck = "1.13.1"
clap = { version = "4.1.8", features = ["derive"] }
ggml-raw = { path = "../ggml-raw" }
num_cpus = "1.15.0"
once_cell = "1.17.1"
partial_sort = "0.2.0"
rand = "0.8.5"
regex = "1.7.1"
thiserror = "1.0"

rand = { workspace = true }
Loading