Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorGuidance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ await AnsiConsole
guidance.Prompt(g);

// Early exit if we reach the natural end of the guided sentence
if (g == model.Tokens.EOS)
if (model.Tokens.IsEndOfGeneration(g))
break;

// Update progress bar
Expand Down
9 changes: 7 additions & 2 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ public class ModelOptions
/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

public uint SeqMax { get; }
/// <inheritdoc />
public uint SeqMax { get; set; }

/// <inheritdoc />
public uint Seed { get; set; } = 1686349486;

public bool Embeddings { get; }
/// <inheritdoc />
public bool Embeddings { get; set; }

/// <inheritdoc />
public bool UseMemorymap { get; set; } = true;
Expand Down Expand Up @@ -102,6 +104,9 @@ public class ModelOptions
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }

/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;

Expand Down
5 changes: 5 additions & 0 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public interface IContextParams
/// </summary>
bool NoKqvOffload { get; }

/// <summary>
/// Whether to use flash attention
/// </summary>
bool FlashAttention { get; }

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
Expand Down
24 changes: 24 additions & 0 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
Expand Down Expand Up @@ -241,6 +242,7 @@ public sealed record MetadataOverride
private readonly int _valueInt;
private readonly float _valueFloat;
private readonly bool _valueBool;
private readonly byte[]? _valueString;

/// <summary>
/// Create a new override for an int key
Expand Down Expand Up @@ -278,6 +280,21 @@ public MetadataOverride(string key, bool value)
Type = LLamaModelKvOverrideType.Bool;
}

/// <summary>
/// Create a new override for a string key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
public MetadataOverride(string key, string value)
{
Key = key;
_valueString = Encoding.UTF8.GetBytes(value);
Type = LLamaModelKvOverrideType.String;

if (_valueString.Length > 128)
throw new ArgumentException("Value string is too long, must be < 128 UTF8 bytes", nameof(value));
}

internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
Expand All @@ -291,6 +308,13 @@ internal void WriteValue(ref LLamaModelMetadataOverride dest)
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
case LLamaModelKvOverrideType.String:
unsafe
{
fixed (byte* strValPtr = dest.StringValue)
new Span<byte>(_valueString!).CopyTo(new Span<byte>(strValPtr, 128));
}
break;
default:
throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
}
Expand Down
3 changes: 3 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ public record ModelParams
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

Expand Down
1 change: 1 addition & 0 deletions LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
result.flash_attention = @params.FlashAttention;
result.llama_pooling_type = @params.PoolingType;

result.n_threads = Threads(@params.Threads);
Expand Down
5 changes: 3 additions & 2 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using LLama.Exceptions;
using LLama.Native;
Expand Down Expand Up @@ -123,8 +124,8 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
);
}

// Check if this is the EOS token
if (id == _weights.Tokens.EOS)
// Check if this token should end generation
if (_weights.Tokens.IsEndOfGeneration(id))
break;

// Decode this token into text
Expand Down
10 changes: 10 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ public bool offload_kqv
}
private sbyte _offload_kqv;

/// <summary>
/// whether to use flash attention
/// </summary>
public bool flash_attention
{
readonly get => Convert.ToBoolean(_flash_attention);
set => _flash_attention = Convert.ToSByte(value);
}
private sbyte _flash_attention;

//todo: implement abort callback support
/// <summary>
/// ggml_abort_callback
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ public enum LLamaFtype
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ1_M = 31,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_BF16 = 32,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions LLama/Native/LLamaModelMetadataOverride.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public unsafe struct LLamaModelMetadataOverride
/// </summary>
[FieldOffset(136)]
public long BoolValue;

/// <summary>
/// Value, **must** only be used if Tag == String
/// </summary>
[FieldOffset(136)]
public fixed byte StringValue[128];
}

/// <summary>
Expand All @@ -65,4 +71,9 @@ public enum LLamaModelKvOverrideType
/// Overriding a bool value
/// </summary>
Bool = 2,

/// <summary>
/// Overriding a string value
/// </summary>
String = 3,
}
10 changes: 10 additions & 0 deletions LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ public bool use_mlock
}
private sbyte _use_mlock;

/// <summary>
/// validate model tensor data
/// </summary>
public bool check_tensors
{
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;

/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions LLama/Native/LLamaModelQuantizeParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ public bool pure
}
private sbyte _pure;

/// <summary>
/// quantize to the same number of shards
/// </summary>
public bool keep_split
{
get => Convert.ToBoolean(_keep_split);
set => _keep_split = Convert.ToSByte(value);
}
private sbyte _keep_split;

/// <summary>
/// pointer to importance matrix data
/// </summary>
Expand Down
17 changes: 17 additions & 0 deletions LLama/Native/LLamaVocabPreType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_vocab_pre_type</remarks>
internal enum LLamaVocabPreType
{
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
}
5 changes: 3 additions & 2 deletions LLama/Native/NativeApi.LLava.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static unsafe partial class NativeApi
/// <param name="ctxClip">Llava Model</param>
/// <returns>True if validate successfully</returns>
[DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip);

/// <summary>
Expand Down Expand Up @@ -56,7 +57,7 @@ SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHan
/// <param name="embed">Embedding handle</param>
/// <returns>True on success</returns>
[DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed,
int n_batch, ref int n_past);
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past);

}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.Sampling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<
public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);

/// <summary>
/// Randomly selects a token from the candidates based on their probabilities.
/// Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
Expand Down
23 changes: 19 additions & 4 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@ public static void llama_empty_call()
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mmap();

/// <summary>
/// Check if memory locking is supported
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mlock();

/// <summary>
/// Check if GPU offload is supported
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_gpu_offload();

/// <summary>
Expand Down Expand Up @@ -77,6 +80,7 @@ public static void llama_empty_call()
/// <param name="n_token_count_out"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);

/// <summary>
Expand All @@ -88,6 +92,7 @@ public static void llama_empty_call()
/// <param name="n_token_count"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
Expand Down Expand Up @@ -133,6 +138,14 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the pooling type for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the embeddings for the a specific sequence.
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
Expand Down Expand Up @@ -218,19 +231,20 @@ public static void llama_empty_call()
/// <param name="model"></param>
/// <param name="llamaToken"></param>
/// <param name="buffer">buffer to write string into</param>
/// <param name="special">If true, special tokens are rendered in the output</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer)
public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer, bool special)
{
unsafe
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length, special);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length);
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length, bool special);
}

/// <summary>
Expand Down Expand Up @@ -260,7 +274,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
}

/// <summary>
/// Clear the KV cache
/// Clear the KV cache. Both cell info is erased and KV data is zeroed
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
Expand All @@ -275,6 +289,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="p1"></param>
/// <returns>Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);

/// <summary>
Expand Down
Loading