Skip to content

Commit

Permalink
Merge pull request #510 from brittlewis12/context-and-model-enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn authored Sep 28, 2024
2 parents 0ebae0b + 7d1b2d5 commit 1466f7e
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 22 deletions.
4 changes: 2 additions & 2 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ either reduce n_len or increase n_ctx"
while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);

// is it an end of stream?
if new_token_id == model.token_eos() {
if model.is_eog_token(new_token_id) {
eprintln!();
break;
}
Expand Down
108 changes: 88 additions & 20 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@

use crate::context::LlamaContext;
use std::ffi::c_int;
use std::num::NonZeroU8;
use std::num::{NonZeroU8, TryFromIntError};

/// Errors that can occur when attempting to prepare values for the kv cache
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum KvCacheConversionError {
/// Sequence id conversion to i32 failed
#[error("Provided sequence id is too large for a i32")]
SeqIdTooLarge(#[source] TryFromIntError),
/// Position 0 conversion to i32 failed
#[error("Provided start position is too large for a i32")]
P0TooLarge(#[source] TryFromIntError),
/// Position 1 conversion to i32 failed
#[error("Provided end position is too large for a i32")]
P1TooLarge(#[source] TryFromIntError),
}

impl LlamaContext<'_> {
/// Copy the cache from one sequence to another.
Expand All @@ -18,33 +32,63 @@ impl LlamaContext<'_> {

/// Copy the cache from one sequence to another.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the either position exceeds
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
dest: i32,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
Ok(())
}

/// Clear the kv cache for the given sequence.
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the sequence id or
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for.
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
}
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
Expand Down Expand Up @@ -73,25 +117,44 @@ impl LlamaContext<'_> {
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
p0: Option<u32>,
p1: Option<u32>,
delta: i32,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
Ok(())
}

/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
Expand All @@ -101,14 +164,19 @@ impl LlamaContext<'_> {
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
p0: Option<u16>,
p1: Option<u16>,
p0: Option<u32>,
p1: Option<u32>,
d: NonZeroU8,
) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
Ok(())
}

/// Returns the largest position present in the KV cache for the specified sequence
Expand Down
60 changes: 60 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,66 @@ impl LlamaContextParams {
self.context_params.n_ubatch
}

/// Set the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_flash_attention(true);
/// assert_eq!(params.flash_attention(), true);
/// ```
#[must_use]
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
self.context_params.flash_attn = enabled;
self
}

/// Get the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.flash_attention(), false);
/// ```
#[must_use]
pub fn flash_attention(&self) -> bool {
self.context_params.flash_attn
}

/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_offload_kqv(false);
/// assert_eq!(params.offload_kqv(), false);
/// ```
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}

/// Get the `offload_kqv` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.offload_kqv(), true);
/// ```
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}

/// Set the type of rope scaling.
///
/// # Examples
Expand Down
6 changes: 6 additions & 0 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ impl LlamaModel {
LlamaToken(token)
}

/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
#[must_use]
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
}

/// Get the decoder start token token.
#[must_use]
pub fn decode_start_token(&self) -> LlamaToken {
Expand Down
33 changes: 33 additions & 0 deletions llama-cpp-2/src/token/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,37 @@ impl LlamaTokenDataArray {
*mu = unsafe { *mu_ptr };
LlamaToken(token)
}

/// Mirostat 1.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words.
///
/// # Parameters
///
/// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// * `m` The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
pub fn sample_token_mirostat_v1(
&mut self,
ctx: &mut LlamaContext,
tau: f32,
eta: f32,
m: i32,
mu: &mut f32,
) -> LlamaToken {
let mu_ptr = ptr::from_mut(mu);
let token = unsafe {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_token_mirostat(
ctx.context.as_ptr(),
c_llama_token_data_array,
tau,
eta,
m,
mu_ptr,
)
})
};
*mu = unsafe { *mu_ptr };
LlamaToken(token)
}
}

0 comments on commit 1466f7e

Please sign in to comment.