-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from tracel-ai/llama
Llama
- Loading branch information
Showing
17 changed files
with
1,676 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
[package] | ||
authors = ["guillaumelagrange <lagrange.guillaume.1@gmail.com>"] | ||
license = "MIT OR Apache-2.0" | ||
name = "llama-burn" | ||
version = "0.1.0" | ||
edition = "2021" | ||
description = "Llama 3 large language model with Burn" | ||
|
||
[features] | ||
default = ["pretrained"] | ||
pretrained = ["burn/network", "dep:dirs"] | ||
|
||
llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"] | ||
tiny = ["dep:tokenizers"] | ||
|
||
# Example feature flags (backend selection) | ||
tch-cpu = ["burn/tch"] | ||
tch-gpu = ["burn/tch"] | ||
wgpu = ["burn/wgpu"] | ||
|
||
[dependencies] | ||
# Note: default-features = false is needed to disable std | ||
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false } | ||
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" } | ||
itertools = { version = "0.12.1", default-features = false, features = [ | ||
"use_alloc", | ||
] } | ||
dirs = { version = "5.0.1", optional = true } | ||
serde = { version = "1.0.192", default-features = false, features = [ | ||
"derive", | ||
"alloc", | ||
] } # alloc is for no_std, derive is needed | ||
|
||
# Tiktoken tokenizer (llama 3) | ||
tiktoken-rs = { version = "0.5.8", optional = true } | ||
base64 = { version = "0.22.1", optional = true } | ||
rustc-hash = { version = "1.1.0", optional = true } | ||
|
||
# SentencePiece tokenizer (tiny llama / llama 2) | ||
tokenizers = { version = "0.19.1", default-features = false, features = [ | ||
"onig", | ||
], optional = true } | ||
|
||
rand = { version = "0.8.5", default-features = false, features = [ | ||
"std_rng", | ||
] } # std_rng is for no_std | ||
|
||
[dev-dependencies] | ||
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" } | ||
clap = { version = "4.5.4", features = ["derive"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# NOTICES AND INFORMATION | ||
|
||
This file contains notices and information required by libraries that this repository copied or | ||
derived from. The use of the following resources complies with the licenses provided. | ||
|
||
## Implementation | ||
|
||
The model implementation was adapted from the original | ||
[Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the | ||
[Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE). | ||
|
||
The TinyLlama implementation is derived from the same code, but its weights and tokenizers were | ||
adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under | ||
the [Apache 2.0](https://github.com/jzhang38/TinyLlama/blob/main/LICENSE) open source license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Llama Burn | ||
|
||
<img src="./assets/llama-burn.jpeg" alt="An image of a llama surrounded by fiery colors and a gust of fire" width="500px"/> | ||
|
||
The popular Llama LLM is here! | ||
|
||
This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and | ||
[TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding | ||
tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama | ||
variants in [src/llama.rs](src/llama.rs). | ||
|
||
## Usage | ||
|
||
### `Cargo.toml` | ||
|
||
Add this to your `Cargo.toml`: | ||
|
||
```toml | ||
[dependencies] | ||
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", default-features = false } | ||
``` | ||
|
||
If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are | ||
active), enable the corresponding feature flag. | ||
|
||
> **Important:** these features require `std`. Note that the weights have been saved in the binary | ||
> format, which is more compact and faster to save & load, but might not be compatible in future | ||
> versions if the Burn data schema were to evolve. | ||
#### Llama 3 | ||
|
||
```toml | ||
[dependencies] | ||
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["llama3"] } | ||
``` | ||
|
||
#### TinyLlama | ||
|
||
```toml | ||
[dependencies] | ||
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["tiny"] } | ||
``` | ||
|
||
### Example Usage | ||
|
||
The [chat completion example](examples/chat.rs) initializes a Llama model from the provided weights | ||
file and generates a sequence of text based on the input prompt. The instruction-tuned model is | ||
loaded for dialogue applications, so the prompt is automatically formatted for chat completion. | ||
|
||
The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`. | ||
|
||
| Argument | Description | | ||
| :-------------- | :------------------------------------------------------------------------------------------------------------- | | ||
| `-p` | The prompt or question to pass to the LLM (default: `"How many helicopters can a human eat in one sitting?"`). | | ||
| `-n` | The number of new tokens to generate (default: `50`). | | ||
| `--top-p` | Top-p probability threshold (default: `0.9`). | | ||
| `--temperature` | Temperature value for controlling randomness in sampling. (default: `0.6`). | | ||
| `--max-seq-len` | Maximum sequence length for input text. (default: `128`). | | ||
| `--seed` | The seed to use when generating random samples.. (default: `42`). | | ||
|
||
Any of the commands below can be used by appending any of the listed arguments by appending | ||
`[-- <arguments>]`. For example, you can provided your own prompt/question | ||
`-- -p "How many llamas does it take to change a lightbulb?"`. | ||
|
||
#### Llama 3 | ||
|
||
Using the `tch` backend with CUDA: | ||
|
||
```sh | ||
export TORCH_CUDA_VERSION=cu121 | ||
cargo run --release --features llama3,tch-gpu --example chat | ||
``` | ||
|
||
Using the `tch` backend with CPU: | ||
|
||
```sh | ||
cargo run --release --features llama3,tch-cpu --example chat | ||
``` | ||
|
||
Using the `wgpu` backend: | ||
|
||
```sh | ||
cargo run --release --features llama3,wgpu --example chat | ||
``` | ||
|
||
**Built with Meta Llama 3.** This example uses the | ||
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | ||
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is | ||
also available if you wish to use it in your application. | ||
|
||
#### TinyLlama | ||
|
||
Using the `tch` backend with CUDA: | ||
|
||
```sh | ||
export TORCH_CUDA_VERSION=cu121 | ||
cargo run --release --features tiny,tch-gpu --example chat | ||
``` | ||
|
||
Using the `tch` backend with CPU: | ||
|
||
```sh | ||
cargo run --release --features tiny,tch-cpu --example chat | ||
``` | ||
|
||
Using the `wgpu` backend: | ||
|
||
```sh | ||
cargo run --release --features tiny,wgpu --example chat | ||
``` | ||
|
||
This example uses the | ||
[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) | ||
instruction-tuned model based on the Llama2 architecture and tokenizer. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
use std::time::Instant; | ||
|
||
use burn::tensor::{backend::Backend, Device}; | ||
use clap::Parser; | ||
use llama_burn::{ | ||
llama::{Llama, LlamaConfig}, | ||
sampling::{Sampler, TopP}, | ||
tokenizer::Tokenizer, | ||
}; | ||
|
||
const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?"; | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(version, about, long_about = None)] | ||
pub struct Config { | ||
/// Top-p probability threshold. | ||
#[arg(long, default_value_t = 0.9)] | ||
top_p: f64, | ||
|
||
/// Temperature value for controlling randomness in sampling. | ||
#[arg(long, default_value_t = 0.6)] | ||
temperature: f64, | ||
|
||
/// Maximum sequence length for input text. | ||
#[arg(long, default_value_t = 128)] | ||
max_seq_len: usize, | ||
|
||
/// The number of new tokens to generate (i.e., the number of generation steps to take). | ||
#[arg(long, short = 'n', default_value_t = 50)] | ||
sample_len: usize, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 42)] | ||
seed: u64, | ||
|
||
/// The input prompt. | ||
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))] | ||
prompt: String, | ||
} | ||
|
||
pub fn generate<B: Backend, T: Tokenizer>( | ||
llama: &mut Llama<B, T>, | ||
prompt: &str, | ||
sample_len: usize, | ||
temperature: f64, | ||
sampler: &mut Sampler, | ||
) { | ||
let now = Instant::now(); | ||
let generated = llama.generate(prompt, sample_len, temperature, sampler); | ||
let elapsed = now.elapsed().as_secs(); | ||
|
||
println!("> {}\n", generated.text); | ||
println!( | ||
"{} tokens generated ({:.4} tokens/s)\n", | ||
generated.tokens, | ||
generated.tokens as f64 / generated.time | ||
); | ||
|
||
println!( | ||
"Generation completed in {}m{}s", | ||
(elapsed / 60), | ||
elapsed % 60 | ||
); | ||
} | ||
|
||
pub fn chat<B: Backend>(args: Config, device: Device<B>) { | ||
let mut prompt = args.prompt; | ||
|
||
// Sampling strategy | ||
let mut sampler = if args.temperature > 0.0 { | ||
Sampler::TopP(TopP::new(args.top_p, args.seed)) | ||
} else { | ||
Sampler::Argmax | ||
}; | ||
|
||
#[cfg(feature = "tiny")] | ||
{ | ||
// TinyLlama-1.1B Chat v1.0 | ||
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap(); | ||
println!("Processing prompt: {}", prompt); | ||
|
||
// Prompt formatting for chat model | ||
prompt = format!( | ||
"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" | ||
); | ||
|
||
generate( | ||
&mut llama, | ||
&prompt, | ||
args.sample_len, | ||
args.temperature, | ||
&mut sampler, | ||
); | ||
} | ||
|
||
#[cfg(feature = "llama3")] | ||
{ | ||
// Llama-3-8B-Instruct | ||
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(true, &device).unwrap(); | ||
println!("Processing prompt: {}", prompt); | ||
|
||
// Prompt formatting for chat model | ||
prompt = format!( | ||
"<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
); | ||
|
||
generate( | ||
&mut llama, | ||
&prompt, | ||
args.sample_len, | ||
args.temperature, | ||
&mut sampler, | ||
); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use super::*; | ||
use burn::{ | ||
backend::{libtorch::LibTorchDevice, LibTorch}, | ||
tensor::f16, | ||
}; | ||
|
||
pub fn run(args: Config) { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
chat::<LibTorch<f16>>(args, device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-cpu")] | ||
mod tch_cpu { | ||
use super::*; | ||
use burn::backend::{libtorch::LibTorchDevice, LibTorch}; | ||
|
||
pub fn run(args: Config) { | ||
let device = LibTorchDevice::Cpu; | ||
|
||
chat::<LibTorch>(args, device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use super::*; | ||
use burn::backend::wgpu::{Wgpu, WgpuDevice}; | ||
|
||
pub fn run(args: Config) { | ||
let device = WgpuDevice::default(); | ||
|
||
chat::<Wgpu>(args, device); | ||
} | ||
} | ||
|
||
pub fn main() { | ||
// Parse arguments | ||
let args = Config::parse(); | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(args); | ||
#[cfg(feature = "tch-cpu")] | ||
tch_cpu::run(args); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(args); | ||
} |
Oops, something went wrong.