Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ Use Batch Api for embeddings #13 #14

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# AI Assist

> AI assistant for coding, chat, code explanation, review with supporting local and online language models.
> `Context Aware` AI assistant for coding, chat, code explanation, review with supporting local and online language models.

`AIAssist` is compatible with [OpenAI](https://platform.openai.com/docs/api-reference/introduction) and [Azure AI Services](https://azure.microsoft.com/en-us/products/ai-services) through apis or [Ollama models](https://ollama.com/search) through [ollama engine](https://ollama.com/) locally.

> [!TIP]
> You can use ollama and its models that are more compatible with code like [deepseek-v2.5](https://ollama.com/library/deepseek-v2.5) or [qwen2.5-coder](https://ollama.com/library/qwen2.5-coder) locally. To use local models, you will need to run [Ollama](https://github.com/ollama/ollama) process first. For running ollama you can use [ollama docker](https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image) container.

Note: `vscode` and `jetbrains` plugins are in the plan and I will add them soon.

## Features

- ✅ `Context Aware` ai code assistant through [ai embeddings](src/AIAssistant/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs) which is based on Retrieval Augmented Generation (RAG) or [tree-sitter application summarization](src/AIAssistant/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs) to summarize application context and understanding by AI.
Expand Down
8 changes: 4 additions & 4 deletions src/AIAssistant/Commands/CodeAssistCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ public sealed class Settings : CommandSettings
[Description("[grey] the type of code assist. it can be `embedding` or `summary`.[/].")]
public CodeAssistType? CodeAssistType { get; set; }

[CommandOption("--threshold <threshold")]
[CommandOption("--threshold <threshold>")]
[Description("[grey] the threshold is a value for using in the `embedding`.[/].")]
public decimal? Threshold { get; set; }

[CommandOption("--temperature <temperature")]
[CommandOption("--temperature <temperature>")]
[Description(
"[grey] the temperature is a value for controlling creativity or randomness on the llm response.[/]."
)]
public decimal? Temperature { get; set; }

[CommandOption("--chat-api-key <key>")]
[CommandOption("--chat-api-key <chat-api-key>")]
[Description("[grey] the chat model api key.[/].")]
public string? ChatModelApiKey { get; set; }

Expand Down Expand Up @@ -159,7 +159,7 @@ await AnsiConsole

console.Write(new Rule());

userInput = "can you remove all comments from Add.cs file?";
//userInput = "can you remove all comments from Add.cs file?";
_running = await internalCommandProcessor.ProcessCommand(userInput, scope);
}

Expand Down
3 changes: 2 additions & 1 deletion src/AIAssistant/Contracts/IEmbeddingService.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections;
using AIAssistant.Chat.Models;
using AIAssistant.Data;
using AIAssistant.Dtos;
Expand All @@ -9,7 +10,7 @@ namespace AIAssistant.Contracts;
public interface IEmbeddingService
{
Task<AddEmbeddingsForFilesResult> AddOrUpdateEmbeddingsForFiles(
IEnumerable<CodeFileMap> codeFilesMap,
IList<CodeFileMap> codeFilesMap,
ChatSession chatSession
);
Task<GetRelatedEmbeddingsResult> GetRelatedEmbeddings(string userQuery, ChatSession chatSession);
Expand Down
2 changes: 1 addition & 1 deletion src/AIAssistant/Contracts/ILLMClientManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface ILLMClientManager
CancellationToken cancellationToken = default
);
Task<GetEmbeddingResult> GetEmbeddingAsync(
string input,
IList<string> inputs,
string? path,
CancellationToken cancellationToken = default
);
Expand Down
8 changes: 8 additions & 0 deletions src/AIAssistant/Dtos/GetBatchEmbeddingResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace AIAssistant.Dtos;

public class GetBatchEmbeddingResult(IList<IList<double>> embeddings, int totalTokensCount, decimal totalCost)
{
public IList<IList<double>> Embeddings { get; } = embeddings;
public int TotalTokensCount { get; } = totalTokensCount;
public decimal TotalCost { get; } = totalCost;
}
6 changes: 5 additions & 1 deletion src/AIAssistant/Dtos/GetEmbeddingResult.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
namespace AIAssistant.Dtos;

public record GetEmbeddingResult(IList<double> Embeddings, int TotalTokensCount, decimal TotalCost);
public record GetEmbeddingResult(
IList<IList<double>> Embeddings, // Multiple embeddings for batch
int TotalTokensCount,
decimal TotalCost
);
18 changes: 18 additions & 0 deletions src/AIAssistant/Models/FileBatch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace AIAssistant.Models;

/// <summary>
/// Represents a batch of files and their chunks to be processed in a single embedding request.
/// </summary>
public class FileBatch
{
public IList<FileChunkGroup> Files { get; set; } = new List<FileChunkGroup>();
public int TotalTokens { get; set; }

/// <summary>
/// Combines all chunked inputs for this batch into a single list for API calls.
/// </summary>
public IList<string> GetBatchInputs()
{
return Files.SelectMany(file => file.Chunks).ToList();
}
}
14 changes: 14 additions & 0 deletions src/AIAssistant/Models/FileChunkGroup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using TreeSitter.Bindings.CustomTypes.TreeParser;

namespace AIAssistant.Models;

/// <summary>
/// Represents a file and its associated chunks for embedding.
/// </summary>
public class FileChunkGroup(CodeFileMap file, List<string> chunks)
{
public CodeFileMap File { get; } = file;
public IList<string> Chunks { get; } = chunks;

public string Input => string.Join("\n", Chunks);
}
214 changes: 196 additions & 18 deletions src/AIAssistant/Services/EmbeddingService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,86 @@
using AIAssistant.Dtos;
using AIAssistant.Models;
using BuildingBlocks.LLM;
using BuildingBlocks.Utils;
using TreeSitter.Bindings.CustomTypes.TreeParser;

namespace AIAssistant.Services;

public class EmbeddingService(
ILLMClientManager llmClientManager,
ICodeEmbeddingsRepository codeEmbeddingsRepository,
IPromptManager promptManager
IPromptManager promptManager,
ITokenizer tokenizer
) : IEmbeddingService
{
public async Task<AddEmbeddingsForFilesResult> AddOrUpdateEmbeddingsForFiles(
IEnumerable<CodeFileMap> codeFilesMap,
IList<CodeFileMap> codeFilesMap,
ChatSession chatSession
)
{
int totalTokens = 0;
decimal totalCost = 0;

IList<CodeEmbedding> codeEmbeddings = new List<CodeEmbedding>();
var fileEmbeddingsMap = new Dictionary<string, List<IList<double>>>();

// Group files and manage batching using the updated tokenizer logic
var fileBatches = await BatchFilesByTokenLimitAsync(codeFilesMap, maxBatchTokens: 8192);

foreach (var batch in fileBatches)
{
var batchInputs = batch.GetBatchInputs();
var embeddingResult = await llmClientManager.GetEmbeddingAsync(batchInputs, null);

int resultIndex = 0;
foreach (var fileChunkGroup in batch.Files)
{
// Extract embeddings for the current file's chunks
var fileEmbeddings = embeddingResult
.Embeddings.Skip(resultIndex)
.Take(fileChunkGroup.Chunks.Count)
.ToList();

resultIndex += fileChunkGroup.Chunks.Count;

// Group embeddings by file path
if (!fileEmbeddingsMap.TryGetValue(fileChunkGroup.File.RelativePath, out List<IList<double>>? value))
{
value = new List<IList<double>>();
fileEmbeddingsMap[fileChunkGroup.File.RelativePath] = value;
}

value.AddRange(fileEmbeddings);
}

totalTokens += embeddingResult.TotalTokensCount;
totalCost += embeddingResult.TotalCost;
}

foreach (var codeFileMap in codeFilesMap)
// Merge and create final embeddings for each file
var codeEmbeddings = new List<CodeEmbedding>();
foreach (var entry in fileEmbeddingsMap)
{
var input = promptManager.GetEmbeddingInputString(codeFileMap.TreeSitterFullCode);
var embeddingResult = await llmClientManager.GetEmbeddingAsync(input, codeFileMap.RelativePath);
var filePath = entry.Key;
var embeddings = entry.Value;

// Merge embeddings for the file
var mergedEmbedding = MergeEmbeddings(embeddings);

// Retrieve the original file details from codeFilesMap
var fileDetails = codeFilesMap.First(file => file.RelativePath == filePath);

codeEmbeddings.Add(
new CodeEmbedding
{
RelativeFilePath = codeFileMap.RelativePath,
TreeSitterFullCode = codeFileMap.TreeSitterFullCode,
TreeOriginalCode = codeFileMap.TreeOriginalCode,
Code = codeFileMap.OriginalCode,
RelativeFilePath = fileDetails.RelativePath,
TreeSitterFullCode = fileDetails.TreeSitterFullCode,
TreeOriginalCode = fileDetails.TreeOriginalCode,
Code = fileDetails.OriginalCode,
SessionId = chatSession.SessionId,
Embeddings = embeddingResult.Embeddings,
Embeddings = mergedEmbedding,
}
);

totalTokens += embeddingResult.TotalTokensCount;
totalCost += embeddingResult.TotalCost;
}

// we can replace it with an embedded database like `chromadb`, it can give us n of most similarity items
await codeEmbeddingsRepository.AddOrUpdateCodeEmbeddings(codeEmbeddings);

return new AddEmbeddingsForFilesResult(totalTokens, totalCost);
Expand All @@ -59,7 +96,7 @@ public async Task<GetRelatedEmbeddingsResult> GetRelatedEmbeddings(string userQu

// Find relevant code based on the user query
var relevantCodes = codeEmbeddingsRepository.Query(
embeddingsResult.Embeddings,
embeddingsResult.Embeddings.First(),
chatSession.SessionId,
llmClientManager.EmbeddingThreshold
);
Expand All @@ -82,6 +119,147 @@ public IEnumerable<CodeEmbedding> QueryByFilter(

public async Task<GetEmbeddingResult> GenerateEmbeddingForUserInput(string userInput)
{
return await llmClientManager.GetEmbeddingAsync(userInput, null);
return await llmClientManager.GetEmbeddingAsync(new List<string> { userInput }, null);
}

private async Task<List<FileBatch>> BatchFilesByTokenLimitAsync(
IEnumerable<CodeFileMap> codeFilesMap,
int maxBatchTokens
)
{
var fileBatches = new List<FileBatch>();
var currentBatch = new FileBatch();

foreach (var file in codeFilesMap)
{
// Convert the full code to an input string and split into chunks
var input = promptManager.GetEmbeddingInputString(file.TreeSitterFullCode);
var chunks = await SplitTextIntoChunksAsync(input, maxTokens: 8192);

var tokenCountTasks = chunks.Select(chunk => tokenizer.GetTokenCount(chunk));
var tokenCounts = await Task.WhenAll(tokenCountTasks);

// Pair chunks with their token counts
var chunkWithTokens = chunks.Zip(
tokenCounts,
(chunk, tokenCount) => new { Chunk = chunk, TokenCount = tokenCount }
);

foreach (var chunkGroup in chunkWithTokens)
{
// If adding this chunk would exceed the batch token limit
if (currentBatch.TotalTokens + chunkGroup.TokenCount > maxBatchTokens && currentBatch.Files.Count > 0)
{
// Finalize the current batch and start a new one
fileBatches.Add(currentBatch);
currentBatch = new FileBatch();
}

// Add this chunk to the current batch
if (currentBatch.Files.All(f => f.File != file))
{
// If this is the first chunk of this file in the current batch, add a new FileChunkGroup
currentBatch.Files.Add(new FileChunkGroup(file, new List<string> { chunkGroup.Chunk }));
}
else
{
// Add the chunk to the existing FileChunkGroup for this file
var fileGroup = currentBatch.Files.First(f => f.File == file);
fileGroup.Chunks.Add(chunkGroup.Chunk);
}

currentBatch.TotalTokens += chunkGroup.TokenCount;
}
}

// Add the last batch if it has content
if (currentBatch.Files.Count > 0)
{
fileBatches.Add(currentBatch);
}

return fileBatches;
}

private async Task<List<string>> SplitTextIntoChunksAsync(string text, int maxTokens)
{
var words = text.Split(' ');
var chunks = new List<string>();
var currentChunk = new List<string>();

foreach (var word in words)
{
currentChunk.Add(word);

// Check token count only when the chunk exceeds a certain word threshold
if (currentChunk.Count % 50 == 0 || currentChunk.Count == words.Length)
{
var currentText = string.Join(" ", currentChunk);
var currentTokenCount = await tokenizer.GetTokenCount(currentText);

if (currentTokenCount > maxTokens)
{
// Ensure the chunk size is within limits
while (currentTokenCount > maxTokens && currentChunk.Count > 1)
{
currentChunk.RemoveAt(currentChunk.Count - 1);
currentText = string.Join(" ", currentChunk);
currentTokenCount = await tokenizer.GetTokenCount(currentText);
}

// Add the finalized chunk only if it fits the token limit
if (currentTokenCount <= maxTokens)
{
chunks.Add(currentText);
}

// Start a new chunk with the current word
currentChunk.Clear();
currentChunk.Add(word);
}
}
}

// Add the final chunk if it has content and is within the token limit
if (currentChunk.Count > 0)
{
var finalText = string.Join(" ", currentChunk);
var finalTokenCount = await tokenizer.GetTokenCount(finalText);

if (finalTokenCount <= maxTokens)
{
chunks.Add(finalText);
}
}

return chunks;
}

private IList<double> MergeEmbeddings(IList<IList<double>> embeddings)
{
if (embeddings == null || embeddings.Count == 0)
throw new ArgumentException("The embeddings list cannot be null or empty.");

int dimension = embeddings.First().Count;
var mergedEmbedding = new double[dimension];

foreach (var embedding in embeddings)
{
if (embedding.Count != dimension)
throw new InvalidOperationException("All embeddings must have the same dimensionality.");

for (int i = 0; i < dimension; i++)
{
mergedEmbedding[i] += embedding[i];
}
}

// Average the embeddings to unify them into one
for (int i = 0; i < dimension; i++)
{
mergedEmbedding[i] /= embeddings.Count;
}

return mergedEmbedding;
}
}
Loading