Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BoolQ Benchmark #802

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,428 changes: 9,428 additions & 0 deletions LLama.Examples/Assets/BoolQ/train.csv

Large diffs are not rendered by default.

3,271 changes: 3,271 additions & 0 deletions LLama.Examples/Assets/BoolQ/validation.csv

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class ExampleRunner
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
Expand Down
248 changes: 248 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorBoolQ.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
using System.Text;
using LLama.Batched;
using LLama.Common;
using LLama.Grammars;
using LLama.Native;
using Spectre.Console;
using LLama.Sampling;

namespace LLama.Examples.Examples;

public class BatchedExecutorBoolQ
{
public static async Task Run()
{
// Load model weights
var parameters = new ModelParams(UserSettings.GetModelPath());
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

const int tokensGenerate = 8;
var batchSize = AnsiConsole.Ask("How many parallel conversations to evaluate in a batch", 64);
var sys = AnsiConsole.Ask("System prompt", "Answer the question with a single word answer.");
var hint = AnsiConsole.Ask("Provide hints to model (test reading comprehension instead of knowledge)", true);

// Answers may start with a space, and then must produce one of the listed strings followed by a newline character and nothing else.
var grammar = Grammar.Parse("root ::= (\" \")? (\"true\" | \"false\" | \"yes\" | \"no\") \"\\n\"", "root");

// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Load dataset
var data = new List<(string, bool, string)>();
if (AnsiConsole.Ask("Load training dataset?", false))
data.AddRange(LoadData("Assets/BoolQ/train.csv"));
if (AnsiConsole.Ask("Load validation dataset?", true))
data.AddRange(LoadData("Assets/BoolQ/validation.csv"));
AnsiConsole.MarkupLineInterpolated($"Loaded Dataset: {data.Count} questions");
var limit = AnsiConsole.Ask("Limit dataset size", 1000);
if (data.Count > limit)
data = data.Take(limit).ToList();

// Process data in batches
var chunks = data.Chunk(batchSize).ToArray();
var results = new List<BatchResult>();
await AnsiConsole.Progress()
.Columns(new SpinnerColumn(Spinner.Known.Dots8Bit), new PercentageColumn(), new ProgressBarColumn(), new RemainingTimeColumn())
.StartAsync(async ctx =>
{
var reporter = ctx.AddTask("Processing Chunks", maxValue: chunks.Length);

foreach (var chunk in chunks)
{
var result = await RunBatch(executor, tokensGenerate, grammar, sys, hint, chunk);
results.Add(result);

reporter.Increment(1);

AnsiConsole.MarkupLineInterpolated($"[green]{result.TruePositive + result.TrueNegative}[/] / [red]{chunk.Length}[/] ({result.Accuracy:P})");
}
});

// Print final results
var correct = (from result in results select result.Correct).Sum();
var total = data.Count;
var accuracy = (float)correct / total;
AnsiConsole.WriteLine();
AnsiConsole.MarkupInterpolated($"Final Result: [green]{correct}[/] / [red]{total}[/] ({accuracy:P})");

}

private static IEnumerable<(string, bool, string)> LoadData(string path)
{
foreach (var line in File.ReadLines(path))
{
var splits = line.Split(",");

if (!bool.TryParse(splits[1], out var boolean))
continue;

var hint = string.Join(",", splits[2..]);
hint = hint.Trim('\"');

yield return (splits[0], boolean, hint);
}
}

private static async Task<BatchResult> RunBatch(BatchedExecutor executor, int maxTokens, Grammar grammar, string sys, bool hint, IEnumerable<(string, bool, string)> batch)
{
var conversations = (from item in batch
select new ConversationRunner(executor, grammar, sys, item.Item1, item.Item2, hint ? item.Item3 : null)).ToArray();

for (var i = 0; i < maxTokens; i++)
{
if (executor.BatchQueueCount > 1)
AnsiConsole.MarkupLineInterpolated($"Batch Queue: {executor.BatchQueueCount} ({i})");

// Process the entire queue of batching waiting to be processed
while (executor.BatchQueueCount > 0)
{
var result = await executor.Infer();
if (result != DecodeResult.Ok)
throw new NotImplementedException($"Decode failed: {result}");

foreach (var item in conversations)
item.Sample();
}

// Prompt each conversation that just sampled a token
foreach (var item in conversations)
item.Prompt();
}

int tp = 0, tn = 0, fp = 0, fn = 0;
foreach (var item in conversations)
{
item.Result(ref tp, ref tn, ref fp, ref fn);
item.Dispose();
}

return new BatchResult(tp, tn, fp, fn);
}

private record BatchResult(int TruePositive, int TrueNegative, int FalsePositive, int FalseNegative)
{
public int Correct => TruePositive + TrueNegative;
public int Incorrect => FalsePositive + FalseNegative;

public int TotalPositives = TruePositive + FalseNegative;
public int TotalNegatives = TrueNegative + FalsePositive;
public int Total => Correct + Incorrect;

public float Accuracy => (float)Correct / Total;
}

private class ConversationRunner
: IDisposable
{
private readonly BatchedExecutor _executor;
private readonly StreamingTokenDecoder _decoder;
private readonly ISamplingPipeline _sampler;

private readonly Conversation _conversation;
private bool _finished;
private LLamaToken? _sampledToken;

public string Question { get; }
public bool Answer { get; }

public ConversationRunner(BatchedExecutor executor, Grammar grammar, string sys, string question, bool answer, string? hint)
{
_executor = executor;
_decoder = new StreamingTokenDecoder(executor.Context);
_sampler = new GreedySamplingPipeline
{
Grammar = grammar.CreateInstance(),
};

// Make sure question ends with question mark
if (!question.EndsWith('?'))
question += '?';

// Prepend hint if necessary
if (hint != null)
{
if (!hint.EndsWith('.'))
hint += '.';
question = $"{hint}\n{question}";
}

// Template the question
var template = new LLamaTemplate(executor.Model);
template.Add("system", sys);
template.Add("user", question);
template.AddAssistant = true;
var templatedQuestion = Encoding.UTF8.GetString(template.Apply());

// Prompt
_conversation = executor.Create();
_conversation.Prompt(_executor.Context.Tokenize(templatedQuestion));

Question = question;
Answer = answer;
}

public void Sample()
{
if (_finished)
return;
if (!_conversation.RequiresSampling)
return;

var token = _sampler.Sample(_executor.Context.NativeHandle, _conversation.Sample(), []);

var tokens = _executor.Context.NativeHandle.ModelHandle.Tokens;
if (tokens.IsEndOfGeneration(token) || tokens.Newline == token)
{
_sampledToken = default;
_finished = true;
}
else
{
_sampledToken = token;
}
}

public void Prompt()
{
if (_finished)
return;
if (!_sampledToken.HasValue)
return;

var token = _sampledToken.Value;
_sampledToken = default;

_sampler.Accept(_executor.Context.NativeHandle, token);
_decoder.Add(token);
_conversation.Prompt(token);
}

public void Result(ref int tp, ref int tn, ref int fp, ref int fn)
{
var str = _decoder.Read().Trim();
var result = str switch
{
"true" or "yes" => true,
_ => false,
};

switch (Answer, result)
{
case (true, true): tp++; break;
case (true, false): fn++; break;
case (false, true): fp++; break;
case (false, false): tn++; break;
}
}

public void Dispose()
{
_conversation.Dispose();
_sampler.Dispose();
}
}
}
6 changes: 6 additions & 0 deletions LLama.Examples/LLama.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
</ItemGroup>

<ItemGroup>
<None Update="Assets\BoolQ\train.csv">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\BoolQ\validation.csv">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\chat-with-bob.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
2 changes: 1 addition & 1 deletion LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaContext context, State
_context = context;
_executor = executor ?? new StatelessExecutor(_weights, _context.Params);
_defaultInferenceParams = inferenceParams;
MaxTokenTotal = (int)_context.Params.ContextSize;
MaxTokenTotal = (int)_context.ContextSize;
}

/// <inheritdoc/>
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ public DecodeResult Decode(LLamaBatch batch)
{
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > Params.BatchSize)
if (batch.TokenCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
Expand All @@ -550,7 +550,7 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
{
if (batch.EmbeddingsCount == 0)
return 0;
if (batch.EmbeddingsCount > Params.BatchSize)
if (batch.EmbeddingsCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Native;
using LLama.Native;
using System;
using LLama.Exceptions;
using LLama.Abstractions;
Expand Down Expand Up @@ -67,7 +67,7 @@ public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationT
// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
var batchSize = (int)Context.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -151,7 +151,7 @@
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down Expand Up @@ -252,7 +252,7 @@
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= Context.Params.BatchSize)
if (_embeds.Count >= Context.BatchSize)
{
break;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -159,7 +159,7 @@
{
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));

Check warning on line 162 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Dereference of a possibly null reference.

Check warning on line 162 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Dereference of a possibly null reference.
}

int imageIndex = text.IndexOf("<image>");
Expand Down Expand Up @@ -196,11 +196,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 199 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 199 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 203 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 203 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand Down Expand Up @@ -339,7 +339,7 @@
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= Context.Params.BatchSize)
if (_embeds.Count >= Context.BatchSize)
{
break;
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/SafeLlavaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads
/// <returns>True on success</returns>
public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past)
{
return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past );
return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.BatchSize, ref n_past );
}

#region native API
Expand Down
Loading