Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
KSemenenko authored and dluc committed Mar 15, 2024
1 parent 5f8c2a5 commit 4b83cf1
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 23 deletions.
2 changes: 1 addition & 1 deletion service/Core/Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<RollForward>LatestMajor</RollForward>
<AssemblyName>Microsoft.KernelMemory.Core</AssemblyName>
<RootNamespace>Microsoft.KernelMemory</RootNamespace>
<NoWarn>$(NoWarn);SKEXP0011;CA2208;CA1308;CA1724;CS1591;</NoWarn>
<NoWarn>$(NoWarn);SKEXP0011;CA2208;CA1308;CA1724;CS1591;SKEXP0050;</NoWarn>
<DefineConstants Condition="'$(SolutionName)' == 'KernelMemoryDev'">$(DefineConstants);KernelMemoryDev</DefineConstants>
</PropertyGroup>

Expand Down
20 changes: 6 additions & 14 deletions service/Core/Handlers/GenerateEmbeddingsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,23 @@ public GenerateEmbeddingsHandler(
{
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it)
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new();
var throttler = new SemaphoreSlim(initialCount: 4);

var tasks = uploadedFile.GeneratedFiles.Select(async generatedFile =>
foreach (KeyValuePair<string, DataPipeline.GeneratedFileDetails> generatedFile in uploadedFile.GeneratedFiles)
{
await throttler.WaitAsync().ConfigureAwait(false);

var partitionFile = generatedFile.Value;
if (partitionFile.AlreadyProcessedBy(this))
{
partitionsFound = true;
this._log.LogTrace("File {0} already processed by this handler", partitionFile.Name);
return;
continue;
}

// Calc embeddings only for partitions (text chunks) and synthetic data
if (partitionFile.ArtifactType != DataPipeline.ArtifactTypes.TextPartition
&& partitionFile.ArtifactType != DataPipeline.ArtifactTypes.SyntheticData)
{
this._log.LogTrace("Skipping file {0} (not a partition, not synthetic data)", partitionFile.Name);
return;
continue;
}

partitionsFound = true;
Expand Down Expand Up @@ -165,23 +162,18 @@ public GenerateEmbeddingsHandler(
Tags = partitionFile.Tags,
};
embeddingFileNameDetails.MarkProcessedBy(this);
lock (newFiles)
{
newFiles.Add(embeddingFileName, embeddingFileNameDetails);
}
newFiles.Add(embeddingFileName, embeddingFileNameDetails);
}

break;

default:
this._log.LogWarning("File {0} cannot be used to generate embedding, type not supported", partitionFile.Name);
return;
continue;
}

partitionFile.MarkProcessedBy(this);
}).ToList();

await Task.WhenAll(tasks).ConfigureAwait(false);
}

// Add new files to pipeline status
foreach (var file in newFiles)
Expand Down
171 changes: 171 additions & 0 deletions service/Core/Handlers/ParallelGenerateEmbeddingsHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.ContentStorage;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.Pipeline;

namespace Microsoft.KernelMemory.Handlers;

/// <summary>
/// Memory ingestion pipeline handler responsible for generating text embedding and saving them to the content storage.
/// </summary>
public class ParallelGenerateEmbeddingsHandler : IPipelineStepHandler
{
private readonly IPipelineOrchestrator _orchestrator;
private readonly List<ITextEmbeddingGenerator> _embeddingGenerators;
private readonly ILogger<ParallelGenerateEmbeddingsHandler> _log;

/// <inheritdoc />
public string StepName { get; }

private object _lock = new();

/// <summary>
/// Handler responsible for generating embeddings and saving them to content storages.
/// Note: stepName and other params are injected with DI
/// </summary>
/// <param name="stepName">Pipeline step for which the handler will be invoked</param>
/// <param name="orchestrator">Current orchestrator used by the pipeline, giving access to content and other helps.</param>
/// <param name="log">Application logger</param>
public ParallelGenerateEmbeddingsHandler(
string stepName,
IPipelineOrchestrator orchestrator,
ILogger<ParallelGenerateEmbeddingsHandler>? log = null)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._log = log ?? DefaultLogger<ParallelGenerateEmbeddingsHandler>.Instance;
this._embeddingGenerators = orchestrator.GetEmbeddingGenerators();

this._log.LogInformation("Handler '{0}' ready, {1} embedding generators", stepName, this._embeddingGenerators.Count);
if (this._embeddingGenerators.Count < 1)
{
this._log.LogError("No embedding generators configured");
}
}

/// <inheritdoc />
public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync(
DataPipeline pipeline, CancellationToken cancellationToken = default)
{
this._log.LogDebug("Generating embeddings, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId);

foreach (var uploadedFile in pipeline.Files)
{
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it)
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new();
var throttler = new SemaphoreSlim(initialCount: Environment.ProcessorCount);

var tasks = uploadedFile.GeneratedFiles.Select(async generatedFile =>
{
await throttler.WaitAsync(cancellationToken).ConfigureAwait(false);

var partitionFile = generatedFile.Value;
if (partitionFile.AlreadyProcessedBy(this))
{
this._log.LogTrace("File {0} already processed by this handler", partitionFile.Name);
return;
}

// Calc embeddings only for partitions (text chunks) and synthetic data
if (partitionFile.ArtifactType != DataPipeline.ArtifactTypes.TextPartition
&& partitionFile.ArtifactType != DataPipeline.ArtifactTypes.SyntheticData)
{
this._log.LogTrace("Skipping file {0} (not a partition, not synthetic data)", partitionFile.Name);
return;
}

// TODO: cost/perf: if the partition SHA256 is the same and the embedding exists, avoid generating it again
switch (partitionFile.MimeType)
{
case MimeTypes.PlainText:
case MimeTypes.MarkDown:
this._log.LogTrace("Processing file {0}", partitionFile.Name);
foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators)
{
EmbeddingFileContent embeddingData = new()
{
SourceFileName = partitionFile.Name
};

var generatorProviderClassName = generator.GetType().FullName ?? generator.GetType().Name;
embeddingData.GeneratorProvider = string.Join('.', generatorProviderClassName.Split('.').TakeLast(3));

// TODO: model name
embeddingData.GeneratorName = "TODO";

this._log.LogTrace("Generating embeddings using {0}, file: {1}", embeddingData.GeneratorProvider, partitionFile.Name);

// Check if embeddings have already been generated
string embeddingFileName = GetEmbeddingFileName(partitionFile.Name, embeddingData.GeneratorProvider, embeddingData.GeneratorName);

// TODO: check if the file exists in storage
if (uploadedFile.GeneratedFiles.ContainsKey(embeddingFileName))
{
this._log.LogDebug("Embeddings for {0} have already been generated", partitionFile.Name);
continue;
}

// TODO: handle Azure.RequestFailedException - BlobNotFound
string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, partitionFile.Name, cancellationToken).ConfigureAwait(false);

Embedding embedding = await generator.GenerateEmbeddingAsync(partitionContent, cancellationToken).ConfigureAwait(false);
embeddingData.Vector = embedding;
embeddingData.VectorSize = embeddingData.Vector.Length;
embeddingData.TimeStamp = DateTimeOffset.UtcNow;

this._log.LogDebug("Saving embedding file {0}", embeddingFileName);
string text = JsonSerializer.Serialize(embeddingData);
await this._orchestrator.WriteTextFileAsync(pipeline, embeddingFileName, text, cancellationToken).ConfigureAwait(false);

var embeddingFileNameDetails = new DataPipeline.GeneratedFileDetails
{
Id = Guid.NewGuid().ToString("N"),
ParentId = uploadedFile.Id,
Name = embeddingFileName,
Size = text.Length,
MimeType = MimeTypes.TextEmbeddingVector,
ArtifactType = DataPipeline.ArtifactTypes.TextEmbeddingVector
};
embeddingFileNameDetails.MarkProcessedBy(this);
lock (this._lock)
{
newFiles.Add(embeddingFileName, embeddingFileNameDetails);
}
}

break;

default:
this._log.LogWarning("File {0} cannot be used to generate embedding, type not supported", partitionFile.Name);
return;
}

partitionFile.MarkProcessedBy(this);
});

await Task.WhenAll(tasks).ConfigureAwait(false);

// Add new files to pipeline status
foreach (var file in newFiles)
{
uploadedFile.GeneratedFiles.Add(file.Key, file.Value);
}
}

return (true, pipeline);
}

private static string GetEmbeddingFileName(string srcFilename, string type, string embeddingName)
{
return $"{srcFilename}.{type}.{embeddingName}{FileExtensions.TextEmbeddingVector}";
}
}
Loading

0 comments on commit 4b83cf1

Please sign in to comment.