-
Notifications
You must be signed in to change notification settings - Fork 324
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5f8c2a5
commit 4b83cf1
Showing
5 changed files
with
393 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
171 changes: 171 additions & 0 deletions
171
service/Core/Handlers/ParallelGenerateEmbeddingsHandler.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text.Json; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.KernelMemory.AI; | ||
using Microsoft.KernelMemory.ContentStorage; | ||
using Microsoft.KernelMemory.Diagnostics; | ||
using Microsoft.KernelMemory.Pipeline; | ||
|
||
namespace Microsoft.KernelMemory.Handlers; | ||
|
||
/// <summary> | ||
/// Memory ingestion pipeline handler responsible for generating text embedding and saving them to the content storage. | ||
/// </summary> | ||
public class ParallelGenerateEmbeddingsHandler : IPipelineStepHandler | ||
{ | ||
private readonly IPipelineOrchestrator _orchestrator; | ||
private readonly List<ITextEmbeddingGenerator> _embeddingGenerators; | ||
private readonly ILogger<ParallelGenerateEmbeddingsHandler> _log; | ||
|
||
/// <inheritdoc /> | ||
public string StepName { get; } | ||
|
||
private object _lock = new(); | ||
|
||
/// <summary> | ||
/// Handler responsible for generating embeddings and saving them to content storages. | ||
/// Note: stepName and other params are injected with DI | ||
/// </summary> | ||
/// <param name="stepName">Pipeline step for which the handler will be invoked</param> | ||
/// <param name="orchestrator">Current orchestrator used by the pipeline, giving access to content and other helps.</param> | ||
/// <param name="log">Application logger</param> | ||
public ParallelGenerateEmbeddingsHandler( | ||
string stepName, | ||
IPipelineOrchestrator orchestrator, | ||
ILogger<ParallelGenerateEmbeddingsHandler>? log = null) | ||
{ | ||
this.StepName = stepName; | ||
this._orchestrator = orchestrator; | ||
this._log = log ?? DefaultLogger<ParallelGenerateEmbeddingsHandler>.Instance; | ||
this._embeddingGenerators = orchestrator.GetEmbeddingGenerators(); | ||
|
||
this._log.LogInformation("Handler '{0}' ready, {1} embedding generators", stepName, this._embeddingGenerators.Count); | ||
if (this._embeddingGenerators.Count < 1) | ||
{ | ||
this._log.LogError("No embedding generators configured"); | ||
} | ||
} | ||
|
||
/// <inheritdoc /> | ||
public async Task<(bool success, DataPipeline updatedPipeline)> InvokeAsync( | ||
DataPipeline pipeline, CancellationToken cancellationToken = default) | ||
{ | ||
this._log.LogDebug("Generating embeddings, pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId); | ||
|
||
foreach (var uploadedFile in pipeline.Files) | ||
{ | ||
// Track new files being generated (cannot edit originalFile.GeneratedFiles while looping it) | ||
Dictionary<string, DataPipeline.GeneratedFileDetails> newFiles = new(); | ||
var throttler = new SemaphoreSlim(initialCount: Environment.ProcessorCount); | ||
|
||
var tasks = uploadedFile.GeneratedFiles.Select(async generatedFile => | ||
{ | ||
await throttler.WaitAsync(cancellationToken).ConfigureAwait(false); | ||
|
||
var partitionFile = generatedFile.Value; | ||
if (partitionFile.AlreadyProcessedBy(this)) | ||
{ | ||
this._log.LogTrace("File {0} already processed by this handler", partitionFile.Name); | ||
return; | ||
} | ||
|
||
// Calc embeddings only for partitions (text chunks) and synthetic data | ||
if (partitionFile.ArtifactType != DataPipeline.ArtifactTypes.TextPartition | ||
&& partitionFile.ArtifactType != DataPipeline.ArtifactTypes.SyntheticData) | ||
{ | ||
this._log.LogTrace("Skipping file {0} (not a partition, not synthetic data)", partitionFile.Name); | ||
return; | ||
} | ||
|
||
// TODO: cost/perf: if the partition SHA256 is the same and the embedding exists, avoid generating it again | ||
switch (partitionFile.MimeType) | ||
{ | ||
case MimeTypes.PlainText: | ||
case MimeTypes.MarkDown: | ||
this._log.LogTrace("Processing file {0}", partitionFile.Name); | ||
foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators) | ||
{ | ||
EmbeddingFileContent embeddingData = new() | ||
{ | ||
SourceFileName = partitionFile.Name | ||
}; | ||
|
||
var generatorProviderClassName = generator.GetType().FullName ?? generator.GetType().Name; | ||
embeddingData.GeneratorProvider = string.Join('.', generatorProviderClassName.Split('.').TakeLast(3)); | ||
|
||
// TODO: model name | ||
embeddingData.GeneratorName = "TODO"; | ||
|
||
this._log.LogTrace("Generating embeddings using {0}, file: {1}", embeddingData.GeneratorProvider, partitionFile.Name); | ||
|
||
// Check if embeddings have already been generated | ||
string embeddingFileName = GetEmbeddingFileName(partitionFile.Name, embeddingData.GeneratorProvider, embeddingData.GeneratorName); | ||
|
||
// TODO: check if the file exists in storage | ||
if (uploadedFile.GeneratedFiles.ContainsKey(embeddingFileName)) | ||
{ | ||
this._log.LogDebug("Embeddings for {0} have already been generated", partitionFile.Name); | ||
continue; | ||
} | ||
|
||
// TODO: handle Azure.RequestFailedException - BlobNotFound | ||
string partitionContent = await this._orchestrator.ReadTextFileAsync(pipeline, partitionFile.Name, cancellationToken).ConfigureAwait(false); | ||
|
||
Embedding embedding = await generator.GenerateEmbeddingAsync(partitionContent, cancellationToken).ConfigureAwait(false); | ||
embeddingData.Vector = embedding; | ||
embeddingData.VectorSize = embeddingData.Vector.Length; | ||
embeddingData.TimeStamp = DateTimeOffset.UtcNow; | ||
|
||
this._log.LogDebug("Saving embedding file {0}", embeddingFileName); | ||
string text = JsonSerializer.Serialize(embeddingData); | ||
await this._orchestrator.WriteTextFileAsync(pipeline, embeddingFileName, text, cancellationToken).ConfigureAwait(false); | ||
|
||
var embeddingFileNameDetails = new DataPipeline.GeneratedFileDetails | ||
{ | ||
Id = Guid.NewGuid().ToString("N"), | ||
ParentId = uploadedFile.Id, | ||
Name = embeddingFileName, | ||
Size = text.Length, | ||
MimeType = MimeTypes.TextEmbeddingVector, | ||
ArtifactType = DataPipeline.ArtifactTypes.TextEmbeddingVector | ||
}; | ||
embeddingFileNameDetails.MarkProcessedBy(this); | ||
lock (this._lock) | ||
{ | ||
newFiles.Add(embeddingFileName, embeddingFileNameDetails); | ||
} | ||
} | ||
|
||
break; | ||
|
||
default: | ||
this._log.LogWarning("File {0} cannot be used to generate embedding, type not supported", partitionFile.Name); | ||
return; | ||
} | ||
|
||
partitionFile.MarkProcessedBy(this); | ||
}); | ||
|
||
await Task.WhenAll(tasks).ConfigureAwait(false); | ||
|
||
// Add new files to pipeline status | ||
foreach (var file in newFiles) | ||
{ | ||
uploadedFile.GeneratedFiles.Add(file.Key, file.Value); | ||
} | ||
} | ||
|
||
return (true, pipeline); | ||
} | ||
|
||
private static string GetEmbeddingFileName(string srcFilename, string type, string embeddingName) | ||
{ | ||
return $"{srcFilename}.{type}.{embeddingName}{FileExtensions.TextEmbeddingVector}"; | ||
} | ||
} |
Oops, something went wrong.