Skip to content

Commit d4bed4b

Browse files
committed
First pass at updating binaries. This works, but has safety issues with SafeLLamaSamplerChainHandle/SafeLLamaSamplerHandle
1 parent c9a9b75 commit d4bed4b

File tree

75 files changed

+958
-2644
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+958
-2644
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ public class ExampleRunner
3131
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
3232
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
3333
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
34-
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
3534
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
3635
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
3736
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },

LLama.Examples/Examples/BatchedExecutorBeamSearch.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ from beam in oldBeam.Sample(beamsCount)
5555
while (beams.Count > beamsCount)
5656
{
5757
var beam = beams[0];
58-
AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam {beam.Conversation.ConversationId} (prob:{beam.CumulativeProbability:P10})[/]: {beam}");
58+
59+
var text = beam.ToString().EscapeMarkup();
60+
AnsiConsole.MarkupLine($"[red]Culling Beam {beam.Conversation.ConversationId} (prob:{beam.CumulativeProbability:P5})[/]: {text}");
5961

6062
beam.Dispose();
6163
beams.RemoveAt(0);
@@ -121,7 +123,7 @@ public List<Beam> Sample(int nbeams)
121123
{
122124
// Apply softmax, this calculates probabilities and sorts tokens into descending order
123125
var logitsArr = LLamaTokenDataArray.Create(Conversation.Sample());
124-
logitsArr.Softmax(Conversation.Executor.Context.NativeHandle);
126+
logitsArr.Softmax();
125127

126128
// Create new forked conversations, one for each beam
127129
var results = new List<Beam>();
@@ -135,14 +137,14 @@ public List<Beam> Sample(int nbeams)
135137
var c = Conversation.Fork();
136138

137139
// Extend the conversation with the selected token.
138-
c.Prompt(item.id);
140+
c.Prompt(item.ID);
139141

140142
// Keep track of the cumulative probability of this entire sequence.
141-
var p = CumulativeProbability * item.p;
143+
var p = CumulativeProbability * item.Probability;
142144

143145
// Keep track of all tokens in this sequence, for decoding later
144146
var t = Tokens.ToList();
145-
t.Add(item.id);
147+
t.Add(item.ID);
146148

147149
// Keep track of which beam this beam was derived from.
148150
var s = Sequence.ToList();

LLama.Examples/Examples/BatchedExecutorBoolQ.cs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using System.Text;
22
using LLama.Batched;
33
using LLama.Common;
4-
using LLama.Grammars;
54
using LLama.Native;
65
using Spectre.Console;
76
using LLama.Sampling;
@@ -10,6 +9,9 @@ namespace LLama.Examples.Examples;
109

1110
public class BatchedExecutorBoolQ
1211
{
12+
// Answers may start with a space, and then must produce one of the listed strings followed by a newline character and nothing else.
13+
private static readonly Grammar AnswerGrammar = new("root ::= (\" \")? (\"true\" | \"false\" | \"yes\" | \"no\") \"\\n\"", "root");
14+
1315
public static async Task Run()
1416
{
1517
// Load model weights
@@ -21,9 +23,6 @@ public static async Task Run()
2123
var sys = AnsiConsole.Ask("System prompt", "Answer the question with a single word answer.");
2224
var hint = AnsiConsole.Ask("Provide hints to model (test reading comprehension instead of knowledge)", true);
2325

24-
// Answers may start with a space, and then must produce one of the listed strings followed by a newline character and nothing else.
25-
var grammar = Grammar.Parse("root ::= (\" \")? (\"true\" | \"false\" | \"yes\" | \"no\") \"\\n\"", "root");
26-
2726
// Create an executor that can evaluate a batch of conversations together
2827
using var executor = new BatchedExecutor(model, parameters);
2928

@@ -53,7 +52,7 @@ await AnsiConsole.Progress()
5352

5453
foreach (var chunk in chunks)
5554
{
56-
var result = await RunBatch(executor, tokensGenerate, grammar, sys, hint, chunk);
55+
var result = await RunBatch(executor, tokensGenerate, sys, hint, chunk);
5756
results.Add(result);
5857

5958
reporter.Increment(1);
@@ -87,10 +86,10 @@ await AnsiConsole.Progress()
8786
}
8887
}
8988

90-
private static async Task<BatchResult> RunBatch(BatchedExecutor executor, int maxTokens, Grammar grammar, string sys, bool hint, IEnumerable<(string, bool, string)> batch)
89+
private static async Task<BatchResult> RunBatch(BatchedExecutor executor, int maxTokens, string sys, bool hint, IEnumerable<(string, bool, string)> batch)
9190
{
9291
var conversations = (from item in batch
93-
select new ConversationRunner(executor, grammar, sys, item.Item1, item.Item2, hint ? item.Item3 : null)).ToArray();
92+
select new ConversationRunner(executor, sys, item.Item1, item.Item2, hint ? item.Item3 : null)).ToArray();
9493

9594
for (var i = 0; i < maxTokens; i++)
9695
{
@@ -135,6 +134,9 @@ private record BatchResult(int TruePositive, int TrueNegative, int FalsePositive
135134
public float Accuracy => (float)Correct / Total;
136135
}
137136

137+
/// <summary>
138+
/// All of the mechanics necessary to run a conversation to answer a single question
139+
/// </summary>
138140
private class ConversationRunner
139141
: IDisposable
140142
{
@@ -149,14 +151,11 @@ private class ConversationRunner
149151
public string Question { get; }
150152
public bool Answer { get; }
151153

152-
public ConversationRunner(BatchedExecutor executor, Grammar grammar, string sys, string question, bool answer, string? hint)
154+
public ConversationRunner(BatchedExecutor executor, string sys, string question, bool answer, string? hint)
153155
{
154156
_executor = executor;
155157
_decoder = new StreamingTokenDecoder(executor.Context);
156-
_sampler = new GreedySamplingPipeline
157-
{
158-
Grammar = grammar.CreateInstance(),
159-
};
158+
_sampler = new GreedySamplingWithGrammarPipeline { Grammar = AnswerGrammar };
160159

161160
// Make sure question ends with question mark
162161
if (!question.EndsWith('?'))
@@ -192,7 +191,7 @@ public void Sample()
192191
if (!_conversation.RequiresSampling)
193192
return;
194193

195-
var token = _sampler.Sample(_executor.Context.NativeHandle, _conversation.Sample(), []);
194+
var token = _sampler.Sample(_executor.Context, _conversation.GetSampleIndex());
196195

197196
var tokens = _executor.Context.NativeHandle.ModelHandle.Tokens;
198197
if (tokens.IsEndOfGeneration(token) || tokens.Newline == token)
@@ -216,7 +215,7 @@ public void Prompt()
216215
var token = _sampledToken.Value;
217216
_sampledToken = default;
218217

219-
_sampler.Accept(_executor.Context.NativeHandle, token);
218+
_sampler.Accept(token);
220219
_decoder.Add(token);
221220
_conversation.Prompt(token);
222221
}
@@ -245,4 +244,21 @@ public void Dispose()
245244
_sampler.Dispose();
246245
}
247246
}
247+
248+
/// <summary>
249+
/// A sampling pipeline which always selects the most likely token (after applying a grammar)
250+
/// </summary>
251+
public class GreedySamplingWithGrammarPipeline
252+
: BaseSamplingPipeline
253+
{
254+
public required Grammar Grammar { get; init; }
255+
256+
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
257+
{
258+
var chain = SafeLLamaSamplerHandle.CreateChain(LLamaSamplerChainParams.Default());
259+
chain.Add(SafeLLamaSamplerHandle.CreateGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root));
260+
chain.Add(SafeLLamaSamplerHandle.CreateGreedySampler());
261+
return chain;
262+
}
263+
}
248264
}

LLama.Examples/Examples/BatchedExecutorFork.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using LLama.Batched;
22
using LLama.Common;
3-
using LLama.Native;
43
using LLama.Sampling;
54
using Spectre.Console;
65

@@ -77,9 +76,7 @@ await AnsiConsole
7776
// Print some stats
7877
var timings = executor.Context.NativeHandle.GetTimings();
7978
AnsiConsole.MarkupLine($"Total Tokens Evaluated: {timings.TokensEvaluated}");
80-
AnsiConsole.MarkupLine($"Total Tokens Sampled: {timings.TokensSampled}");
8179
AnsiConsole.MarkupLine($"Eval Time: {(timings.Eval + timings.PromptEval).TotalMilliseconds}ms");
82-
AnsiConsole.MarkupLine($"Sample Time: {timings.Sampling.TotalMilliseconds}ms");
8380
}
8481

8582
private class Node
@@ -114,8 +111,7 @@ public void Sample()
114111

115112
// Sample one token
116113
var ctx = _conversation.Executor.Context.NativeHandle;
117-
var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
118-
_sampler.Accept(ctx, token);
114+
var token = _sampler.Sample(ctx, _conversation.GetSampleIndex());
119115
_decoder.Add(token);
120116

121117
// Prompt the conversation with this token, to continue generating from there

LLama.Examples/Examples/BatchedExecutorGuidance.cs

Lines changed: 0 additions & 123 deletions
This file was deleted.

LLama.Examples/Examples/BatchedExecutorLLava.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ await AnsiConsole
7575

7676
await executor.Infer();
7777

78-
var token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
78+
var token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex());
7979
if (executor.Context.NativeHandle.ModelHandle.Tokens.IsEndOfGeneration(token))
8080
break;
8181

LLama.Examples/Examples/BatchedExecutorRewind.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private class Node
9191

9292
public LLamaToken Sample(Conversation conversation)
9393
{
94-
var token = _sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.Sample(), []);
94+
var token = _sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex());
9595
_tokens.Add(token);
9696
return token;
9797
}

LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LLama.Batched;
1+
using LLama.Batched;
22
using LLama.Common;
33
using LLama.Native;
44
using LLama.Sampling;
@@ -94,7 +94,7 @@ private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, C
9494
await executor.Infer();
9595

9696
// Use sampling pipeline to pick a token
97-
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty);
97+
token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex());
9898

9999
// Add it to the decoder, so it can be converted into text later
100100
decoder.Add(token);

LLama.Examples/Examples/ChatChineseGB2312.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Text;
1+
using System.Text;
22
using LLama.Common;
33

44
namespace LLama.Examples.Examples;
@@ -27,7 +27,6 @@ public static async Task Run()
2727

2828
var parameters = new ModelParams(modelPath)
2929
{
30-
Seed = 1337,
3130
GpuLayerCount = 5,
3231
Encoding = Encoding.UTF8
3332
};

LLama.Examples/Examples/ChatSessionStripRoleName.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ public static async Task Run()
1313

1414
var parameters = new ModelParams(modelPath)
1515
{
16-
Seed = 1337,
1716
GpuLayerCount = 5
1817
};
1918
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

LLama.Examples/Examples/ChatSessionWithHistory.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ public static async Task Run()
1111

1212
var parameters = new ModelParams(modelPath)
1313
{
14-
Seed = 1337,
1514
GpuLayerCount = 5
1615
};
1716
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

LLama.Examples/Examples/ChatSessionWithRestart.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ public static async Task Run()
1111

1212
var parameters = new ModelParams(modelPath)
1313
{
14-
Seed = 1337,
1514
GpuLayerCount = 5
1615
};
1716
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

LLama.Examples/Examples/ChatSessionWithRoleName.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ public static async Task Run()
1111

1212
var parameters = new ModelParams(modelPath)
1313
{
14-
Seed = 1337,
1514
GpuLayerCount = 5
1615
};
1716
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

0 commit comments

Comments
 (0)